diff --git a/pytato/codegen.py b/pytato/codegen.py index e0731cee6..b74222ee8 100644 --- a/pytato/codegen.py +++ b/pytato/codegen.py @@ -139,9 +139,10 @@ def __init__( self, target: Target, kernels_seen: dict[str, lp.LoopKernel] | None = None, + _cache: CodeGenPreprocessor._CacheT | None = None, _function_cache: CodeGenPreprocessor._FunctionCacheT | None = None ) -> None: - super().__init__(_function_cache=_function_cache) + super().__init__(_cache=_cache, _function_cache=_function_cache) self.bound_arguments: dict[str, DataInterface] = {} self.var_name_gen: UniqueNameGenerator = UniqueNameGenerator() self.target = target diff --git a/pytato/distributed/partition.py b/pytato/distributed/partition.py index ccb681c40..8bd78c0da 100644 --- a/pytato/distributed/partition.py +++ b/pytato/distributed/partition.py @@ -294,10 +294,11 @@ def __init__(self, recvd_ary_to_name: Mapping[Array, str], sptpo_ary_to_name: Mapping[Array, str], name_to_output: Mapping[str, Array], + _cache: _DistributedInputReplacer._CacheT | None = None, _function_cache: _DistributedInputReplacer._FunctionCacheT | None = None, ) -> None: - super().__init__(_function_cache=_function_cache) + super().__init__(_cache=_cache, _function_cache=_function_cache) self.recvd_ary_to_name = recvd_ary_to_name self.sptpo_ary_to_name = sptpo_ary_to_name diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index 1b38361a0..8df903c26 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -357,8 +357,7 @@ class CachedMapper(Mapper[ResultT, FunctionResultT, P]): def __init__( self, - # Arrays are cached separately for each call stack frame, but - # functions are cached globally + _cache: CachedMapper._CacheT[ResultT, P] | None = None, _function_cache: CachedMapper._FunctionCacheT[FunctionResultT, P] | None = None ) -> None: @@ -369,7 +368,9 @@ def key_func( *args: P.args, **kwargs: P.kwargs) -> Hashable: return (expr, args, tuple(sorted(kwargs.items()))) - self._cache: CachedMapper._CacheT[ResultT, P] = CachedMapperCache(key_func) + self._cache: CachedMapper._CacheT[ResultT, P] = ( + _cache if _cache is not None + else CachedMapperCache(key_func)) self._function_cache: CachedMapper._FunctionCacheT[FunctionResultT, P] = ( _function_cache if _function_cache is not None @@ -403,6 +404,7 @@ def clone_for_callee( Called to clone *self* before starting traversal of a :class:`pytato.function.FunctionDefinition`. """ + # Functions are cached globally, but arrays aren't return type(self)(_function_cache=self._function_cache) # }}} @@ -877,13 +879,15 @@ class CombineMapper(Mapper[ResultT, FunctionResultT, []]): def __init__( self, + _cache: CombineMapper._CacheT[ResultT] | None = None, _function_cache: CombineMapper._FunctionCacheT[FunctionResultT] | None = None ) -> None: super().__init__() - self.cache: CombineMapper._CacheT[ResultT] = CachedMapperCache( - lambda expr: expr) + self.cache: CombineMapper._CacheT[ResultT] = ( + _cache if _cache is not None + else CachedMapperCache(lambda expr: expr)) self.function_cache: CombineMapper._FunctionCacheT[FunctionResultT] = ( _function_cache if _function_cache is not None @@ -1495,9 +1499,10 @@ class CachedMapAndCopyMapper(CopyMapper): def __init__( self, map_fn: Callable[[ArrayOrNames], ArrayOrNames], + _cache: CachedMapAndCopyMapper._CacheT | None = None, _function_cache: CachedMapAndCopyMapper._FunctionCacheT | None = None ) -> None: - super().__init__(_function_cache=_function_cache) + super().__init__(_cache=_cache, _function_cache=_function_cache) self.map_fn: Callable[[ArrayOrNames], ArrayOrNames] = map_fn def clone_for_callee( diff --git a/pytato/transform/metadata.py b/pytato/transform/metadata.py index 7c3b06b8e..3cc177445 100644 --- a/pytato/transform/metadata.py +++ b/pytato/transform/metadata.py @@ -596,8 +596,9 @@ class AxisTagAttacher(CopyMapper): def __init__(self, axis_to_tags: Mapping[tuple[Array, int], Collection[Tag]], tag_corresponding_redn_descr: bool, + _cache: AxisTagAttacher._CacheT | None = None, _function_cache: AxisTagAttacher._FunctionCacheT | None = None): - super().__init__(_function_cache=_function_cache) + super().__init__(_cache=_cache, _function_cache=_function_cache) self.axis_to_tags: Mapping[tuple[Array, int], Collection[Tag]] = axis_to_tags self.tag_corresponding_redn_descr: bool = tag_corresponding_redn_descr