diff --git a/opt_einsum/backends/object_arrays.py b/opt_einsum/backends/object_arrays.py index 308cb67..eae0e92 100644 --- a/opt_einsum/backends/object_arrays.py +++ b/opt_einsum/backends/object_arrays.py @@ -7,8 +7,10 @@ import numpy as np +from opt_einsum.typing import ArrayType -def object_einsum(eq, *arrays): + +def object_einsum(eq: str, *arrays: ArrayType) -> ArrayType: """A ``einsum`` implementation for ``numpy`` arrays with object dtype. The loop is performed in python, meaning the objects themselves need only to implement ``__mul__`` and ``__add__`` for the contraction to be diff --git a/opt_einsum/backends/torch.py b/opt_einsum/backends/torch.py index ed92fd5..c3ae9b5 100644 --- a/opt_einsum/backends/torch.py +++ b/opt_einsum/backends/torch.py @@ -41,7 +41,7 @@ def transpose(a, axes): return a.permute(*axes) -def einsum(equation, *operands): +def einsum(equation, *operands, **kwargs): """Variadic version of torch.einsum to match numpy api.""" # rename symbols to support PyTorch 0.4.1 and earlier, # which allow only symbols a-z. diff --git a/opt_einsum/contract.py b/opt_einsum/contract.py index bc86b80..7270ce4 100644 --- a/opt_einsum/contract.py +++ b/opt_einsum/contract.py @@ -123,6 +123,26 @@ def _choose_memory_arg(memory_limit: _MemoryLimit, size_list: List[int]) -> Opti return int(memory_limit) +def _filter_einsum_defaults(kwargs: Dict[Literal["order", "casting", "dtype", "out"], Any]) -> Dict[str, Any]: + """Filters out default contract kwargs to pass to various backends.""" + kwargs = kwargs.copy() + ret = {} + if (order := kwargs.pop("order", "K")) != "K": + ret["order"] = order + + if (casting := kwargs.pop("casting", "safe")) != "safe": + ret["casting"] = casting + + if (dtype := kwargs.pop("dtype", None)) is not None: + ret["dtype"] = dtype + + if (out := kwargs.pop("out", None)) is not None: + ret["out"] = out + + ret.update(kwargs) + return ret + + @overload def contract_path( subscripts: str, @@ -330,7 +350,7 @@ def contract_path( path_tuple = [tuple(range(num_ops))] elif isinstance(optimize, paths.PathOptimizer): # Custom path optimizer supplied - path_tuple = optimize(input_sets, output_set, size_dict, memory_arg) # type: ignore + path_tuple = optimize(input_sets, output_set, size_dict, memory_arg) else: path_optimizer = paths.get_path_fn(optimize) path_tuple = path_optimizer(input_sets, output_set, size_dict, memory_arg) @@ -427,6 +447,7 @@ def _einsum(*operands: Any, **kwargs: Any) -> ArrayType: einsum_str = parser.convert_to_valid_einsum_chars(einsum_str) + kwargs = _filter_einsum_defaults(kwargs) return fn(einsum_str, *operands, **kwargs) @@ -906,7 +927,6 @@ def __call__(self, *arrays: ArrayType, **kwargs: Any) -> ArrayType: return self._contract(ops, out=out, backend=backend, evaluate_constants=evaluate_constants) except ValueError as err: - raise original_msg = str(err.args) if err.args else "" msg = ( "Internal error while evaluating `ContractExpression`. Note that few checks are performed"