diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index 9bf351a5c..fc2435db9 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -26,11 +26,12 @@ THE SOFTWARE. """ -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Never + +from typing_extensions import Self from loopy.tools import LoopyKeyBuilder from pymbolic.mapper.optimize import optimize_mapper -from pytools import memoize_method from pytato.array import ( Array, @@ -76,7 +77,7 @@ # {{{ NUserCollector -class NUserCollector(Mapper[None, []]): +class NUserCollector(Mapper[None, None, []]): """ A :class:`pytato.transform.CachedWalkMapper` that records the number of times an array expression is a direct dependency of other nodes. @@ -317,7 +318,7 @@ def is_einsum_similar_to_subscript(expr: Einsum, subscripts: str) -> bool: # {{{ DirectPredecessorsGetter -class DirectPredecessorsGetter(Mapper[frozenset[ArrayOrNames], []]): +class DirectPredecessorsGetter(Mapper[frozenset[ArrayOrNames], Never, []]): """ Mapper to get the `direct predecessors @@ -413,9 +414,14 @@ class NodeCountMapper(CachedWalkMapper[[]]): Dictionary mapping node types to number of nodes of that type. """ - def __init__(self, count_duplicates: bool = False) -> None: + def __init__( + self, + count_duplicates: bool = False, + _visited_functions: set[Any] | None = None, + ) -> None: + super().__init__(_visited_functions=_visited_functions) + from collections import defaultdict - super().__init__() self.expr_type_counts: dict[type[Any], int] = defaultdict(int) self.count_duplicates = count_duplicates @@ -423,6 +429,16 @@ def get_cache_key(self, expr: ArrayOrNames) -> int | ArrayOrNames: # Returns unique nodes only if count_duplicates is False return id(expr) if self.count_duplicates else expr + def get_function_definition_cache_key( + self, expr: FunctionDefinition) -> int | FunctionDefinition: + # Returns unique nodes only if count_duplicates is False + return id(expr) if self.count_duplicates else expr + + def clone_for_callee(self, function: FunctionDefinition) -> Self: + return type(self)( + count_duplicates=self.count_duplicates, + _visited_functions=self._visited_functions) + def post_visit(self, expr: Any) -> None: if not isinstance(expr, DictOfNamedArrays): self.expr_type_counts[type(expr)] += 1 @@ -488,15 +504,20 @@ class NodeMultiplicityMapper(CachedWalkMapper[[]]): .. autoattribute:: expr_multiplicity_counts """ - def __init__(self) -> None: + def __init__(self, _visited_functions: set[Any] | None = None) -> None: + super().__init__(_visited_functions=_visited_functions) + from collections import defaultdict - super().__init__() self.expr_multiplicity_counts: dict[Array, int] = defaultdict(int) def get_cache_key(self, expr: ArrayOrNames) -> int: # Returns each node, including nodes that are duplicates return id(expr) + def get_function_definition_cache_key(self, expr: FunctionDefinition) -> int: + # Returns each node, including nodes that are duplicates + return id(expr) + def post_visit(self, expr: Any) -> None: if not isinstance(expr, DictOfNamedArrays): self.expr_multiplicity_counts[expr] += 1 @@ -530,14 +551,16 @@ class CallSiteCountMapper(CachedWalkMapper[[]]): The number of nodes. """ - def __init__(self) -> None: - super().__init__() + def __init__(self, _visited_functions: set[Any] | None = None) -> None: + super().__init__(_visited_functions=_visited_functions) self.count = 0 def get_cache_key(self, expr: ArrayOrNames) -> int: return id(expr) - @memoize_method + def get_function_definition_cache_key(self, expr: FunctionDefinition) -> int: + return id(expr) + def map_function_definition(self, expr: FunctionDefinition) -> None: if not self.visit(expr): return diff --git a/pytato/codegen.py b/pytato/codegen.py index d08445517..85ac4052d 100644 --- a/pytato/codegen.py +++ b/pytato/codegen.py @@ -68,9 +68,9 @@ if TYPE_CHECKING: - from collections.abc import Mapping + from collections.abc import Hashable, Mapping - from pytato.function import NamedCallResult + from pytato.function import FunctionDefinition, NamedCallResult from pytato.target import Target @@ -136,10 +136,13 @@ class CodeGenPreprocessor(ToIndexLambdaMixin, CopyMapper): # type: ignore[misc] ====================================== ===================================== """ - def __init__(self, target: Target, - kernels_seen: dict[str, lp.LoopKernel] | None = None - ) -> None: - super().__init__() + def __init__( + self, + target: Target, + kernels_seen: dict[str, lp.LoopKernel] | None = None, + _function_cache: dict[Hashable, FunctionDefinition] | None = None + ) -> None: + super().__init__(_function_cache=_function_cache) self.bound_arguments: dict[str, DataInterface] = {} self.var_name_gen: UniqueNameGenerator = UniqueNameGenerator() self.target = target @@ -266,13 +269,16 @@ def normalize_outputs( @optimize_mapper(drop_args=True, drop_kwargs=True, inline_get_cache_key=True) class NamesValidityChecker(CachedWalkMapper[[]]): - def __init__(self) -> None: + def __init__(self, _visited_functions: set[Any] | None = None) -> None: self.name_to_input: dict[str, InputArgumentBase] = {} - super().__init__() + super().__init__(_visited_functions=_visited_functions) def get_cache_key(self, expr: ArrayOrNames) -> int: return id(expr) + def get_function_definition_cache_key(self, expr: FunctionDefinition) -> int: + return id(expr) + def post_visit(self, expr: Any) -> None: if isinstance(expr, Placeholder | SizeParam | DataWrapper): if expr.name is not None: diff --git a/pytato/distributed/partition.py b/pytato/distributed/partition.py index 4a0eb5897..216403191 100644 --- a/pytato/distributed/partition.py +++ b/pytato/distributed/partition.py @@ -70,6 +70,7 @@ TYPE_CHECKING, Any, Generic, + Never, TypeVar, cast, ) @@ -89,7 +90,13 @@ DistributedSendRefHolder, ) from pytato.scalar_expr import SCALAR_CLASSES -from pytato.transform import ArrayOrNames, CachedWalkMapper, CombineMapper, CopyMapper +from pytato.transform import ( + ArrayOrNames, + CachedWalkMapper, + CombineMapper, + CopyMapper, + _verify_is_array, +) if TYPE_CHECKING: @@ -288,8 +295,9 @@ def __init__(self, recvd_ary_to_name: Mapping[Array, str], sptpo_ary_to_name: Mapping[Array, str], name_to_output: Mapping[str, Array], + _function_cache: dict[Hashable, FunctionDefinition] | None = None, ) -> None: - super().__init__() + super().__init__(_function_cache=_function_cache) self.recvd_ary_to_name = recvd_ary_to_name self.sptpo_ary_to_name = sptpo_ary_to_name @@ -303,7 +311,7 @@ def clone_for_callee( self, function: FunctionDefinition) -> _DistributedInputReplacer: # Function definitions aren't allowed to contain receives, # stored arrays promoted to part outputs, or part outputs - return type(self)({}, {}, {}) + return type(self)({}, {}, {}, _function_cache=self._function_cache) def map_placeholder(self, expr: Placeholder) -> Placeholder: self.user_input_names.add(expr.name) @@ -394,7 +402,7 @@ def _make_distributed_partition( for name, val in name_to_part_output.items(): assert name not in name_to_output - name_to_output[name] = comm_replacer.rec_ary(val) + name_to_output[name] = _verify_is_array(comm_replacer.rec(val)) comm_ids = part_comm_ids[part_id] @@ -456,7 +464,7 @@ def _recv_to_comm_id( class _LocalSendRecvDepGatherer( - CombineMapper[frozenset[CommunicationOpIdentifier]]): + CombineMapper[frozenset[CommunicationOpIdentifier], Never]): def __init__(self, local_rank: int) -> None: super().__init__() self.local_comm_ids_to_needed_comm_ids: \ diff --git a/pytato/distributed/verify.py b/pytato/distributed/verify.py index 5e8aa526d..bde9e2277 100644 --- a/pytato/distributed/verify.py +++ b/pytato/distributed/verify.py @@ -144,8 +144,8 @@ class MissingRecvError(DistributedPartitionVerificationError): @optimize_mapper(drop_args=True, drop_kwargs=True, inline_get_cache_key=True) class _SeenNodesWalkMapper(CachedWalkMapper[[]]): - def __init__(self) -> None: - super().__init__() + def __init__(self, _visited_functions: set[Any] | None = None) -> None: + super().__init__(_visited_functions=_visited_functions) self.seen_nodes: set[ArrayOrNames] = set() def get_cache_key(self, expr: ArrayOrNames) -> int: diff --git a/pytato/equality.py b/pytato/equality.py index 1ef2a88ed..47bf7a0dc 100644 --- a/pytato/equality.py +++ b/pytato/equality.py @@ -27,8 +27,6 @@ from typing import TYPE_CHECKING, Any -from pytools import memoize_method - from pytato.array import ( AbstractResultWithNamedArrays, AdvancedIndexInContiguousAxes, @@ -49,13 +47,14 @@ SizeParam, Stack, ) +from pytato.function import FunctionDefinition if TYPE_CHECKING: from collections.abc import Callable from pytato.distributed.nodes import DistributedRecv, DistributedSendRefHolder - from pytato.function import Call, FunctionDefinition, NamedCallResult + from pytato.function import Call, NamedCallResult from pytato.loopy import LoopyCall, LoopyCallResult __doc__ = """ @@ -85,26 +84,31 @@ class EqualityComparer: more on this. """ def __init__(self) -> None: + # Uses the same cache for both arrays and functions self._cache: dict[tuple[int, int], bool] = {} - def rec(self, expr1: ArrayOrNames, expr2: Any) -> bool: + def rec(self, expr1: ArrayOrNames | FunctionDefinition, expr2: Any) -> bool: cache_key = id(expr1), id(expr2) try: return self._cache[cache_key] except KeyError: - - method: Callable[[Array | AbstractResultWithNamedArrays, Any], - bool] - - try: - method = getattr(self, expr1._mapper_method) - except AttributeError: - if isinstance(expr1, Array): - result = self.handle_unsupported_array(expr1, expr2) + if expr1 is expr2: + result = True + elif isinstance(expr1, ArrayOrNames): + method: Callable[[ArrayOrNames, Any], bool] + try: + method = getattr(self, expr1._mapper_method) + except AttributeError: + if isinstance(expr1, Array): + result = self.handle_unsupported_array(expr1, expr2) + else: + result = self.map_foreign(expr1, expr2) else: - result = self.map_foreign(expr1, expr2) + result = method(expr1, expr2) + elif isinstance(expr1, FunctionDefinition): + result = self.map_function_definition(expr1, expr2) else: - result = (expr1 is expr2) or method(expr1, expr2) + result = self.map_foreign(expr1, expr2) self._cache[cache_key] = result return result @@ -296,7 +300,6 @@ def map_distributed_recv(self, expr1: DistributedRecv, expr2: Any) -> bool: and expr1.tags == expr2.tags ) - @memoize_method def map_function_definition(self, expr1: FunctionDefinition, expr2: Any ) -> bool: return (expr1.__class__ is expr2.__class__ @@ -310,7 +313,7 @@ def map_function_definition(self, expr1: FunctionDefinition, expr2: Any def map_call(self, expr1: Call, expr2: Any) -> bool: return (expr1.__class__ is expr2.__class__ - and self.map_function_definition(expr1.function, expr2.function) + and self.rec(expr1.function, expr2.function) and frozenset(expr1.bindings) == frozenset(expr2.bindings) and all(self.rec(bnd, expr2.bindings[name]) diff --git a/pytato/stringifier.py b/pytato/stringifier.py index f2e7f066b..17c915778 100644 --- a/pytato/stringifier.py +++ b/pytato/stringifier.py @@ -31,8 +31,6 @@ import numpy as np from immutabledict import immutabledict -from pytools import memoize_method - from pytato.array import ( Array, Axis, @@ -41,7 +39,7 @@ IndexLambda, ReductionDescriptor, ) -from pytato.transform import Mapper +from pytato.transform import ForeignObjectError, Mapper if TYPE_CHECKING: @@ -58,7 +56,7 @@ # {{{ Reprifier -class Reprifier(Mapper[str, [int]]): +class Reprifier(Mapper[str, str, [int]]): """ Stringifies :mod:`pytato`-types to closely resemble CPython's implementation of :func:`repr` for its builtin datatypes. @@ -71,6 +69,7 @@ def __init__(self, self.truncation_depth = truncation_depth self.truncation_string = truncation_string + # Uses the same cache for both arrays and functions self._cache: dict[tuple[int, int], str] = {} def rec(self, expr: Any, depth: int) -> str: @@ -78,7 +77,19 @@ def rec(self, expr: Any, depth: int) -> str: try: return self._cache[cache_key] except KeyError: - result = super().rec(expr, depth) + try: + result = super().rec(expr, depth) + except ForeignObjectError: + result = self.map_foreign(expr, depth) + self._cache[cache_key] = result + return result + + def rec_function_definition(self, expr: FunctionDefinition, depth: int) -> str: + cache_key = (id(expr), depth) + try: + return self._cache[cache_key] + except KeyError: + result = super().rec_function_definition(expr, depth) self._cache[cache_key] = result return result @@ -171,7 +182,6 @@ def _get_field_val(field: str) -> str: for field in dataclasses.fields(type(expr))) + ")") - @memoize_method def map_function_definition(self, expr: FunctionDefinition, depth: int) -> str: if depth > self.truncation_depth: return self.truncation_string @@ -194,7 +204,7 @@ def map_call(self, expr: Call, depth: int) -> str: def _get_field_val(field: str) -> str: if field == "function": - return self.map_function_definition(expr.function, depth+1) + return self.rec_function_definition(expr.function, depth+1) else: return self.rec(getattr(expr, field), depth+1) diff --git a/pytato/target/loopy/codegen.py b/pytato/target/loopy/codegen.py index 528de5d0e..c4b029a33 100644 --- a/pytato/target/loopy/codegen.py +++ b/pytato/target/loopy/codegen.py @@ -28,7 +28,7 @@ import sys from abc import ABC, abstractmethod from collections.abc import Mapping -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Never import islpy as isl @@ -384,7 +384,7 @@ def update_t_unit(self, t_unit: lp.TranslationUnit) -> None: # {{{ codegen mapper -class CodeGenMapper(Mapper[ImplementedResult, [CodeGenState]]): +class CodeGenMapper(Mapper[ImplementedResult, Never, [CodeGenState]]): """A mapper for generating code for nodes in the computation graph. """ exprgen_mapper: InlinedExpressionGenMapper diff --git a/pytato/target/python/numpy_like.py b/pytato/target/python/numpy_like.py index 0c9bd413c..f8bb080be 100644 --- a/pytato/target/python/numpy_like.py +++ b/pytato/target/python/numpy_like.py @@ -30,6 +30,7 @@ import sys from typing import ( TYPE_CHECKING, + Never, TypedDict, TypeVar, cast, @@ -171,7 +172,7 @@ def _is_slice_trivial(slice_: NormalizedSlice, } -class NumpyCodegenMapper(CachedMapper[str, []]): +class NumpyCodegenMapper(CachedMapper[str, Never, []]): """ .. note:: diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index 393583037..40cb2b6e2 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -28,10 +28,12 @@ """ import dataclasses import logging +from collections.abc import Hashable from typing import ( TYPE_CHECKING, Any, Generic, + Never, ParamSpec, TypeAlias, TypeVar, @@ -43,7 +45,6 @@ from typing_extensions import Self from pymbolic.mapper.optimize import optimize_mapper -from pytools import memoize_method from pytato.array import ( AbstractResultWithNamedArrays, @@ -80,15 +81,12 @@ if TYPE_CHECKING: - from collections.abc import Callable, Hashable, Iterable, Mapping + from collections.abc import Callable, Iterable, Mapping ArrayOrNames: TypeAlias = Array | AbstractResultWithNamedArrays MappedT = TypeVar("MappedT", Array, AbstractResultWithNamedArrays, ArrayOrNames) -TransformMapperResultT = TypeVar("TransformMapperResultT", # used in TransformMapper - Array, AbstractResultWithNamedArrays, ArrayOrNames) -CachedMapperT = TypeVar("CachedMapperT") # used in CachedMapper IndexOrShapeExpr = TypeVar("IndexOrShapeExpr") R = frozenset[Array] @@ -142,13 +140,15 @@ A type variable representing the input type of a :class:`Mapper`. -.. class:: CombineT +.. class:: ResultT - A type variable representing the type of a :class:`CombineMapper`. + A type variable representing the result type of a :class:`Mapper` when mapping + a :class:`pytato.Array` or :class:`pytato.AbstractResultWithNamedArrays`. -.. class:: ResultT +.. class:: FunctionResultT - A type variable representing the result type of a :class:`Mapper`. + A type variable representing the result type of a :class:`Mapper` when mapping + a :class:`pytato.function.FunctionDefinition`. .. class:: Scalar @@ -162,13 +162,23 @@ class UnsupportedArrayError(ValueError): pass +class ForeignObjectError(ValueError): + pass + + # {{{ mapper base class ResultT = TypeVar("ResultT") +FunctionResultT = TypeVar("FunctionResultT") P = ParamSpec("P") -class Mapper(Generic[ResultT, P]): +def _verify_is_array(expr: ArrayOrNames) -> Array: + assert isinstance(expr, Array) + return expr + + +class Mapper(Generic[ResultT, FunctionResultT, P]): """A class that when called with a :class:`pytato.Array` recursively iterates over the DAG, calling the *_mapper_method* of each node. Users of this class are expected to override the methods of this class or create a @@ -180,12 +190,11 @@ class Mapper(Generic[ResultT, P]): if this is not desired. .. automethod:: handle_unsupported_array - .. automethod:: map_foreign .. automethod:: rec .. automethod:: __call__ """ - def handle_unsupported_array(self, expr: MappedT, + def handle_unsupported_array(self, expr: Array, *args: P.args, **kwargs: P.kwargs) -> ResultT: """Mapper method that is invoked for :class:`pytato.Array` subclasses for which a mapper @@ -194,16 +203,9 @@ def handle_unsupported_array(self, expr: MappedT, raise UnsupportedArrayError( f"{type(self).__name__} cannot handle expressions of type {type(expr)}") - def map_foreign(self, expr: Any, *args: P.args, **kwargs: P.kwargs) -> ResultT: - """Mapper method that is invoked for an object of class for which a - mapper method does not exist in this mapper. - """ - raise ValueError( - f"{type(self).__name__} encountered invalid foreign object: {expr!r}") - def rec(self, expr: ArrayOrNames, *args: P.args, **kwargs: P.kwargs) -> ResultT: """Call the mapper method of *expr* and return the result.""" - method: Callable[..., Array] | None + method: Callable[..., ResultT] | None try: method = getattr(self, expr._mapper_method) @@ -218,11 +220,28 @@ def rec(self, expr: ArrayOrNames, *args: P.args, **kwargs: P.kwargs) -> ResultT: else: return self.handle_unsupported_array(expr, *args, **kwargs) else: - return self.map_foreign(expr, *args, **kwargs) + raise ForeignObjectError( + f"{type(self).__name__} encountered invalid foreign " + f"object: {expr!r}") from None assert method is not None return cast("ResultT", method(expr, *args, **kwargs)) + def rec_function_definition( + self, expr: FunctionDefinition, *args: P.args, **kwargs: P.kwargs + ) -> FunctionResultT: + """Call the mapper method of *expr* and return the result.""" + method: Callable[..., FunctionResultT] | None + + try: + method = self.map_function_definition # type: ignore[attr-defined] + except AttributeError: + raise ValueError( + f"{type(self).__name__} lacks a mapper method for functions.") from None + + assert method is not None + return method(expr, *args, **kwargs) + def __call__(self, expr: ArrayOrNames, *args: P.args, **kwargs: P.kwargs) -> ResultT: """Handle the mapping of *expr*.""" @@ -233,23 +252,38 @@ def __call__(self, # {{{ CachedMapper -class CachedMapper(Mapper[ResultT, P]): +class CachedMapper(Mapper[ResultT, FunctionResultT, P]): """Mapper class that maps each node in the DAG exactly once. This loses some information compared to :class:`Mapper` as a node is visited only from one of its predecessors. .. automethod:: get_cache_key + .. automethod:: get_function_definition_cache_key + .. automethod:: clone_for_callee """ - def __init__(self) -> None: + def __init__( + self, + # Arrays are cached separately for each call stack frame, but + # functions are cached globally + _function_cache: dict[Hashable, FunctionResultT] | None = None + ) -> None: super().__init__() self._cache: dict[Hashable, ResultT] = {} + self._function_cache: dict[Hashable, FunctionResultT] = \ + _function_cache if _function_cache is not None else {} + def get_cache_key( self, expr: ArrayOrNames, *args: P.args, **kwargs: P.kwargs ) -> Hashable: return (expr, *args, tuple(sorted(kwargs.items()))) + def get_function_definition_cache_key( + self, expr: FunctionDefinition, *args: P.args, **kwargs: P.kwargs + ) -> Hashable: + return (expr, *args, tuple(sorted(kwargs.items()))) + def rec(self, expr: ArrayOrNames, *args: P.args, **kwargs: P.kwargs) -> ResultT: key = self.get_cache_key(expr, *args, **kwargs) try: @@ -259,12 +293,31 @@ def rec(self, expr: ArrayOrNames, *args: P.args, **kwargs: P.kwargs) -> ResultT: self._cache[key] = result return result + def rec_function_definition( + self, expr: FunctionDefinition, *args: P.args, **kwargs: P.kwargs + ) -> FunctionResultT: + key = self.get_function_definition_cache_key(expr, *args, **kwargs) + try: + return self._function_cache[key] + except KeyError: + result = super().rec_function_definition(expr, *args, **kwargs) + self._function_cache[key] = result + return result + + def clone_for_callee( + self, function: FunctionDefinition) -> Self: + """ + Called to clone *self* before starting traversal of a + :class:`pytato.function.FunctionDefinition`. + """ + return type(self)(_function_cache=self._function_cache) + # }}} # {{{ TransformMapper -class TransformMapper(CachedMapper[ArrayOrNames, []]): +class TransformMapper(CachedMapper[ArrayOrNames, FunctionDefinition, []]): """Base class for mappers that transform :class:`pytato.array.Array`\\ s into other :class:`pytato.array.Array`\\ s. @@ -272,19 +325,8 @@ class TransformMapper(CachedMapper[ArrayOrNames, []]): arrays (e.g., calling :meth:`~CachedMapper.get_cache_key` on them). Does not implement default mapper methods; for that, see :class:`CopyMapper`. - .. automethod:: clone_for_callee """ - def rec_ary(self, expr: Array) -> Array: - res = self.rec(expr) - assert isinstance(res, Array) - return res - - def clone_for_callee(self, function: FunctionDefinition) -> Self: - """ - Called to clone *self* before starting traversal of a - :class:`pytato.function.FunctionDefinition`. - """ - return type(self)() + pass # }}} @@ -292,8 +334,7 @@ def clone_for_callee(self, function: FunctionDefinition) -> Self: # {{{ TransformMapperWithExtraArgs class TransformMapperWithExtraArgs( - CachedMapper[ArrayOrNames, P], - Mapper[ArrayOrNames, P] + CachedMapper[ArrayOrNames, FunctionDefinition, P] ): """ Similar to :class:`TransformMapper`, but each mapper method takes extra @@ -301,20 +342,8 @@ class TransformMapperWithExtraArgs( The logic in :class:`TransformMapper` purposely does not take the extra arguments to keep the cost of its each call frame low. - - .. automethod:: clone_for_callee """ - def rec_ary(self, expr: Array, *args: P.args, **kwargs: P.kwargs) -> Array: - res = self.rec(expr, *args, **kwargs) - assert isinstance(res, Array) - return res - - def clone_for_callee(self, function: FunctionDefinition) -> Self: - """ - Called to clone *self* before starting traversal of a - :class:`pytato.function.FunctionDefinition`. - """ - return type(self)() + pass # }}} @@ -360,18 +389,18 @@ def map_placeholder(self, expr: Placeholder) -> Array: non_equality_tags=expr.non_equality_tags) def map_stack(self, expr: Stack) -> Array: - arrays = tuple(self.rec_ary(arr) for arr in expr.arrays) + arrays = tuple(_verify_is_array(self.rec(arr)) for arr in expr.arrays) return Stack(arrays=arrays, axis=expr.axis, axes=expr.axes, tags=expr.tags, non_equality_tags=expr.non_equality_tags) def map_concatenate(self, expr: Concatenate) -> Array: - arrays = tuple(self.rec_ary(arr) for arr in expr.arrays) + arrays = tuple(_verify_is_array(self.rec(arr)) for arr in expr.arrays) return Concatenate(arrays=arrays, axis=expr.axis, axes=expr.axes, tags=expr.tags, non_equality_tags=expr.non_equality_tags) def map_roll(self, expr: Roll) -> Array: - return Roll(array=self.rec_ary(expr.array), + return Roll(array=_verify_is_array(self.rec(expr.array)), shift=expr.shift, axis=expr.axis, axes=expr.axes, @@ -379,14 +408,14 @@ def map_roll(self, expr: Roll) -> Array: non_equality_tags=expr.non_equality_tags) def map_axis_permutation(self, expr: AxisPermutation) -> Array: - return AxisPermutation(array=self.rec_ary(expr.array), + return AxisPermutation(array=_verify_is_array(self.rec(expr.array)), axis_permutation=expr.axis_permutation, axes=expr.axes, tags=expr.tags, non_equality_tags=expr.non_equality_tags) def _map_index_base(self, expr: IndexBase) -> Array: - return type(expr)(self.rec_ary(expr.array), + return type(expr)(_verify_is_array(self.rec(expr.array)), indices=self.rec_idx_or_size_tuple(expr.indices), axes=expr.axes, tags=expr.tags, @@ -423,7 +452,7 @@ def map_size_param(self, expr: SizeParam) -> Array: def map_einsum(self, expr: Einsum) -> Array: return Einsum(expr.access_descriptors, - tuple(self.rec_ary(arg) for arg in expr.args), + tuple(_verify_is_array(self.rec(arg)) for arg in expr.args), axes=expr.axes, redn_axis_to_redn_descr=expr.redn_axis_to_redn_descr, tags=expr.tags, @@ -440,7 +469,7 @@ def map_named_array(self, expr: NamedArray) -> Array: def map_dict_of_named_arrays(self, expr: DictOfNamedArrays) -> DictOfNamedArrays: - return DictOfNamedArrays({key: self.rec_ary(val.expr) + return DictOfNamedArrays({key: _verify_is_array(self.rec(val.expr)) for key, val in expr.items()}, tags=expr.tags ) @@ -468,7 +497,7 @@ def map_loopy_call_result(self, expr: LoopyCallResult) -> Array: non_equality_tags=expr.non_equality_tags) def map_reshape(self, expr: Reshape) -> Array: - return Reshape(self.rec_ary(expr.array), + return Reshape(_verify_is_array(self.rec(expr.array)), newshape=self.rec_idx_or_size_tuple(expr.newshape), order=expr.order, axes=expr.axes, @@ -479,10 +508,10 @@ def map_distributed_send_ref_holder( self, expr: DistributedSendRefHolder) -> Array: return DistributedSendRefHolder( send=DistributedSend( - data=self.rec_ary(expr.send.data), + data=_verify_is_array(self.rec(expr.send.data)), dest_rank=expr.send.dest_rank, comm_tag=expr.send.comm_tag), - passthrough_data=self.rec_ary(expr.passthrough_data), + passthrough_data=_verify_is_array(self.rec(expr.passthrough_data)), ) def map_distributed_recv(self, expr: DistributedRecv) -> Array: @@ -492,7 +521,6 @@ def map_distributed_recv(self, expr: DistributedRecv) -> Array: dtype=expr.dtype, tags=expr.tags, axes=expr.axes, non_equality_tags=expr.non_equality_tags) - @memoize_method def map_function_definition(self, expr: FunctionDefinition) -> FunctionDefinition: # spawn a new mapper to avoid unsound cache hits, since the namespace of the @@ -503,7 +531,7 @@ def map_function_definition(self, return dataclasses.replace(expr, returns=immutabledict(new_returns)) def map_call(self, expr: Call) -> AbstractResultWithNamedArrays: - return Call(self.map_function_definition(expr.function), + return Call(self.rec_function_definition(expr.function), immutabledict({name: self.rec(bnd) for name, bnd in expr.bindings.items()}), tags=expr.tags, @@ -561,19 +589,21 @@ def map_placeholder(self, non_equality_tags=expr.non_equality_tags) def map_stack(self, expr: Stack, *args: P.args, **kwargs: P.kwargs) -> Array: - arrays = tuple(self.rec_ary(arr, *args, **kwargs) for arr in expr.arrays) + arrays = tuple( + _verify_is_array(self.rec(arr, *args, **kwargs)) for arr in expr.arrays) return Stack(arrays=arrays, axis=expr.axis, axes=expr.axes, tags=expr.tags, non_equality_tags=expr.non_equality_tags) def map_concatenate(self, expr: Concatenate, *args: P.args, **kwargs: P.kwargs) -> Array: - arrays = tuple(self.rec_ary(arr, *args, **kwargs) for arr in expr.arrays) + arrays = tuple( + _verify_is_array(self.rec(arr, *args, **kwargs)) for arr in expr.arrays) return Concatenate(arrays=arrays, axis=expr.axis, axes=expr.axes, tags=expr.tags, non_equality_tags=expr.non_equality_tags) def map_roll(self, expr: Roll, *args: P.args, **kwargs: P.kwargs) -> Array: - return Roll(array=self.rec_ary(expr.array, *args, **kwargs), + return Roll(array=_verify_is_array(self.rec(expr.array, *args, **kwargs)), shift=expr.shift, axis=expr.axis, axes=expr.axes, @@ -582,7 +612,8 @@ def map_roll(self, expr: Roll, *args: P.args, **kwargs: P.kwargs) -> Array: def map_axis_permutation(self, expr: AxisPermutation, *args: P.args, **kwargs: P.kwargs) -> Array: - return AxisPermutation(array=self.rec_ary(expr.array, *args, **kwargs), + return AxisPermutation(array=_verify_is_array( + self.rec(expr.array, *args, **kwargs)), axis_permutation=expr.axis_permutation, axes=expr.axes, tags=expr.tags, @@ -591,7 +622,7 @@ def map_axis_permutation(self, expr: AxisPermutation, def _map_index_base(self, expr: IndexBase, *args: P.args, **kwargs: P.kwargs) -> Array: assert isinstance(expr, _SuppliedAxesAndTagsMixin) - return type(expr)(self.rec_ary(expr.array, *args, **kwargs), + return type(expr)(_verify_is_array(self.rec(expr.array, *args, **kwargs)), indices=self.rec_idx_or_size_tuple(expr.indices, *args, **kwargs), axes=expr.axes, @@ -631,7 +662,8 @@ def map_size_param(self, def map_einsum(self, expr: Einsum, *args: P.args, **kwargs: P.kwargs) -> Array: return Einsum(expr.access_descriptors, - tuple(self.rec_ary(arg, *args, **kwargs) for arg in expr.args), + tuple(_verify_is_array( + self.rec(arg, *args, **kwargs)) for arg in expr.args), axes=expr.axes, redn_axis_to_redn_descr=expr.redn_axis_to_redn_descr, tags=expr.tags, @@ -650,7 +682,8 @@ def map_named_array(self, def map_dict_of_named_arrays(self, expr: DictOfNamedArrays, *args: P.args, **kwargs: P.kwargs ) -> DictOfNamedArrays: - return DictOfNamedArrays({key: self.rec_ary(val.expr, *args, **kwargs) + return DictOfNamedArrays({key: _verify_is_array( + self.rec(val.expr, *args, **kwargs)) for key, val in expr.items()}, tags=expr.tags, ) @@ -682,7 +715,7 @@ def map_loopy_call_result(self, expr: LoopyCallResult, def map_reshape(self, expr: Reshape, *args: P.args, **kwargs: P.kwargs) -> Array: - return Reshape(self.rec_ary(expr.array, *args, **kwargs), + return Reshape(_verify_is_array(self.rec(expr.array, *args, **kwargs)), newshape=self.rec_idx_or_size_tuple(expr.newshape, *args, **kwargs), order=expr.order, @@ -694,10 +727,11 @@ def map_distributed_send_ref_holder(self, expr: DistributedSendRefHolder, *args: P.args, **kwargs: P.kwargs) -> Array: return DistributedSendRefHolder( send=DistributedSend( - data=self.rec_ary(expr.send.data, *args, **kwargs), + data=_verify_is_array(self.rec(expr.send.data, *args, **kwargs)), dest_rank=expr.send.dest_rank, comm_tag=expr.send.comm_tag), - passthrough_data=self.rec_ary(expr.passthrough_data, *args, **kwargs)) + passthrough_data=_verify_is_array( + self.rec(expr.passthrough_data, *args, **kwargs))) def map_distributed_recv(self, expr: DistributedRecv, *args: P.args, **kwargs: P.kwargs) -> Array: @@ -717,7 +751,7 @@ def map_function_definition( def map_call(self, expr: Call, *args: P.args, **kwargs: P.kwargs) -> AbstractResultWithNamedArrays: - return Call(self.map_function_definition(expr.function, *args, **kwargs), + return Call(self.rec_function_definition(expr.function, *args, **kwargs), immutabledict({name: self.rec(bnd, *args, **kwargs) for name, bnd in expr.bindings.items()}), tags=expr.tags, @@ -734,16 +768,21 @@ def map_named_call_result(self, expr: NamedCallResult, # {{{ CombineMapper -class CombineMapper(Mapper[ResultT, []]): +class CombineMapper(Mapper[ResultT, FunctionResultT, []]): """ Abstract mapper that recursively combines the results of user nodes of a given expression. .. automethod:: combine """ - def __init__(self) -> None: + def __init__( + self, + _function_cache: dict[FunctionDefinition, FunctionResultT] | None = None + ) -> None: super().__init__() self.cache: dict[ArrayOrNames, ResultT] = {} + self.function_cache: dict[FunctionDefinition, FunctionResultT] = \ + _function_cache if _function_cache is not None else {} def rec_idx_or_size_tuple(self, situp: tuple[IndexOrShapeExpr, ...] ) -> tuple[ResultT, ...]: @@ -756,6 +795,14 @@ def rec(self, expr: ArrayOrNames) -> ResultT: self.cache[expr] = result return result + def rec_function_definition( + self, expr: FunctionDefinition) -> FunctionResultT: + if expr in self.function_cache: + return self.function_cache[expr] + result: FunctionResultT = super().rec_function_definition(expr) + self.function_cache[expr] = result + return result + def __call__(self, expr: ArrayOrNames) -> ResultT: return self.rec(expr) @@ -839,8 +886,7 @@ def map_distributed_send_ref_holder( def map_distributed_recv(self, expr: DistributedRecv) -> ResultT: return self.combine(*self.rec_idx_or_size_tuple(expr.shape)) - @memoize_method - def map_function_definition(self, expr: FunctionDefinition) -> ResultT: + def map_function_definition(self, expr: FunctionDefinition) -> FunctionResultT: raise NotImplementedError("Combining results from a callee expression" " is context-dependent. Derived classes" " must override map_function_definition.") @@ -858,7 +904,7 @@ def map_named_call_result(self, expr: NamedCallResult) -> ResultT: # {{{ DependencyMapper -class DependencyMapper(CombineMapper[R]): +class DependencyMapper(CombineMapper[R, R]): """ Maps a :class:`pytato.array.Array` to a :class:`frozenset` of :class:`pytato.array.Array`'s it depends on. @@ -920,14 +966,13 @@ def map_distributed_send_ref_holder( def map_distributed_recv(self, expr: DistributedRecv) -> R: return self.combine(frozenset([expr]), super().map_distributed_recv(expr)) - @memoize_method def map_function_definition(self, expr: FunctionDefinition) -> R: # do not include arrays from the function's body as it would involve # putting arrays from different namespaces into the same collection. return frozenset() def map_call(self, expr: Call) -> R: - return self.combine(self.map_function_definition(expr.function), + return self.combine(self.rec_function_definition(expr.function), *[self.rec(bnd) for bnd in expr.bindings.values()]) def map_named_call_result(self, expr: NamedCallResult) -> R: @@ -958,7 +1003,8 @@ def combine(self, *args: frozenset[Array]) -> frozenset[Array]: # {{{ InputGatherer -class InputGatherer(CombineMapper[frozenset[InputArgumentBase]]): +class InputGatherer( + CombineMapper[frozenset[InputArgumentBase], frozenset[InputArgumentBase]]): """ Mapper to combine all instances of :class:`pytato.array.InputArgumentBase` that an array expression depends on. @@ -977,7 +1023,6 @@ def map_data_wrapper(self, expr: DataWrapper) -> frozenset[InputArgumentBase]: def map_size_param(self, expr: SizeParam) -> frozenset[SizeParam]: return frozenset([expr]) - @memoize_method def map_function_definition(self, expr: FunctionDefinition ) -> frozenset[InputArgumentBase]: # get rid of placeholders local to the function. @@ -999,7 +1044,7 @@ def map_function_definition(self, expr: FunctionDefinition return frozenset(result) def map_call(self, expr: Call) -> frozenset[InputArgumentBase]: - return self.combine(self.map_function_definition(expr.function), + return self.combine(self.rec_function_definition(expr.function), *[ self.rec(bnd) for name, bnd in sorted(expr.bindings.items())]) @@ -1009,7 +1054,8 @@ def map_call(self, expr: Call) -> frozenset[InputArgumentBase]: # {{{ SizeParamGatherer -class SizeParamGatherer(CombineMapper[frozenset[SizeParam]]): +class SizeParamGatherer( + CombineMapper[frozenset[SizeParam], frozenset[SizeParam]]): """ Mapper to combine all instances of :class:`pytato.array.SizeParam` that an array expression depends on. @@ -1022,14 +1068,13 @@ def combine(self, *args: frozenset[SizeParam] def map_size_param(self, expr: SizeParam) -> frozenset[SizeParam]: return frozenset([expr]) - @memoize_method def map_function_definition(self, expr: FunctionDefinition ) -> frozenset[SizeParam]: return self.combine(*[self.rec(ret) for ret in expr.returns.values()]) def map_call(self, expr: Call) -> frozenset[SizeParam]: - return self.combine(self.map_function_definition(expr.function), + return self.combine(self.rec_function_definition(expr.function), *[ self.rec(bnd) for name, bnd in sorted(expr.bindings.items())]) @@ -1039,7 +1084,7 @@ def map_call(self, expr: Call) -> frozenset[SizeParam]: # {{{ WalkMapper -class WalkMapper(Mapper[None, P]): +class WalkMapper(Mapper[None, None, P]): """ A mapper that walks over all the arrays in a :class:`pytato.Array`. @@ -1228,7 +1273,7 @@ def map_call(self, expr: Call, *args: P.args, **kwargs: P.kwargs) -> None: if not self.visit(expr, *args, **kwargs): return - self.map_function_definition(expr.function, *args, **kwargs) + self.rec_function_definition(expr.function, *args, **kwargs) for bnd in expr.bindings.values(): self.rec(bnd, *args, **kwargs) @@ -1248,6 +1293,9 @@ def map_named_call_result(self, expr: NamedCallResult, # {{{ CachedWalkMapper +VisitKeyT: TypeAlias = Hashable + + class CachedWalkMapper(WalkMapper[P]): """ WalkMapper that visits each node in the DAG exactly once. This loses some @@ -1255,24 +1303,45 @@ class CachedWalkMapper(WalkMapper[P]): one of its predecessors. """ - def __init__(self) -> None: + def __init__( + self, + _visited_functions: set[VisitKeyT] | None = None + ) -> None: super().__init__() - self._visited_nodes: set[Any] = set() + self._visited_arrays_or_names: set[VisitKeyT] = set() + + self._visited_functions: set[VisitKeyT] = \ + _visited_functions if _visited_functions is not None else set() - def get_cache_key(self, expr: ArrayOrNames, *args: Any, **kwargs: Any) -> Any: + def get_cache_key( + self, expr: ArrayOrNames, *args: P.args, **kwargs: P.kwargs + ) -> VisitKeyT: raise NotImplementedError - def rec(self, expr: ArrayOrNames, *args: Any, **kwargs: Any - ) -> None: + def get_function_definition_cache_key( + self, expr: FunctionDefinition, *args: P.args, **kwargs: P.kwargs + ) -> VisitKeyT: + raise NotImplementedError + + def rec(self, expr: ArrayOrNames, *args: P.args, **kwargs: P.kwargs) -> None: cache_key = self.get_cache_key(expr, *args, **kwargs) - if cache_key in self._visited_nodes: + if cache_key in self._visited_arrays_or_names: return super().rec(expr, *args, **kwargs) - self._visited_nodes.add(cache_key) + self._visited_arrays_or_names.add(cache_key) + + def rec_function_definition(self, expr: FunctionDefinition, + *args: P.args, **kwargs: P.kwargs) -> None: + cache_key = self.get_function_definition_cache_key(expr, *args, **kwargs) + if cache_key in self._visited_functions: + return + + super().rec_function_definition(expr, *args, **kwargs) + self._visited_functions.add(cache_key) def clone_for_callee(self, function: FunctionDefinition) -> Self: - return type(self)() + return type(self)(_visited_functions=self._visited_functions) # }}} @@ -1291,8 +1360,10 @@ class TopoSortMapper(CachedWalkMapper[[]]): :class:`~pytato.function.FunctionDefinition`. """ - def __init__(self) -> None: - super().__init__() + def __init__( + self, + _visited_functions: set[VisitKeyT] | None = None) -> None: + super().__init__(_visited_functions=_visited_functions) self.topological_order: list[Array] = [] def get_cache_key(self, expr: ArrayOrNames) -> int: @@ -1301,7 +1372,6 @@ def get_cache_key(self, expr: ArrayOrNames) -> int: def post_visit(self, expr: Any) -> None: self.topological_order.append(expr) - @memoize_method def map_function_definition(self, expr: FunctionDefinition) -> None: # do nothing as it includes arrays from a different namespace. return @@ -1317,13 +1387,17 @@ class CachedMapAndCopyMapper(CopyMapper): traversals are memoized i.e. each node is mapped via *map_fn* exactly once. """ - def __init__(self, map_fn: Callable[[ArrayOrNames], ArrayOrNames]) -> None: - super().__init__() + def __init__( + self, + map_fn: Callable[[ArrayOrNames], ArrayOrNames], + _function_cache: dict[Hashable, FunctionDefinition] | None = None + ) -> None: + super().__init__(_function_cache=_function_cache) self.map_fn: Callable[[ArrayOrNames], ArrayOrNames] = map_fn def clone_for_callee( self, function: FunctionDefinition) -> Self: - return type(self)(self.map_fn) + return type(self)(self.map_fn, _function_cache=self._function_cache) def rec(self, expr: ArrayOrNames) -> ArrayOrNames: if expr in self._cache: @@ -1371,7 +1445,7 @@ def _materialize_if_mpms(expr: Array, return MPMSMaterializerAccumulator(materialized_predecessors, expr) -class MPMSMaterializer(Mapper[MPMSMaterializerAccumulator, []]): +class MPMSMaterializer(Mapper[MPMSMaterializerAccumulator, Never, []]): """ See :func:`materialize_with_mpms` for an explanation. @@ -1556,7 +1630,7 @@ def copy_dict_of_named_arrays(source_dict: DictOfNamedArrays, if not source_dict: data = {} else: - data = {name: copy_mapper.rec_ary(val.expr) + data = {name: _verify_is_array(copy_mapper.rec(val.expr)) for name, val in sorted(source_dict.items())} return DictOfNamedArrays(data, tags=source_dict.tags) @@ -1648,7 +1722,7 @@ def materialize_with_mpms(expr: DictOfNamedArrays) -> DictOfNamedArrays: # {{{ UsersCollector -class UsersCollector(CachedMapper[None, []]): +class UsersCollector(CachedMapper[None, Never, []]): """ Maps a graph to a dictionary representation mapping a node to its users, i.e. all the nodes using its value. @@ -1772,7 +1846,6 @@ def map_distributed_send_ref_holder( def map_distributed_recv(self, expr: DistributedRecv) -> None: self.rec_idx_or_size_tuple(expr, expr.shape) - @memoize_method def map_function_definition(self, expr: FunctionDefinition) -> None: raise AssertionError("Control shouldn't reach at this point." " Instantiate another UsersCollector to" diff --git a/pytato/transform/calls.py b/pytato/transform/calls.py index bc3d69909..34f89cbc1 100644 --- a/pytato/transform/calls.py +++ b/pytato/transform/calls.py @@ -40,7 +40,7 @@ ) from pytato.function import Call, NamedCallResult from pytato.tags import InlineCallTag -from pytato.transform import ArrayOrNames, CopyMapper +from pytato.transform import ArrayOrNames, CopyMapper, _verify_is_array if TYPE_CHECKING: @@ -58,12 +58,17 @@ class PlaceholderSubstitutor(CopyMapper): """ def __init__(self, substitutions: Mapping[str, Array]) -> None: + # Ignoring function cache, since we don't support functions anyway super().__init__() self.substitutions = substitutions def map_placeholder(self, expr: Placeholder) -> Array: return self.substitutions[expr.name] + def map_named_call_result(self, expr: NamedCallResult) -> NamedCallResult: + raise NotImplementedError( + "PlaceholderSubstitutor does not support functions.") + class Inliner(CopyMapper): """ @@ -78,7 +83,7 @@ def map_call(self, expr: Call) -> AbstractResultWithNamedArrays: substitutor = PlaceholderSubstitutor(new_expr.bindings) return DictOfNamedArrays( - {name: substitutor.rec_ary(ret) + {name: _verify_is_array(substitutor.rec(ret)) for name, ret in new_expr.function.returns.items()}, tags=new_expr.tags ) diff --git a/pytato/transform/einsum_distributive_law.py b/pytato/transform/einsum_distributive_law.py index 7a23518c6..8cd635f61 100644 --- a/pytato/transform/einsum_distributive_law.py +++ b/pytato/transform/einsum_distributive_law.py @@ -59,6 +59,7 @@ from pytato.transform import ( MappedT, TransformMapperWithExtraArgs, + _verify_is_array, ) from pytato.utils import are_shapes_equal @@ -186,30 +187,34 @@ def map_index_lambda(self, and isinstance(hlo.x2, Array) and are_shapes_equal(hlo.x1.shape, hlo.x2.shape)) # https://github.com/python/mypy/issues/16499 - return self.rec_ary(hlo.x1, ctx) + self.rec_ary(hlo.x2, ctx) # type: ignore[no-any-return] + return ( # type: ignore[no-any-return] + _verify_is_array(self.rec(hlo.x1, ctx)) + + _verify_is_array(self.rec(hlo.x2, ctx))) elif hlo.binary_op == BinaryOpType.SUB: assert (isinstance(hlo.x1, Array) and isinstance(hlo.x2, Array) and are_shapes_equal(hlo.x1.shape, hlo.x2.shape)) assert are_shapes_equal(hlo.x1.shape, hlo.x2.shape) # https://github.com/python/mypy/issues/16499 - return self.rec_ary(hlo.x1, ctx) - self.rec_ary(hlo.x2, ctx) # type: ignore[no-any-return] + return ( # type: ignore[no-any-return] + _verify_is_array(self.rec(hlo.x1, ctx)) + - _verify_is_array(self.rec(hlo.x2, ctx))) elif hlo.binary_op == BinaryOpType.MULT: if isinstance(hlo.x1, Array) and np.isscalar(hlo.x2): # https://github.com/python/mypy/issues/16499 - return self.rec_ary(hlo.x1, ctx) * hlo.x2 # type: ignore[no-any-return] + return _verify_is_array(self.rec(hlo.x1, ctx)) * hlo.x2 # type: ignore[no-any-return] else: assert isinstance(hlo.x2, Array) and np.isscalar(hlo.x1) # https://github.com/python/mypy/issues/16499 - return hlo.x1 * self.rec_ary(hlo.x2, ctx) # type: ignore[no-any-return] + return hlo.x1 * _verify_is_array(self.rec(hlo.x2, ctx)) # type: ignore[no-any-return] elif hlo.binary_op == BinaryOpType.TRUEDIV: if isinstance(hlo.x1, Array) and np.isscalar(hlo.x2): # https://github.com/python/mypy/issues/16499 - return self.rec_ary(hlo.x1, ctx) / hlo.x2 # type: ignore[no-any-return] + return _verify_is_array(self.rec(hlo.x1, ctx)) / hlo.x2 # type: ignore[no-any-return] else: assert isinstance(hlo.x2, Array) and np.isscalar(hlo.x1) # https://github.com/python/mypy/issues/16499 - return hlo.x1 / self.rec_ary(hlo.x2, ctx) # type: ignore[no-any-return] + return hlo.x1 / _verify_is_array(self.rec(hlo.x2, ctx)) # type: ignore[no-any-return] else: raise NotImplementedError(hlo) else: @@ -217,7 +222,7 @@ def map_index_lambda(self, expr=expr.expr, shape=expr.shape, dtype=expr.dtype, - bindings=immutabledict({name: self.rec_ary(bnd, None) + bindings=immutabledict({name: _verify_is_array(self.rec(bnd, None)) for name, bnd in sorted(expr.bindings.items())}), var_to_reduction_descr=expr.var_to_reduction_descr, tags=expr.tags, @@ -244,12 +249,13 @@ def map_einsum(self, tags=expr.tags, axes=expr.axes, ) - return self.rec_ary(expr.args[distributive_law_descr.ioperand], ctx) + return _verify_is_array( + self.rec(expr.args[distributive_law_descr.ioperand], ctx)) else: assert isinstance(distributive_law_descr, DoNotDistribute) rec_expr = Einsum( expr.access_descriptors, - tuple(self.rec_ary(arg, None) for arg in expr.args), + tuple(_verify_is_array(self.rec(arg, None)) for arg in expr.args), expr.redn_axis_to_redn_descr, tags=expr.tags, axes=expr.axes @@ -260,7 +266,7 @@ def map_einsum(self, def map_stack(self, expr: Stack, ctx: _EinsumDistributiveLawMapperContext | None) -> Array: - rec_expr = Stack(tuple(self.rec_ary(ary, None) + rec_expr = Stack(tuple(_verify_is_array(self.rec(ary, None)) for ary in expr.arrays), expr.axis, tags=expr.tags, @@ -271,7 +277,7 @@ def map_concatenate(self, expr: Concatenate, ctx: _EinsumDistributiveLawMapperContext | None ) -> Array: - rec_expr = Concatenate(tuple(self.rec_ary(ary, None) + rec_expr = Concatenate(tuple(_verify_is_array(self.rec(ary, None)) for ary in expr.arrays), expr.axis, tags=expr.tags, @@ -282,7 +288,7 @@ def map_roll(self, expr: Roll, ctx: _EinsumDistributiveLawMapperContext | None ) -> Array: - rec_expr = Roll(self.rec_ary(expr.array, None), + rec_expr = Roll(_verify_is_array(self.rec(expr.array, None)), expr.shift, expr.axis, tags=expr.tags, @@ -293,7 +299,7 @@ def map_axis_permutation(self, expr: AxisPermutation, ctx: _EinsumDistributiveLawMapperContext | None ) -> Array: - rec_expr = AxisPermutation(self.rec_ary(expr.array, None), + rec_expr = AxisPermutation(_verify_is_array(self.rec(expr.array, None)), expr.axis_permutation, tags=expr.tags, axes=expr.axes) @@ -303,7 +309,7 @@ def _map_index_base(self, expr: IndexBase, ctx: _EinsumDistributiveLawMapperContext | None ) -> Array: - rec_expr = type(expr)(self.rec_ary(expr.array, None), + rec_expr = type(expr)(_verify_is_array(self.rec(expr.array, None)), expr.indices, tags=expr.tags, axes=expr.axes) @@ -317,7 +323,7 @@ def map_reshape(self, expr: Reshape, ctx: _EinsumDistributiveLawMapperContext | None ) -> Array: - rec_expr = Reshape(self.rec_ary(expr.array, None), + rec_expr = Reshape(_verify_is_array(self.rec(expr.array, None)), expr.newshape, expr.order, tags=expr.tags, diff --git a/pytato/transform/lower_to_index_lambda.py b/pytato/transform/lower_to_index_lambda.py index e952e4e9e..1bdcad098 100644 --- a/pytato/transform/lower_to_index_lambda.py +++ b/pytato/transform/lower_to_index_lambda.py @@ -29,7 +29,7 @@ """ from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, TypeVar, cast +from typing import TYPE_CHECKING, Any, Never, TypeVar, cast from immutabledict import immutabledict @@ -637,9 +637,9 @@ def map_axis_permutation(self, expr: AxisPermutation) -> IndexLambda: non_equality_tags=expr.non_equality_tags) -class ToIndexLambdaMapper(Mapper[Array, []], ToIndexLambdaMixin): +class ToIndexLambdaMapper(Mapper[Array, Never, []], ToIndexLambdaMixin): - def handle_unsupported_array(self, expr: Any) -> Any: + def handle_unsupported_array(self, expr: Array) -> Array: raise CannotBeLoweredToIndexLambda(type(expr)) def rec(self, expr: Array) -> Array: # type: ignore[override] diff --git a/pytato/transform/metadata.py b/pytato/transform/metadata.py index 8cd520d94..49190c76e 100644 --- a/pytato/transform/metadata.py +++ b/pytato/transform/metadata.py @@ -42,7 +42,7 @@ import logging from typing import ( TYPE_CHECKING, - Any, + Never, TypeVar, cast, ) @@ -90,9 +90,9 @@ if TYPE_CHECKING: - from collections.abc import Collection, Mapping + from collections.abc import Collection, Hashable, Mapping - from pytato.function import NamedCallResult + from pytato.function import FunctionDefinition, NamedCallResult from pytato.loopy import LoopyCall @@ -101,7 +101,7 @@ # {{{ AxesTagsEquationCollector -class AxesTagsEquationCollector(Mapper[None, []]): +class AxesTagsEquationCollector(Mapper[None, Never, []]): r""" Records equations arising from operand/output axes equivalence for an array operation. This mapper implements a default set of propagation rules, @@ -595,12 +595,13 @@ class AxisTagAttacher(CopyMapper): """ def __init__(self, axis_to_tags: Mapping[tuple[Array, int], Collection[Tag]], - tag_corresponding_redn_descr: bool): - super().__init__() + tag_corresponding_redn_descr: bool, + _function_cache: dict[Hashable, FunctionDefinition] | None = None): + super().__init__(_function_cache=_function_cache) self.axis_to_tags: Mapping[tuple[Array, int], Collection[Tag]] = axis_to_tags self.tag_corresponding_redn_descr: bool = tag_corresponding_redn_descr - def rec(self, expr: ArrayOrNames) -> Any: + def rec(self, expr: ArrayOrNames) -> ArrayOrNames: if isinstance(expr, AbstractResultWithNamedArrays | DistributedSendRefHolder): return super().rec(expr) else: diff --git a/pytato/transform/remove_broadcasts_einsum.py b/pytato/transform/remove_broadcasts_einsum.py index 8c7c224fb..2d8f7e0f0 100644 --- a/pytato/transform/remove_broadcasts_einsum.py +++ b/pytato/transform/remove_broadcasts_einsum.py @@ -31,7 +31,7 @@ from typing import cast from pytato.array import Array, Einsum, EinsumAxisDescriptor -from pytato.transform import CopyMapper, MappedT +from pytato.transform import CopyMapper, MappedT, _verify_is_array from pytato.utils import are_shape_components_equal @@ -42,7 +42,7 @@ def map_einsum(self, expr: Einsum) -> Array: descr_to_axis_len = expr._access_descr_to_axis_len() for acc_descrs, arg in zip(expr.access_descriptors, expr.args, strict=True): - arg = self.rec_ary(arg) + arg = _verify_is_array(self.rec(arg)) axes_to_squeeze: list[int] = [] for idim, acc_descr in enumerate(acc_descrs): if not are_shape_components_equal(arg.shape[idim], diff --git a/pytato/utils.py b/pytato/utils.py index aedd0dc05..10a5b7d43 100644 --- a/pytato/utils.py +++ b/pytato/utils.py @@ -25,6 +25,7 @@ from typing import ( TYPE_CHECKING, Any, + Never, TypeVar, cast, ) @@ -269,7 +270,7 @@ def cast_to_result_type( # {{{ dim_to_index_lambda_components -class ShapeExpressionMapper(CachedMapper[ScalarExpression, []]): +class ShapeExpressionMapper(CachedMapper[ScalarExpression, Never, []]): """ Mapper that takes a shape component and returns it as a scalar expression. """ @@ -372,7 +373,7 @@ def are_shapes_equal(shape1: ShapeType, shape2: ShapeType) -> bool: # {{{ ShapeToISLExpressionMapper -class ShapeToISLExpressionMapper(CachedMapper[isl.Aff, []]): +class ShapeToISLExpressionMapper(CachedMapper[isl.Aff, Never, []]): """ Mapper that takes a shape component and returns it as :class:`isl.Aff`. """ diff --git a/pytato/visualization/dot.py b/pytato/visualization/dot.py index 7e685aba9..02352fb95 100644 --- a/pytato/visualization/dot.py +++ b/pytato/visualization/dot.py @@ -178,7 +178,7 @@ def stringify_shape(shape: ShapeType) -> str: return "(" + ", ".join(components) + ")" -class ArrayToDotNodeInfoMapper(CachedMapper[ArrayOrNames, []]): +class ArrayToDotNodeInfoMapper(CachedMapper[None, None, []]): def __init__(self) -> None: super().__init__() self.node_to_dot: dict[ArrayOrNames, _DotNodeInfo] = {} @@ -197,8 +197,7 @@ def get_common_dot_info(self, expr: Array) -> _DotNodeInfo: return _DotNodeInfo(title, fields, edges) # type-ignore-reason: incompatible with supertype - def handle_unsupported_array(self, # type: ignore[override] - expr: Array) -> None: + def handle_unsupported_array(self, expr: Array) -> None: # Default handler, does its best to guess how to handle fields. info = self.get_common_dot_info(expr) diff --git a/pytato/visualization/fancy_placeholder_data_flow.py b/pytato/visualization/fancy_placeholder_data_flow.py index 3d06309fd..eadae86ad 100644 --- a/pytato/visualization/fancy_placeholder_data_flow.py +++ b/pytato/visualization/fancy_placeholder_data_flow.py @@ -6,7 +6,7 @@ from __future__ import annotations from dataclasses import dataclass -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Never from pytools import UniqueNameGenerator @@ -100,7 +100,7 @@ def _get_dot_node_from_predecessors(node_id: str, return NoShowNode(), frozenset() -class FancyDotWriter(CachedMapper[_FancyDotWriterNode, []]): +class FancyDotWriter(CachedMapper[_FancyDotWriterNode, Never, []]): def __init__(self) -> None: super().__init__() self.vng = UniqueNameGenerator() diff --git a/test/testlib.py b/test/testlib.py index 36857197b..a28dec67e 100644 --- a/test/testlib.py +++ b/test/testlib.py @@ -2,7 +2,7 @@ import operator import random -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Never import numpy as np @@ -32,7 +32,7 @@ # {{{ tools for comparison to numpy -class NumpyBasedEvaluator(Mapper[Any, []]): +class NumpyBasedEvaluator(Mapper[Any, Never, []]): """ Mapper to return the result according to an eager evaluation array package *np*.