diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index 9bf351a5c..a1c2ac62a 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -28,9 +28,10 @@ from typing import TYPE_CHECKING, Any +from typing_extensions import Self + from loopy.tools import LoopyKeyBuilder from pymbolic.mapper.optimize import optimize_mapper -from pytools import memoize_method from pytato.array import ( Array, @@ -76,7 +77,7 @@ # {{{ NUserCollector -class NUserCollector(Mapper[None, []]): +class NUserCollector(Mapper[None, None, []]): """ A :class:`pytato.transform.CachedWalkMapper` that records the number of times an array expression is a direct dependency of other nodes. @@ -317,7 +318,7 @@ def is_einsum_similar_to_subscript(expr: Einsum, subscripts: str) -> bool: # {{{ DirectPredecessorsGetter -class DirectPredecessorsGetter(Mapper[frozenset[ArrayOrNames], []]): +class DirectPredecessorsGetter(Mapper[frozenset[ArrayOrNames], None, []]): """ Mapper to get the `direct predecessors @@ -413,9 +414,14 @@ class NodeCountMapper(CachedWalkMapper[[]]): Dictionary mapping node types to number of nodes of that type. """ - def __init__(self, count_duplicates: bool = False) -> None: + def __init__( + self, + count_duplicates: bool = False, + _visited_functions: set[Any] | None = None, + ) -> None: + super().__init__(_visited_functions=_visited_functions) + from collections import defaultdict - super().__init__() self.expr_type_counts: dict[type[Any], int] = defaultdict(int) self.count_duplicates = count_duplicates @@ -423,6 +429,16 @@ def get_cache_key(self, expr: ArrayOrNames) -> int | ArrayOrNames: # Returns unique nodes only if count_duplicates is False return id(expr) if self.count_duplicates else expr + def get_function_definition_cache_key( + self, expr: FunctionDefinition) -> int | FunctionDefinition: + # Returns unique nodes only if count_duplicates is False + return id(expr) if self.count_duplicates else expr + + def clone_for_callee(self, function: FunctionDefinition) -> Self: + return type(self)( + count_duplicates=self.count_duplicates, + _visited_functions=self._visited_functions) + def post_visit(self, expr: Any) -> None: if not isinstance(expr, DictOfNamedArrays): self.expr_type_counts[type(expr)] += 1 @@ -488,15 +504,20 @@ class NodeMultiplicityMapper(CachedWalkMapper[[]]): .. autoattribute:: expr_multiplicity_counts """ - def __init__(self) -> None: + def __init__(self, _visited_functions: set[Any] | None = None) -> None: + super().__init__(_visited_functions=_visited_functions) + from collections import defaultdict - super().__init__() self.expr_multiplicity_counts: dict[Array, int] = defaultdict(int) def get_cache_key(self, expr: ArrayOrNames) -> int: # Returns each node, including nodes that are duplicates return id(expr) + def get_function_definition_cache_key(self, expr: FunctionDefinition) -> int: + # Returns each node, including nodes that are duplicates + return id(expr) + def post_visit(self, expr: Any) -> None: if not isinstance(expr, DictOfNamedArrays): self.expr_multiplicity_counts[expr] += 1 @@ -530,14 +551,16 @@ class CallSiteCountMapper(CachedWalkMapper[[]]): The number of nodes. """ - def __init__(self) -> None: - super().__init__() + def __init__(self, _visited_functions: set[Any] | None = None) -> None: + super().__init__(_visited_functions=_visited_functions) self.count = 0 def get_cache_key(self, expr: ArrayOrNames) -> int: return id(expr) - @memoize_method + def get_function_definition_cache_key(self, expr: FunctionDefinition) -> int: + return id(expr) + def map_function_definition(self, expr: FunctionDefinition) -> None: if not self.visit(expr): return diff --git a/pytato/codegen.py b/pytato/codegen.py index d08445517..85ac4052d 100644 --- a/pytato/codegen.py +++ b/pytato/codegen.py @@ -68,9 +68,9 @@ if TYPE_CHECKING: - from collections.abc import Mapping + from collections.abc import Hashable, Mapping - from pytato.function import NamedCallResult + from pytato.function import FunctionDefinition, NamedCallResult from pytato.target import Target @@ -136,10 +136,13 @@ class CodeGenPreprocessor(ToIndexLambdaMixin, CopyMapper): # type: ignore[misc] ====================================== ===================================== """ - def __init__(self, target: Target, - kernels_seen: dict[str, lp.LoopKernel] | None = None - ) -> None: - super().__init__() + def __init__( + self, + target: Target, + kernels_seen: dict[str, lp.LoopKernel] | None = None, + _function_cache: dict[Hashable, FunctionDefinition] | None = None + ) -> None: + super().__init__(_function_cache=_function_cache) self.bound_arguments: dict[str, DataInterface] = {} self.var_name_gen: UniqueNameGenerator = UniqueNameGenerator() self.target = target @@ -266,13 +269,16 @@ def normalize_outputs( @optimize_mapper(drop_args=True, drop_kwargs=True, inline_get_cache_key=True) class NamesValidityChecker(CachedWalkMapper[[]]): - def __init__(self) -> None: + def __init__(self, _visited_functions: set[Any] | None = None) -> None: self.name_to_input: dict[str, InputArgumentBase] = {} - super().__init__() + super().__init__(_visited_functions=_visited_functions) def get_cache_key(self, expr: ArrayOrNames) -> int: return id(expr) + def get_function_definition_cache_key(self, expr: FunctionDefinition) -> int: + return id(expr) + def post_visit(self, expr: Any) -> None: if isinstance(expr, Placeholder | SizeParam | DataWrapper): if expr.name is not None: diff --git a/pytato/distributed/partition.py b/pytato/distributed/partition.py index 4a0eb5897..8e5940a06 100644 --- a/pytato/distributed/partition.py +++ b/pytato/distributed/partition.py @@ -288,8 +288,9 @@ def __init__(self, recvd_ary_to_name: Mapping[Array, str], sptpo_ary_to_name: Mapping[Array, str], name_to_output: Mapping[str, Array], + _function_cache: dict[Hashable, FunctionDefinition] | None = None, ) -> None: - super().__init__() + super().__init__(_function_cache=_function_cache) self.recvd_ary_to_name = recvd_ary_to_name self.sptpo_ary_to_name = sptpo_ary_to_name @@ -303,7 +304,7 @@ def clone_for_callee( self, function: FunctionDefinition) -> _DistributedInputReplacer: # Function definitions aren't allowed to contain receives, # stored arrays promoted to part outputs, or part outputs - return type(self)({}, {}, {}) + return type(self)({}, {}, {}, _function_cache=self._function_cache) def map_placeholder(self, expr: Placeholder) -> Placeholder: self.user_input_names.add(expr.name) @@ -456,7 +457,7 @@ def _recv_to_comm_id( class _LocalSendRecvDepGatherer( - CombineMapper[frozenset[CommunicationOpIdentifier]]): + CombineMapper[frozenset[CommunicationOpIdentifier], None]): def __init__(self, local_rank: int) -> None: super().__init__() self.local_comm_ids_to_needed_comm_ids: \ diff --git a/pytato/distributed/verify.py b/pytato/distributed/verify.py index 5e8aa526d..bde9e2277 100644 --- a/pytato/distributed/verify.py +++ b/pytato/distributed/verify.py @@ -144,8 +144,8 @@ class MissingRecvError(DistributedPartitionVerificationError): @optimize_mapper(drop_args=True, drop_kwargs=True, inline_get_cache_key=True) class _SeenNodesWalkMapper(CachedWalkMapper[[]]): - def __init__(self) -> None: - super().__init__() + def __init__(self, _visited_functions: set[Any] | None = None) -> None: + super().__init__(_visited_functions=_visited_functions) self.seen_nodes: set[ArrayOrNames] = set() def get_cache_key(self, expr: ArrayOrNames) -> int: diff --git a/pytato/equality.py b/pytato/equality.py index 1ef2a88ed..6eed44b48 100644 --- a/pytato/equality.py +++ b/pytato/equality.py @@ -27,8 +27,6 @@ from typing import TYPE_CHECKING, Any -from pytools import memoize_method - from pytato.array import ( AbstractResultWithNamedArrays, AdvancedIndexInContiguousAxes, @@ -85,26 +83,31 @@ class EqualityComparer: more on this. """ def __init__(self) -> None: + # Uses the same cache for both arrays and functions self._cache: dict[tuple[int, int], bool] = {} - def rec(self, expr1: ArrayOrNames, expr2: Any) -> bool: + def rec(self, expr1: ArrayOrNames | FunctionDefinition, expr2: Any) -> bool: cache_key = id(expr1), id(expr2) try: return self._cache[cache_key] except KeyError: - - method: Callable[[Array | AbstractResultWithNamedArrays, Any], - bool] - - try: - method = getattr(self, expr1._mapper_method) - except AttributeError: - if isinstance(expr1, Array): - result = self.handle_unsupported_array(expr1, expr2) + if expr1 is expr2: + result = True + elif isinstance(expr1, ArrayOrNames): + method: Callable[[ArrayOrNames, Any], bool] + try: + method = getattr(self, expr1._mapper_method) + except AttributeError: + if isinstance(expr1, Array): + result = self.handle_unsupported_array(expr1, expr2) + else: + result = self.map_foreign(expr1, expr2) else: - result = self.map_foreign(expr1, expr2) + result = method(expr1, expr2) + elif isinstance(expr1, FunctionDefinition): + result = self.map_function_definition(expr1, expr2) else: - result = (expr1 is expr2) or method(expr1, expr2) + result = self.map_foreign(expr1, expr2) self._cache[cache_key] = result return result @@ -296,7 +299,6 @@ def map_distributed_recv(self, expr1: DistributedRecv, expr2: Any) -> bool: and expr1.tags == expr2.tags ) - @memoize_method def map_function_definition(self, expr1: FunctionDefinition, expr2: Any ) -> bool: return (expr1.__class__ is expr2.__class__ @@ -310,7 +312,7 @@ def map_function_definition(self, expr1: FunctionDefinition, expr2: Any def map_call(self, expr1: Call, expr2: Any) -> bool: return (expr1.__class__ is expr2.__class__ - and self.map_function_definition(expr1.function, expr2.function) + and self.rec(expr1.function, expr2.function) and frozenset(expr1.bindings) == frozenset(expr2.bindings) and all(self.rec(bnd, expr2.bindings[name]) diff --git a/pytato/stringifier.py b/pytato/stringifier.py index f2e7f066b..952b8f8e2 100644 --- a/pytato/stringifier.py +++ b/pytato/stringifier.py @@ -31,8 +31,6 @@ import numpy as np from immutabledict import immutabledict -from pytools import memoize_method - from pytato.array import ( Array, Axis, @@ -58,7 +56,7 @@ # {{{ Reprifier -class Reprifier(Mapper[str, [int]]): +class Reprifier(Mapper[str, str, [int]]): """ Stringifies :mod:`pytato`-types to closely resemble CPython's implementation of :func:`repr` for its builtin datatypes. @@ -71,6 +69,7 @@ def __init__(self, self.truncation_depth = truncation_depth self.truncation_string = truncation_string + # Uses the same cache for both arrays and functions self._cache: dict[tuple[int, int], str] = {} def rec(self, expr: Any, depth: int) -> str: @@ -82,6 +81,15 @@ def rec(self, expr: Any, depth: int) -> str: self._cache[cache_key] = result return result + def rec_function_definition(self, expr: FunctionDefinition, depth: int) -> str: + cache_key = (id(expr), depth) + try: + return self._cache[cache_key] + except KeyError: + result = super().rec_function_definition(expr, depth) + self._cache[cache_key] = result + return result + def __call__(self, expr: Any, depth: int = 0) -> str: return self.rec(expr, depth) @@ -171,7 +179,6 @@ def _get_field_val(field: str) -> str: for field in dataclasses.fields(type(expr))) + ")") - @memoize_method def map_function_definition(self, expr: FunctionDefinition, depth: int) -> str: if depth > self.truncation_depth: return self.truncation_string @@ -194,7 +201,7 @@ def map_call(self, expr: Call, depth: int) -> str: def _get_field_val(field: str) -> str: if field == "function": - return self.map_function_definition(expr.function, depth+1) + return self.rec_function_definition(expr.function, depth+1) else: return self.rec(getattr(expr, field), depth+1) diff --git a/pytato/target/loopy/codegen.py b/pytato/target/loopy/codegen.py index 528de5d0e..84249c3b8 100644 --- a/pytato/target/loopy/codegen.py +++ b/pytato/target/loopy/codegen.py @@ -384,7 +384,7 @@ def update_t_unit(self, t_unit: lp.TranslationUnit) -> None: # {{{ codegen mapper -class CodeGenMapper(Mapper[ImplementedResult, [CodeGenState]]): +class CodeGenMapper(Mapper[ImplementedResult, None, [CodeGenState]]): """A mapper for generating code for nodes in the computation graph. """ exprgen_mapper: InlinedExpressionGenMapper diff --git a/pytato/target/python/numpy_like.py b/pytato/target/python/numpy_like.py index 0c9bd413c..2561f57f7 100644 --- a/pytato/target/python/numpy_like.py +++ b/pytato/target/python/numpy_like.py @@ -62,6 +62,7 @@ SizeParam, Stack, ) +from pytato.function import FunctionDefinition from pytato.raising import BinaryOpType, C99CallOp from pytato.reductions import ( AllReductionOperation, @@ -171,7 +172,7 @@ def _is_slice_trivial(slice_: NormalizedSlice, } -class NumpyCodegenMapper(CachedMapper[str, []]): +class NumpyCodegenMapper(CachedMapper[str, FunctionDefinition, []]): """ .. note:: diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index 393583037..b4be9aceb 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -43,7 +43,6 @@ from typing_extensions import Self from pymbolic.mapper.optimize import optimize_mapper -from pytools import memoize_method from pytato.array import ( AbstractResultWithNamedArrays, @@ -86,9 +85,6 @@ ArrayOrNames: TypeAlias = Array | AbstractResultWithNamedArrays MappedT = TypeVar("MappedT", Array, AbstractResultWithNamedArrays, ArrayOrNames) -TransformMapperResultT = TypeVar("TransformMapperResultT", # used in TransformMapper - Array, AbstractResultWithNamedArrays, ArrayOrNames) -CachedMapperT = TypeVar("CachedMapperT") # used in CachedMapper IndexOrShapeExpr = TypeVar("IndexOrShapeExpr") R = frozenset[Array] @@ -142,13 +138,15 @@ A type variable representing the input type of a :class:`Mapper`. -.. class:: CombineT +.. class:: ResultT - A type variable representing the type of a :class:`CombineMapper`. + A type variable representing the result type of a :class:`Mapper` when mapping + a :class:`pytato.Array` or :class:`pytato.AbstractResultWithNamedArrays`. -.. class:: ResultT +.. class:: FunctionResultT - A type variable representing the result type of a :class:`Mapper`. + A type variable representing the result type of a :class:`Mapper` when mapping + a :class:`pytato.function.FunctionDefinition`. .. class:: Scalar @@ -165,10 +163,11 @@ class UnsupportedArrayError(ValueError): # {{{ mapper base class ResultT = TypeVar("ResultT") +FunctionResultT = TypeVar("FunctionResultT") P = ParamSpec("P") -class Mapper(Generic[ResultT, P]): +class Mapper(Generic[ResultT, FunctionResultT, P]): """A class that when called with a :class:`pytato.Array` recursively iterates over the DAG, calling the *_mapper_method* of each node. Users of this class are expected to override the methods of this class or create a @@ -194,7 +193,7 @@ def handle_unsupported_array(self, expr: MappedT, raise UnsupportedArrayError( f"{type(self).__name__} cannot handle expressions of type {type(expr)}") - def map_foreign(self, expr: Any, *args: P.args, **kwargs: P.kwargs) -> ResultT: + def map_foreign(self, expr: Any, *args: P.args, **kwargs: P.kwargs) -> Any: """Mapper method that is invoked for an object of class for which a mapper method does not exist in this mapper. """ @@ -203,7 +202,7 @@ def map_foreign(self, expr: Any, *args: P.args, **kwargs: P.kwargs) -> ResultT: def rec(self, expr: ArrayOrNames, *args: P.args, **kwargs: P.kwargs) -> ResultT: """Call the mapper method of *expr* and return the result.""" - method: Callable[..., Array] | None + method: Callable[..., Any] | None try: method = getattr(self, expr._mapper_method) @@ -218,11 +217,25 @@ def rec(self, expr: ArrayOrNames, *args: P.args, **kwargs: P.kwargs) -> ResultT: else: return self.handle_unsupported_array(expr, *args, **kwargs) else: - return self.map_foreign(expr, *args, **kwargs) + return cast("ResultT", self.map_foreign(expr, *args, **kwargs)) assert method is not None return cast("ResultT", method(expr, *args, **kwargs)) + def rec_function_definition( + self, expr: FunctionDefinition, *args: P.args, **kwargs: P.kwargs + ) -> FunctionResultT: + """Call the mapper method of *expr* and return the result.""" + method: Callable[..., Any] | None + + try: + method = self.map_function_definition # type: ignore[attr-defined] + except AttributeError: + return cast("FunctionResultT", self.map_foreign(expr, *args, **kwargs)) + + assert method is not None + return cast("FunctionResultT", method(expr, *args, **kwargs)) + def __call__(self, expr: ArrayOrNames, *args: P.args, **kwargs: P.kwargs) -> ResultT: """Handle the mapping of *expr*.""" @@ -233,23 +246,42 @@ def __call__(self, # {{{ CachedMapper -class CachedMapper(Mapper[ResultT, P]): +class CachedMapper(Mapper[ResultT, FunctionResultT, P]): """Mapper class that maps each node in the DAG exactly once. This loses some information compared to :class:`Mapper` as a node is visited only from one of its predecessors. .. automethod:: get_cache_key + .. automethod:: get_function_definition_cache_key + .. automethod:: clone_for_callee """ - def __init__(self) -> None: + def __init__( + self, + # Arrays are cached separately for each call stack frame, but + # functions are cached globally + _function_cache: dict[Hashable, FunctionResultT] | None = None + ) -> None: super().__init__() self._cache: dict[Hashable, ResultT] = {} + if _function_cache is not None: + function_cache = _function_cache + else: + function_cache = {} + + self._function_cache: dict[Hashable, FunctionResultT] = function_cache + def get_cache_key( self, expr: ArrayOrNames, *args: P.args, **kwargs: P.kwargs ) -> Hashable: return (expr, *args, tuple(sorted(kwargs.items()))) + def get_function_definition_cache_key( + self, expr: FunctionDefinition, *args: P.args, **kwargs: P.kwargs + ) -> Hashable: + return (expr, *args, tuple(sorted(kwargs.items()))) + def rec(self, expr: ArrayOrNames, *args: P.args, **kwargs: P.kwargs) -> ResultT: key = self.get_cache_key(expr, *args, **kwargs) try: @@ -259,12 +291,31 @@ def rec(self, expr: ArrayOrNames, *args: P.args, **kwargs: P.kwargs) -> ResultT: self._cache[key] = result return result + def rec_function_definition( + self, expr: FunctionDefinition, *args: P.args, **kwargs: P.kwargs + ) -> FunctionResultT: + key = self.get_function_definition_cache_key(expr, *args, **kwargs) + try: + return self._function_cache[key] + except KeyError: + result = super().rec_function_definition(expr, *args, **kwargs) + self._function_cache[key] = result + return result + + def clone_for_callee( + self, function: FunctionDefinition) -> Self: + """ + Called to clone *self* before starting traversal of a + :class:`pytato.function.FunctionDefinition`. + """ + return type(self)(_function_cache=self._function_cache) + # }}} # {{{ TransformMapper -class TransformMapper(CachedMapper[ArrayOrNames, []]): +class TransformMapper(CachedMapper[ArrayOrNames, FunctionDefinition, []]): """Base class for mappers that transform :class:`pytato.array.Array`\\ s into other :class:`pytato.array.Array`\\ s. @@ -272,28 +323,20 @@ class TransformMapper(CachedMapper[ArrayOrNames, []]): arrays (e.g., calling :meth:`~CachedMapper.get_cache_key` on them). Does not implement default mapper methods; for that, see :class:`CopyMapper`. - .. automethod:: clone_for_callee """ def rec_ary(self, expr: Array) -> Array: res = self.rec(expr) assert isinstance(res, Array) return res - def clone_for_callee(self, function: FunctionDefinition) -> Self: - """ - Called to clone *self* before starting traversal of a - :class:`pytato.function.FunctionDefinition`. - """ - return type(self)() - # }}} # {{{ TransformMapperWithExtraArgs class TransformMapperWithExtraArgs( - CachedMapper[ArrayOrNames, P], - Mapper[ArrayOrNames, P] + CachedMapper[ArrayOrNames, FunctionDefinition, P], + Mapper[ArrayOrNames, FunctionDefinition, P] ): """ Similar to :class:`TransformMapper`, but each mapper method takes extra @@ -301,21 +344,12 @@ class TransformMapperWithExtraArgs( The logic in :class:`TransformMapper` purposely does not take the extra arguments to keep the cost of its each call frame low. - - .. automethod:: clone_for_callee """ def rec_ary(self, expr: Array, *args: P.args, **kwargs: P.kwargs) -> Array: res = self.rec(expr, *args, **kwargs) assert isinstance(res, Array) return res - def clone_for_callee(self, function: FunctionDefinition) -> Self: - """ - Called to clone *self* before starting traversal of a - :class:`pytato.function.FunctionDefinition`. - """ - return type(self)() - # }}} @@ -492,7 +526,6 @@ def map_distributed_recv(self, expr: DistributedRecv) -> Array: dtype=expr.dtype, tags=expr.tags, axes=expr.axes, non_equality_tags=expr.non_equality_tags) - @memoize_method def map_function_definition(self, expr: FunctionDefinition) -> FunctionDefinition: # spawn a new mapper to avoid unsound cache hits, since the namespace of the @@ -503,7 +536,7 @@ def map_function_definition(self, return dataclasses.replace(expr, returns=immutabledict(new_returns)) def map_call(self, expr: Call) -> AbstractResultWithNamedArrays: - return Call(self.map_function_definition(expr.function), + return Call(self.rec_function_definition(expr.function), immutabledict({name: self.rec(bnd) for name, bnd in expr.bindings.items()}), tags=expr.tags, @@ -717,7 +750,7 @@ def map_function_definition( def map_call(self, expr: Call, *args: P.args, **kwargs: P.kwargs) -> AbstractResultWithNamedArrays: - return Call(self.map_function_definition(expr.function, *args, **kwargs), + return Call(self.rec_function_definition(expr.function, *args, **kwargs), immutabledict({name: self.rec(bnd, *args, **kwargs) for name, bnd in expr.bindings.items()}), tags=expr.tags, @@ -734,7 +767,7 @@ def map_named_call_result(self, expr: NamedCallResult, # {{{ CombineMapper -class CombineMapper(Mapper[ResultT, []]): +class CombineMapper(Mapper[ResultT, FunctionResultT, []]): """ Abstract mapper that recursively combines the results of user nodes of a given expression. @@ -744,6 +777,9 @@ class CombineMapper(Mapper[ResultT, []]): def __init__(self) -> None: super().__init__() self.cache: dict[ArrayOrNames, ResultT] = {} + # Don't need to pass function cache as argument here, because unlike + # CachedMapper we're not creating a new mapper for each call + self.function_cache: dict[FunctionDefinition, FunctionResultT] = {} def rec_idx_or_size_tuple(self, situp: tuple[IndexOrShapeExpr, ...] ) -> tuple[ResultT, ...]: @@ -756,6 +792,14 @@ def rec(self, expr: ArrayOrNames) -> ResultT: self.cache[expr] = result return result + def rec_function_definition( + self, expr: FunctionDefinition) -> FunctionResultT: + if expr in self.function_cache: + return self.function_cache[expr] + result: FunctionResultT = super().rec_function_definition(expr) + self.function_cache[expr] = result + return result + def __call__(self, expr: ArrayOrNames) -> ResultT: return self.rec(expr) @@ -839,8 +883,7 @@ def map_distributed_send_ref_holder( def map_distributed_recv(self, expr: DistributedRecv) -> ResultT: return self.combine(*self.rec_idx_or_size_tuple(expr.shape)) - @memoize_method - def map_function_definition(self, expr: FunctionDefinition) -> ResultT: + def map_function_definition(self, expr: FunctionDefinition) -> FunctionResultT: raise NotImplementedError("Combining results from a callee expression" " is context-dependent. Derived classes" " must override map_function_definition.") @@ -858,7 +901,7 @@ def map_named_call_result(self, expr: NamedCallResult) -> ResultT: # {{{ DependencyMapper -class DependencyMapper(CombineMapper[R]): +class DependencyMapper(CombineMapper[R, R]): """ Maps a :class:`pytato.array.Array` to a :class:`frozenset` of :class:`pytato.array.Array`'s it depends on. @@ -920,14 +963,13 @@ def map_distributed_send_ref_holder( def map_distributed_recv(self, expr: DistributedRecv) -> R: return self.combine(frozenset([expr]), super().map_distributed_recv(expr)) - @memoize_method def map_function_definition(self, expr: FunctionDefinition) -> R: # do not include arrays from the function's body as it would involve # putting arrays from different namespaces into the same collection. return frozenset() def map_call(self, expr: Call) -> R: - return self.combine(self.map_function_definition(expr.function), + return self.combine(self.rec_function_definition(expr.function), *[self.rec(bnd) for bnd in expr.bindings.values()]) def map_named_call_result(self, expr: NamedCallResult) -> R: @@ -958,7 +1000,8 @@ def combine(self, *args: frozenset[Array]) -> frozenset[Array]: # {{{ InputGatherer -class InputGatherer(CombineMapper[frozenset[InputArgumentBase]]): +class InputGatherer( + CombineMapper[frozenset[InputArgumentBase], frozenset[InputArgumentBase]]): """ Mapper to combine all instances of :class:`pytato.array.InputArgumentBase` that an array expression depends on. @@ -977,7 +1020,6 @@ def map_data_wrapper(self, expr: DataWrapper) -> frozenset[InputArgumentBase]: def map_size_param(self, expr: SizeParam) -> frozenset[SizeParam]: return frozenset([expr]) - @memoize_method def map_function_definition(self, expr: FunctionDefinition ) -> frozenset[InputArgumentBase]: # get rid of placeholders local to the function. @@ -999,7 +1041,7 @@ def map_function_definition(self, expr: FunctionDefinition return frozenset(result) def map_call(self, expr: Call) -> frozenset[InputArgumentBase]: - return self.combine(self.map_function_definition(expr.function), + return self.combine(self.rec_function_definition(expr.function), *[ self.rec(bnd) for name, bnd in sorted(expr.bindings.items())]) @@ -1009,7 +1051,8 @@ def map_call(self, expr: Call) -> frozenset[InputArgumentBase]: # {{{ SizeParamGatherer -class SizeParamGatherer(CombineMapper[frozenset[SizeParam]]): +class SizeParamGatherer( + CombineMapper[frozenset[SizeParam], frozenset[SizeParam]]): """ Mapper to combine all instances of :class:`pytato.array.SizeParam` that an array expression depends on. @@ -1022,14 +1065,13 @@ def combine(self, *args: frozenset[SizeParam] def map_size_param(self, expr: SizeParam) -> frozenset[SizeParam]: return frozenset([expr]) - @memoize_method def map_function_definition(self, expr: FunctionDefinition ) -> frozenset[SizeParam]: return self.combine(*[self.rec(ret) for ret in expr.returns.values()]) def map_call(self, expr: Call) -> frozenset[SizeParam]: - return self.combine(self.map_function_definition(expr.function), + return self.combine(self.rec_function_definition(expr.function), *[ self.rec(bnd) for name, bnd in sorted(expr.bindings.items())]) @@ -1039,7 +1081,7 @@ def map_call(self, expr: Call) -> frozenset[SizeParam]: # {{{ WalkMapper -class WalkMapper(Mapper[None, P]): +class WalkMapper(Mapper[None, None, P]): """ A mapper that walks over all the arrays in a :class:`pytato.Array`. @@ -1228,7 +1270,7 @@ def map_call(self, expr: Call, *args: P.args, **kwargs: P.kwargs) -> None: if not self.visit(expr, *args, **kwargs): return - self.map_function_definition(expr.function, *args, **kwargs) + self.rec_function_definition(expr.function, *args, **kwargs) for bnd in expr.bindings.values(): self.rec(bnd, *args, **kwargs) @@ -1255,24 +1297,46 @@ class CachedWalkMapper(WalkMapper[P]): one of its predecessors. """ - def __init__(self) -> None: + def __init__( + self, + _visited_functions: set[Any] | None = None) -> None: super().__init__() - self._visited_nodes: set[Any] = set() + self._visited_arrays_or_names: set[Any] = set() + + if _visited_functions is not None: + visited_functions = _visited_functions + else: + visited_functions = set() + + self._visited_functions: set[Any] = visited_functions def get_cache_key(self, expr: ArrayOrNames, *args: Any, **kwargs: Any) -> Any: raise NotImplementedError + def get_function_definition_cache_key( + self, expr: FunctionDefinition, *args: Any, **kwargs: Any) -> Any: + raise NotImplementedError + def rec(self, expr: ArrayOrNames, *args: Any, **kwargs: Any ) -> None: cache_key = self.get_cache_key(expr, *args, **kwargs) - if cache_key in self._visited_nodes: + if cache_key in self._visited_arrays_or_names: return super().rec(expr, *args, **kwargs) - self._visited_nodes.add(cache_key) + self._visited_arrays_or_names.add(cache_key) + + def rec_function_definition(self, expr: FunctionDefinition, + *args: Any, **kwargs: Any) -> None: + cache_key = self.get_function_definition_cache_key(expr, *args, **kwargs) + if cache_key in self._visited_functions: + return + + super().rec_function_definition(expr, *args, **kwargs) + self._visited_functions.add(cache_key) def clone_for_callee(self, function: FunctionDefinition) -> Self: - return type(self)() + return type(self)(_visited_functions=self._visited_functions) # }}} @@ -1291,8 +1355,10 @@ class TopoSortMapper(CachedWalkMapper[[]]): :class:`~pytato.function.FunctionDefinition`. """ - def __init__(self) -> None: - super().__init__() + def __init__( + self, + _visited_functions: set[Any] | None = None) -> None: + super().__init__(_visited_functions=_visited_functions) self.topological_order: list[Array] = [] def get_cache_key(self, expr: ArrayOrNames) -> int: @@ -1301,7 +1367,6 @@ def get_cache_key(self, expr: ArrayOrNames) -> int: def post_visit(self, expr: Any) -> None: self.topological_order.append(expr) - @memoize_method def map_function_definition(self, expr: FunctionDefinition) -> None: # do nothing as it includes arrays from a different namespace. return @@ -1317,13 +1382,17 @@ class CachedMapAndCopyMapper(CopyMapper): traversals are memoized i.e. each node is mapped via *map_fn* exactly once. """ - def __init__(self, map_fn: Callable[[ArrayOrNames], ArrayOrNames]) -> None: - super().__init__() + def __init__( + self, + map_fn: Callable[[ArrayOrNames], ArrayOrNames], + _function_cache: dict[Hashable, FunctionDefinition] | None = None + ) -> None: + super().__init__(_function_cache=_function_cache) self.map_fn: Callable[[ArrayOrNames], ArrayOrNames] = map_fn def clone_for_callee( self, function: FunctionDefinition) -> Self: - return type(self)(self.map_fn) + return type(self)(self.map_fn, _function_cache=self._function_cache) def rec(self, expr: ArrayOrNames) -> ArrayOrNames: if expr in self._cache: @@ -1371,7 +1440,7 @@ def _materialize_if_mpms(expr: Array, return MPMSMaterializerAccumulator(materialized_predecessors, expr) -class MPMSMaterializer(Mapper[MPMSMaterializerAccumulator, []]): +class MPMSMaterializer(Mapper[MPMSMaterializerAccumulator, None, []]): """ See :func:`materialize_with_mpms` for an explanation. @@ -1648,7 +1717,7 @@ def materialize_with_mpms(expr: DictOfNamedArrays) -> DictOfNamedArrays: # {{{ UsersCollector -class UsersCollector(CachedMapper[None, []]): +class UsersCollector(CachedMapper[None, None, []]): """ Maps a graph to a dictionary representation mapping a node to its users, i.e. all the nodes using its value. @@ -1772,7 +1841,6 @@ def map_distributed_send_ref_holder( def map_distributed_recv(self, expr: DistributedRecv) -> None: self.rec_idx_or_size_tuple(expr, expr.shape) - @memoize_method def map_function_definition(self, expr: FunctionDefinition) -> None: raise AssertionError("Control shouldn't reach at this point." " Instantiate another UsersCollector to" diff --git a/pytato/transform/calls.py b/pytato/transform/calls.py index bc3d69909..ea08a6ac6 100644 --- a/pytato/transform/calls.py +++ b/pytato/transform/calls.py @@ -58,6 +58,7 @@ class PlaceholderSubstitutor(CopyMapper): """ def __init__(self, substitutions: Mapping[str, Array]) -> None: + # Ignoring function cache, not needed super().__init__() self.substitutions = substitutions diff --git a/pytato/transform/lower_to_index_lambda.py b/pytato/transform/lower_to_index_lambda.py index e952e4e9e..ec00e9f45 100644 --- a/pytato/transform/lower_to_index_lambda.py +++ b/pytato/transform/lower_to_index_lambda.py @@ -637,7 +637,7 @@ def map_axis_permutation(self, expr: AxisPermutation) -> IndexLambda: non_equality_tags=expr.non_equality_tags) -class ToIndexLambdaMapper(Mapper[Array, []], ToIndexLambdaMixin): +class ToIndexLambdaMapper(Mapper[Array, None, []], ToIndexLambdaMixin): def handle_unsupported_array(self, expr: Any) -> Any: raise CannotBeLoweredToIndexLambda(type(expr)) diff --git a/pytato/transform/metadata.py b/pytato/transform/metadata.py index 8cd520d94..4154c5962 100644 --- a/pytato/transform/metadata.py +++ b/pytato/transform/metadata.py @@ -90,9 +90,9 @@ if TYPE_CHECKING: - from collections.abc import Collection, Mapping + from collections.abc import Collection, Hashable, Mapping - from pytato.function import NamedCallResult + from pytato.function import FunctionDefinition, NamedCallResult from pytato.loopy import LoopyCall @@ -101,7 +101,7 @@ # {{{ AxesTagsEquationCollector -class AxesTagsEquationCollector(Mapper[None, []]): +class AxesTagsEquationCollector(Mapper[None, None, []]): r""" Records equations arising from operand/output axes equivalence for an array operation. This mapper implements a default set of propagation rules, @@ -595,8 +595,9 @@ class AxisTagAttacher(CopyMapper): """ def __init__(self, axis_to_tags: Mapping[tuple[Array, int], Collection[Tag]], - tag_corresponding_redn_descr: bool): - super().__init__() + tag_corresponding_redn_descr: bool, + _function_cache: dict[Hashable, FunctionDefinition] | None = None): + super().__init__(_function_cache=_function_cache) self.axis_to_tags: Mapping[tuple[Array, int], Collection[Tag]] = axis_to_tags self.tag_corresponding_redn_descr: bool = tag_corresponding_redn_descr diff --git a/pytato/utils.py b/pytato/utils.py index aedd0dc05..78e5463b3 100644 --- a/pytato/utils.py +++ b/pytato/utils.py @@ -269,7 +269,7 @@ def cast_to_result_type( # {{{ dim_to_index_lambda_components -class ShapeExpressionMapper(CachedMapper[ScalarExpression, []]): +class ShapeExpressionMapper(CachedMapper[ScalarExpression, None, []]): """ Mapper that takes a shape component and returns it as a scalar expression. """ @@ -372,7 +372,7 @@ def are_shapes_equal(shape1: ShapeType, shape2: ShapeType) -> bool: # {{{ ShapeToISLExpressionMapper -class ShapeToISLExpressionMapper(CachedMapper[isl.Aff, []]): +class ShapeToISLExpressionMapper(CachedMapper[isl.Aff, None, []]): """ Mapper that takes a shape component and returns it as :class:`isl.Aff`. """ diff --git a/pytato/visualization/dot.py b/pytato/visualization/dot.py index 7e685aba9..3d4512079 100644 --- a/pytato/visualization/dot.py +++ b/pytato/visualization/dot.py @@ -178,7 +178,7 @@ def stringify_shape(shape: ShapeType) -> str: return "(" + ", ".join(components) + ")" -class ArrayToDotNodeInfoMapper(CachedMapper[ArrayOrNames, []]): +class ArrayToDotNodeInfoMapper(CachedMapper[None, None, []]): def __init__(self) -> None: super().__init__() self.node_to_dot: dict[ArrayOrNames, _DotNodeInfo] = {} diff --git a/pytato/visualization/fancy_placeholder_data_flow.py b/pytato/visualization/fancy_placeholder_data_flow.py index 3d06309fd..ace1c32b4 100644 --- a/pytato/visualization/fancy_placeholder_data_flow.py +++ b/pytato/visualization/fancy_placeholder_data_flow.py @@ -23,6 +23,7 @@ Placeholder, Stack, ) +from pytato.function import FunctionDefinition from pytato.transform import CachedMapper @@ -100,7 +101,7 @@ def _get_dot_node_from_predecessors(node_id: str, return NoShowNode(), frozenset() -class FancyDotWriter(CachedMapper[_FancyDotWriterNode, []]): +class FancyDotWriter(CachedMapper[_FancyDotWriterNode, FunctionDefinition, []]): def __init__(self) -> None: super().__init__() self.vng = UniqueNameGenerator() diff --git a/test/testlib.py b/test/testlib.py index 36857197b..59d4f9f50 100644 --- a/test/testlib.py +++ b/test/testlib.py @@ -32,7 +32,7 @@ # {{{ tools for comparison to numpy -class NumpyBasedEvaluator(Mapper[Any, []]): +class NumpyBasedEvaluator(Mapper[Any, None, []]): """ Mapper to return the result according to an eager evaluation array package *np*.