From 778a9bc507f0fb9f33fe01179a14a1d5c13fc15b Mon Sep 17 00:00:00 2001 From: Daniel Smith Date: Sat, 29 Jun 2024 16:13:26 -0400 Subject: [PATCH] Adds edge case tests with new shape kinds --- opt_einsum/backends/dispatch.py | 4 ++-- opt_einsum/contract.py | 4 ++-- opt_einsum/parser.py | 16 +++++++--------- opt_einsum/testing.py | 6 +++++- opt_einsum/tests/test_edge_cases.py | 19 +++++++++++++++++++ opt_einsum/tests/test_parser.py | 10 ++++------ opt_einsum/tests/test_paths.py | 6 ++++++ 7 files changed, 45 insertions(+), 20 deletions(-) diff --git a/opt_einsum/backends/dispatch.py b/opt_einsum/backends/dispatch.py index e1310f4..0abad45 100644 --- a/opt_einsum/backends/dispatch.py +++ b/opt_einsum/backends/dispatch.py @@ -5,7 +5,7 @@ """ import importlib -from typing import Any, Dict +from typing import Any, Dict, Tuple from opt_einsum.backends import cupy as _cupy from opt_einsum.backends import jax as _jax @@ -55,7 +55,7 @@ def _import_func(func: str, backend: str, default: Any = None) -> Any: # manually cache functions as python2 doesn't support functools.lru_cache # other libs will be added to this if needed, but pre-populate with numpy -_cached_funcs = { +_cached_funcs: Dict[Tuple[str, str], Any] = { ("einsum", "object"): object_arrays.object_einsum, } diff --git a/opt_einsum/contract.py b/opt_einsum/contract.py index 43ba392..ee3c2db 100644 --- a/opt_einsum/contract.py +++ b/opt_einsum/contract.py @@ -305,7 +305,7 @@ def contract_path( if shapes: input_shapes = operands_prepped else: - input_shapes = [x.shape for x in operands_prepped] + input_shapes = [parser.get_shape(x) for x in operands_prepped] output_set = frozenset(output_subscript) indices = frozenset(input_subscripts.replace(",", "")) @@ -1066,7 +1066,7 @@ def contract_expression( ) if not isinstance(subscripts, str): - subscripts, shapes = parser.convert_interleaved_input((subscripts,) + shapes) # type: ignore + subscripts, shapes = parser.convert_interleaved_input((subscripts,) + shapes) kwargs["_gen_expression"] = True diff --git a/opt_einsum/parser.py b/opt_einsum/parser.py index 05af3a6..15f7181 100644 --- a/opt_einsum/parser.py +++ b/opt_einsum/parser.py @@ -4,7 +4,7 @@ import itertools from collections.abc import Sequence -from typing import Any, Dict, Iterator, List, Tuple, Union +from typing import Any, Dict, Iterator, List, Tuple from opt_einsum.typing import ArrayType, TensorShapeType @@ -12,6 +12,7 @@ "is_valid_einsum_char", "has_valid_einsum_chars_only", "get_symbol", + "get_shape", "gen_unused_symbols", "convert_to_valid_einsum_chars", "alpha_canonicalize", @@ -172,7 +173,7 @@ def find_output_shape(inputs: List[str], shapes: List[TensorShapeType], output: return tuple(max(shape[loc] for shape, loc in zip(shapes, [x.find(c) for x in inputs]) if loc >= 0) for c in output) -_BaseTypes = Union[bool, int, float, complex, str, bytes] +_BaseTypes = (bool, int, float, complex, str, bytes) def get_shape(x: Any) -> TensorShapeType: @@ -254,7 +255,7 @@ def convert_subscripts(old_sub: List[Any], symbol_map: Dict[Any, Any]) -> str: return new_sub -def convert_interleaved_input(operands: Union[List[Any], Tuple[Any]]) -> Tuple[str, Tuple[ArrayType, ...]]: +def convert_interleaved_input(operands: Sequence[Any]) -> Tuple[str, Tuple[Any, ...]]: """Convert 'interleaved' input to standard einsum input.""" tmp_operands = list(operands) operand_list = [] @@ -264,7 +265,6 @@ def convert_interleaved_input(operands: Union[List[Any], Tuple[Any]]) -> Tuple[s subscript_list.append(tmp_operands.pop(0)) output_list = tmp_operands[-1] if len(tmp_operands) else None - operands = [possibly_convert_to_numpy(x) for x in operand_list] # build a map from user symbols to single-character symbols based on `get_symbol` # The map retains the intrinsic order of user symbols @@ -289,7 +289,7 @@ def convert_interleaved_input(operands: Union[List[Any], Tuple[Any]]) -> Tuple[s subscripts += "->" subscripts += convert_subscripts(output_list, symbol_map) - return subscripts, tuple(operands) + return subscripts, tuple(operand_list) def parse_einsum_input(operands: Any, shapes: bool = False) -> Tuple[str, str, List[ArrayType]]: @@ -332,16 +332,14 @@ def parse_einsum_input(operands: Any, shapes: bool = False) -> Tuple[str, str, L "shapes is set to True but given at least one operand looks like an array" " (at least one operand has a shape attribute). " ) - operands = operands[1:] - else: - operands = [possibly_convert_to_numpy(x) for x in operands[1:]] + operands = operands[1:] else: subscripts, operands = convert_interleaved_input(operands) if shapes: operand_shapes = operands else: - operand_shapes = [o.shape for o in operands] + operand_shapes = [get_shape(o) for o in operands] # Check for proper "->" if ("-" in subscripts) or (">" in subscripts): diff --git a/opt_einsum/testing.py b/opt_einsum/testing.py index 2ed4b87..5c41bf9 100644 --- a/opt_einsum/testing.py +++ b/opt_einsum/testing.py @@ -2,6 +2,7 @@ Testing routines for opt_einsum. """ +import random from typing import Any, Dict, List, Literal, Optional, Tuple, Union, overload import pytest @@ -73,7 +74,10 @@ def build_views( views = [] for shape in build_shapes(string, dimension_dict=dimension_dict): - views.append(array_function(*shape)) + if shape: + views.append(array_function(*shape)) + else: + views.append(random.random()) return tuple(views) diff --git a/opt_einsum/tests/test_edge_cases.py b/opt_einsum/tests/test_edge_cases.py index 8f8111d..4835531 100644 --- a/opt_einsum/tests/test_edge_cases.py +++ b/opt_einsum/tests/test_edge_cases.py @@ -2,6 +2,8 @@ Tets a series of opt_einsum contraction paths to ensure the results are the same for different paths """ +from typing import Any, Tuple + import pytest from opt_einsum import contract, contract_expression, contract_path @@ -131,3 +133,20 @@ def test_pathinfo_for_empty_contraction() -> None: # some info is built lazily, so check repr assert repr(info) assert info.largest_intermediate == 1 + + +@pytest.mark.parametrize( + "expression, operands", + [ + [",,->", (5, 5.0, 2.0j)], + ["ab,->", ([[5, 5], [2.0, 1]], 2.0j)], + ["ab,bc->ac", ([[5, 5], [2.0, 1]], [[2.0, 1], [3.0, 4]])], + ["ab,->", ([[5, 5], [2.0, 1]], True)], + ], +) +def test_contract_with_assumed_shapes(expression: str, operands: Tuple[Any]) -> None: + """Test that we can contract with assumed shapes, and that the output is correct. This is required as we need to infer intermediate shape sizes.""" + + benchmark = np.einsum(expression, *operands) + result = contract(expression, *operands, optimize=True) + assert np.allclose(benchmark, result) diff --git a/opt_einsum/tests/test_parser.py b/opt_einsum/tests/test_parser.py index 291344d..6fcd3b2 100644 --- a/opt_einsum/tests/test_parser.py +++ b/opt_einsum/tests/test_parser.py @@ -6,7 +6,7 @@ import pytest -from opt_einsum.parser import get_symbol, parse_einsum_input, possibly_convert_to_numpy +from opt_einsum.parser import get_shape, get_symbol, parse_einsum_input from opt_einsum.testing import build_arrays_from_tuples @@ -37,14 +37,12 @@ def test_parse_einsum_input_shapes_error() -> None: def test_parse_einsum_input_shapes() -> None: - np = pytest.importorskip("numpy") - eq = "ab,bc,cd" shapes = [(2, 3), (3, 4), (4, 5)] input_subscripts, output_subscript, operands = parse_einsum_input([eq, *shapes], shapes=True) assert input_subscripts == eq assert output_subscript == "ad" - assert np.allclose([possibly_convert_to_numpy(shp) for shp in shapes], operands) + assert shapes == operands def test_parse_with_ellisis() -> None: @@ -63,7 +61,7 @@ def test_parse_with_ellisis() -> None: [[5, 5], (2,)], [(5, 5), (2,)], [[[[[[5, 2]]]]], (1, 1, 1, 1, 2)], - [[[[[["a", "b"]]]]], (1, 1, 1, 1, 2)], + [[[[[["abcdef", "b"]]]]], (1, 1, 1, 1, 2)], ["A", tuple()], [b"A", tuple()], [True, tuple()], @@ -73,4 +71,4 @@ def test_parse_with_ellisis() -> None: ], ) def test_get_shapes(array: Any, shape: Tuple[int]) -> None: - assert possibly_convert_to_numpy(array).shape == shape + assert get_shape(array) == shape diff --git a/opt_einsum/tests/test_paths.py b/opt_einsum/tests/test_paths.py index d7eae5c..70f0904 100644 --- a/opt_einsum/tests/test_paths.py +++ b/opt_einsum/tests/test_paths.py @@ -536,3 +536,9 @@ def custom_optimizer( path, _ = oe.contract_path(eq, *shapes, shapes=True, optimize="custom") # type: ignore assert path == [(0, 1), (0, 1)] del oe.paths._PATH_OPTIONS["custom"] + + +def test_path_with_assumed_shapes() -> None: + + path, _ = oe.contract_path("ab,bc,cd", [[5, 3]], [[2], [4]], [[3, 2]]) + assert path == [(0, 1), (0, 1)]