Skip to content

Commit

Permalink
Adds edge case tests with new shape kinds
Browse files Browse the repository at this point in the history
  • Loading branch information
dgasmith committed Jun 29, 2024
1 parent 925b994 commit 778a9bc
Show file tree
Hide file tree
Showing 7 changed files with 45 additions and 20 deletions.
4 changes: 2 additions & 2 deletions opt_einsum/backends/dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
"""

import importlib
from typing import Any, Dict
from typing import Any, Dict, Tuple

from opt_einsum.backends import cupy as _cupy
from opt_einsum.backends import jax as _jax
Expand Down Expand Up @@ -55,7 +55,7 @@ def _import_func(func: str, backend: str, default: Any = None) -> Any:

# manually cache functions as python2 doesn't support functools.lru_cache
# other libs will be added to this if needed, but pre-populate with numpy
_cached_funcs = {
_cached_funcs: Dict[Tuple[str, str], Any] = {
("einsum", "object"): object_arrays.object_einsum,
}

Expand Down
4 changes: 2 additions & 2 deletions opt_einsum/contract.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ def contract_path(
if shapes:
input_shapes = operands_prepped
else:
input_shapes = [x.shape for x in operands_prepped]
input_shapes = [parser.get_shape(x) for x in operands_prepped]
output_set = frozenset(output_subscript)
indices = frozenset(input_subscripts.replace(",", ""))

Expand Down Expand Up @@ -1066,7 +1066,7 @@ def contract_expression(
)

if not isinstance(subscripts, str):
subscripts, shapes = parser.convert_interleaved_input((subscripts,) + shapes) # type: ignore
subscripts, shapes = parser.convert_interleaved_input((subscripts,) + shapes)

kwargs["_gen_expression"] = True

Expand Down
16 changes: 7 additions & 9 deletions opt_einsum/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@

import itertools
from collections.abc import Sequence
from typing import Any, Dict, Iterator, List, Tuple, Union
from typing import Any, Dict, Iterator, List, Tuple

from opt_einsum.typing import ArrayType, TensorShapeType

__all__ = [
"is_valid_einsum_char",
"has_valid_einsum_chars_only",
"get_symbol",
"get_shape",
"gen_unused_symbols",
"convert_to_valid_einsum_chars",
"alpha_canonicalize",
Expand Down Expand Up @@ -172,7 +173,7 @@ def find_output_shape(inputs: List[str], shapes: List[TensorShapeType], output:
return tuple(max(shape[loc] for shape, loc in zip(shapes, [x.find(c) for x in inputs]) if loc >= 0) for c in output)


_BaseTypes = Union[bool, int, float, complex, str, bytes]
_BaseTypes = (bool, int, float, complex, str, bytes)


def get_shape(x: Any) -> TensorShapeType:
Expand Down Expand Up @@ -254,7 +255,7 @@ def convert_subscripts(old_sub: List[Any], symbol_map: Dict[Any, Any]) -> str:
return new_sub


def convert_interleaved_input(operands: Union[List[Any], Tuple[Any]]) -> Tuple[str, Tuple[ArrayType, ...]]:
def convert_interleaved_input(operands: Sequence[Any]) -> Tuple[str, Tuple[Any, ...]]:
"""Convert 'interleaved' input to standard einsum input."""
tmp_operands = list(operands)
operand_list = []
Expand All @@ -264,7 +265,6 @@ def convert_interleaved_input(operands: Union[List[Any], Tuple[Any]]) -> Tuple[s
subscript_list.append(tmp_operands.pop(0))

output_list = tmp_operands[-1] if len(tmp_operands) else None
operands = [possibly_convert_to_numpy(x) for x in operand_list]

# build a map from user symbols to single-character symbols based on `get_symbol`
# The map retains the intrinsic order of user symbols
Expand All @@ -289,7 +289,7 @@ def convert_interleaved_input(operands: Union[List[Any], Tuple[Any]]) -> Tuple[s
subscripts += "->"
subscripts += convert_subscripts(output_list, symbol_map)

return subscripts, tuple(operands)
return subscripts, tuple(operand_list)


def parse_einsum_input(operands: Any, shapes: bool = False) -> Tuple[str, str, List[ArrayType]]:
Expand Down Expand Up @@ -332,16 +332,14 @@ def parse_einsum_input(operands: Any, shapes: bool = False) -> Tuple[str, str, L
"shapes is set to True but given at least one operand looks like an array"
" (at least one operand has a shape attribute). "
)
operands = operands[1:]
else:
operands = [possibly_convert_to_numpy(x) for x in operands[1:]]
operands = operands[1:]
else:
subscripts, operands = convert_interleaved_input(operands)

if shapes:
operand_shapes = operands
else:
operand_shapes = [o.shape for o in operands]
operand_shapes = [get_shape(o) for o in operands]

# Check for proper "->"
if ("-" in subscripts) or (">" in subscripts):
Expand Down
6 changes: 5 additions & 1 deletion opt_einsum/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Testing routines for opt_einsum.
"""

