Skip to content

Commit

Permalink
add cache argument to mapper constructors so derived classes can crea…
Browse files Browse the repository at this point in the history
…te their own cache classes
  • Loading branch information
majosm committed Jan 16, 2025
1 parent 19960f2 commit c25dce2
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 9 deletions.
3 changes: 2 additions & 1 deletion pytato/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,9 +139,10 @@ def __init__(
self,
target: Target,
kernels_seen: dict[str, lp.LoopKernel] | None = None,
_cache: CodeGenPreprocessor._CacheT | None = None,
_function_cache: CodeGenPreprocessor._FunctionCacheT | None = None
) -> None:
super().__init__(_function_cache=_function_cache)
super().__init__(_cache=_cache, _function_cache=_function_cache)
self.bound_arguments: dict[str, DataInterface] = {}
self.var_name_gen: UniqueNameGenerator = UniqueNameGenerator()
self.target = target
Expand Down
3 changes: 2 additions & 1 deletion pytato/distributed/partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,10 +294,11 @@ def __init__(self,
recvd_ary_to_name: Mapping[Array, str],
sptpo_ary_to_name: Mapping[Array, str],
name_to_output: Mapping[str, Array],
_cache: _DistributedInputReplacer._CacheT | None = None,
_function_cache:
_DistributedInputReplacer._FunctionCacheT | None = None,
) -> None:
super().__init__(_function_cache=_function_cache)
super().__init__(_cache=_cache, _function_cache=_function_cache)

self.recvd_ary_to_name = recvd_ary_to_name
self.sptpo_ary_to_name = sptpo_ary_to_name
Expand Down
17 changes: 11 additions & 6 deletions pytato/transform/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,8 +357,7 @@ class CachedMapper(Mapper[ResultT, FunctionResultT, P]):

def __init__(
self,
# Arrays are cached separately for each call stack frame, but
# functions are cached globally
_cache: CachedMapper._CacheT[ResultT, P] | None = None,
_function_cache:
CachedMapper._FunctionCacheT[FunctionResultT, P] | None = None
) -> None:
Expand All @@ -369,7 +368,9 @@ def key_func(
*args: P.args, **kwargs: P.kwargs) -> Hashable:
return (expr, args, tuple(sorted(kwargs.items())))

self._cache: CachedMapper._CacheT[ResultT, P] = CachedMapperCache(key_func)
self._cache: CachedMapper._CacheT[ResultT, P] = (
_cache if _cache is not None
else CachedMapperCache(key_func))

self._function_cache: CachedMapper._FunctionCacheT[FunctionResultT, P] = (
_function_cache if _function_cache is not None
Expand Down Expand Up @@ -403,6 +404,7 @@ def clone_for_callee(
Called to clone *self* before starting traversal of a
:class:`pytato.function.FunctionDefinition`.
"""
# Functions are cached globally, but arrays aren't
return type(self)(_function_cache=self._function_cache)

# }}}
Expand Down Expand Up @@ -877,13 +879,15 @@ class CombineMapper(Mapper[ResultT, FunctionResultT, []]):

def __init__(
self,
_cache: CombineMapper._CacheT[ResultT] | None = None,
_function_cache:
CombineMapper._FunctionCacheT[FunctionResultT] | None = None
) -> None:
super().__init__()

self.cache: CombineMapper._CacheT[ResultT] = CachedMapperCache(
lambda expr: expr)
self.cache: CombineMapper._CacheT[ResultT] = (
_cache if _cache is not None
else CachedMapperCache(lambda expr: expr))

self.function_cache: CombineMapper._FunctionCacheT[FunctionResultT] = (
_function_cache if _function_cache is not None
Expand Down Expand Up @@ -1495,9 +1499,10 @@ class CachedMapAndCopyMapper(CopyMapper):
def __init__(
self,
map_fn: Callable[[ArrayOrNames], ArrayOrNames],
_cache: CachedMapAndCopyMapper._CacheT | None = None,
_function_cache: CachedMapAndCopyMapper._FunctionCacheT | None = None
) -> None:
super().__init__(_function_cache=_function_cache)
super().__init__(_cache=_cache, _function_cache=_function_cache)
self.map_fn: Callable[[ArrayOrNames], ArrayOrNames] = map_fn

def clone_for_callee(
Expand Down
3 changes: 2 additions & 1 deletion pytato/transform/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -596,8 +596,9 @@ class AxisTagAttacher(CopyMapper):
def __init__(self,
axis_to_tags: Mapping[tuple[Array, int], Collection[Tag]],
tag_corresponding_redn_descr: bool,
_cache: AxisTagAttacher._CacheT | None = None,
_function_cache: AxisTagAttacher._FunctionCacheT | None = None):
super().__init__(_function_cache=_function_cache)
super().__init__(_cache=_cache, _function_cache=_function_cache)
self.axis_to_tags: Mapping[tuple[Array, int], Collection[Tag]] = axis_to_tags
self.tag_corresponding_redn_descr: bool = tag_corresponding_redn_descr

Expand Down

0 comments on commit c25dce2

Please sign in to comment.