import random
from typing import Any, Dict, List, Literal, Optional, Tuple, Union, overload

import pytest
Expand Down Expand Up @@ -73,7 +74,10 @@ def build_views(

views = []
for shape in build_shapes(string, dimension_dict=dimension_dict):
views.append(array_function(*shape))
if shape:
views.append(array_function(*shape))
else:
views.append(random.random())
return tuple(views)


Expand Down
19 changes: 19 additions & 0 deletions opt_einsum/tests/test_edge_cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
Tets a series of opt_einsum contraction paths to ensure the results are the same for different paths
"""

from typing import Any, Tuple

import pytest

from opt_einsum import contract, contract_expression, contract_path
Expand Down Expand Up @@ -131,3 +133,20 @@ def test_pathinfo_for_empty_contraction() -> None:
# some info is built lazily, so check repr
assert repr(info)
assert info.largest_intermediate == 1


@pytest.mark.parametrize(
"expression, operands",
[
[",,->", (5, 5.0, 2.0j)],
["ab,->", ([[5, 5], [2.0, 1]], 2.0j)],
["ab,bc->ac", ([[5, 5], [2.0, 1]], [[2.0, 1], [3.0, 4]])],
["ab,->", ([[5, 5], [2.0, 1]], True)],
],
)
def test_contract_with_assumed_shapes(expression: str, operands: Tuple[Any]) -> None:
"""Test that we can contract with assumed shapes, and that the output is correct. This is required as we need to infer intermediate shape sizes."""

benchmark = np.einsum(expression, *operands)
result = contract(expression, *operands, optimize=True)
assert np.allclose(benchmark, result)
10 changes: 4 additions & 6 deletions opt_einsum/tests/test_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import pytest

from opt_einsum.parser import get_symbol, parse_einsum_input, possibly_convert_to_numpy
from opt_einsum.parser import get_shape, get_symbol, parse_einsum_input
from opt_einsum.testing import build_arrays_from_tuples


Expand Down Expand Up @@ -37,14 +37,12 @@ def test_parse_einsum_input_shapes_error() -> None:


def test_parse_einsum_input_shapes() -> None:
np = pytest.importorskip("numpy")

eq = "ab,bc,cd"
shapes = [(2, 3), (3, 4), (4, 5)]
input_subscripts, output_subscript, operands = parse_einsum_input([eq, *shapes], shapes=True)
assert input_subscripts == eq
assert output_subscript == "ad"
assert np.allclose([possibly_convert_to_numpy(shp) for shp in shapes], operands)
assert shapes == operands


def test_parse_with_ellisis() -> None:
Expand All @@ -63,7 +61,7 @@ def test_parse_with_ellisis() -> None:
[[5, 5], (2,)],
[(5, 5), (2,)],
[[[[[[5, 2]]]]], (1, 1, 1, 1, 2)],
[[[[[["a", "b"]]]]], (1, 1, 1, 1, 2)],
[[[[[["abcdef", "b"]]]]], (1, 1, 1, 1, 2)],
["A", tuple()],
[b"A", tuple()],
[True, tuple()],
Expand All @@ -73,4 +71,4 @@ def test_parse_with_ellisis() -> None:
],
)
def test_get_shapes(array: Any, shape: Tuple[int]) -> None:
assert possibly_convert_to_numpy(array).shape == shape
assert get_shape(array) == shape
6 changes: 6 additions & 0 deletions opt_einsum/tests/test_paths.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,3 +536,9 @@ def custom_optimizer(
path, _ = oe.contract_path(eq, *shapes, shapes=True, optimize="custom") # type: ignore
assert path == [(0, 1), (0, 1)]
del oe.paths._PATH_OPTIONS["custom"]


def test_path_with_assumed_shapes() -> None:

path, _ = oe.contract_path("ab,bc,cd", [[5, 3]], [[2], [4]], [[3, 2]])
assert path == [(0, 1), (0, 1)]

0 comments on commit 778a9bc

Please sign in to comment.