Skip to content

Commit

Permalink
use a class for CachedMapper caches instead of using a dict directly
Browse files Browse the repository at this point in the history
  • Loading branch information
majosm committed Jan 24, 2025
1 parent c09e9f1 commit 9cbe7f4
Show file tree
Hide file tree
Showing 6 changed files with 204 additions and 48 deletions.
2 changes: 2 additions & 0 deletions doc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,4 +56,6 @@
["py:class", r"P\.kwargs"],
["py:class", r"lp\.LoopKernel"],
["py:class", r"_dtype_any"],
["py:class", r"(.+)\._CacheT"],
["py:class", r"(.+)\._FunctionCacheT"],
]
8 changes: 5 additions & 3 deletions pytato/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,13 @@
CachedWalkMapper,
CopyMapper,
SubsetDependencyMapper,
TransformMapperCache,
)
from pytato.transform.lower_to_index_lambda import ToIndexLambdaMixin


if TYPE_CHECKING:
from collections.abc import Hashable, Mapping
from collections.abc import Mapping

from pytato.function import FunctionDefinition, NamedCallResult
from pytato.target import Target
Expand Down Expand Up @@ -140,9 +141,10 @@ def __init__(
self,
target: Target,
kernels_seen: dict[str, lp.LoopKernel] | None = None,
_function_cache: dict[Hashable, FunctionDefinition] | None = None
_cache: TransformMapperCache[ArrayOrNames, []] | None = None,
_function_cache: TransformMapperCache[FunctionDefinition, []] | None = None
) -> None:
super().__init__(_function_cache=_function_cache)
super().__init__(_cache=_cache, _function_cache=_function_cache)
self.bound_arguments: dict[str, DataInterface] = {}
self.var_name_gen: UniqueNameGenerator = UniqueNameGenerator()
self.target = target
Expand Down
16 changes: 11 additions & 5 deletions pytato/distributed/partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@
CachedWalkMapper,
CombineMapper,
CopyMapper,
TransformMapperCache,
_verify_is_array,
)

Expand Down Expand Up @@ -239,9 +240,11 @@ def __init__(self,
recvd_ary_to_name: Mapping[Array, str],
sptpo_ary_to_name: Mapping[Array, str],
name_to_output: Mapping[str, Array],
_function_cache: dict[Hashable, FunctionDefinition] | None = None,
_cache: TransformMapperCache[ArrayOrNames, []] | None = None,
_function_cache:
TransformMapperCache[FunctionDefinition, []] | None = None,
) -> None:
super().__init__(_function_cache=_function_cache)
super().__init__(_cache=_cache, _function_cache=_function_cache)

self.recvd_ary_to_name = recvd_ary_to_name
self.sptpo_ary_to_name = sptpo_ary_to_name
Expand All @@ -255,7 +258,10 @@ def clone_for_callee(
self, function: FunctionDefinition) -> _DistributedInputReplacer:
# Function definitions aren't allowed to contain receives,
# stored arrays promoted to part outputs, or part outputs
return type(self)({}, {}, {}, _function_cache=self._function_cache)
return type(self)(
{}, {}, {},
_function_cache=cast(
"TransformMapperCache[FunctionDefinition, []]", self._function_cache))

def map_placeholder(self, expr: Placeholder) -> Placeholder:
self.user_input_names.add(expr.name)
Expand Down Expand Up @@ -288,9 +294,9 @@ def map_distributed_send(self, expr: DistributedSend) -> DistributedSend:
return new_send

def rec(self, expr: ArrayOrNames) -> ArrayOrNames:
key = self.get_cache_key(expr)
key = self._cache.get_key(expr)
try:
return self._cache[key]
return self._cache.retrieve(expr, key=key)
except KeyError:
pass

Expand Down
207 changes: 176 additions & 31 deletions pytato/transform/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,10 @@

__doc__ = """
.. autoclass:: Mapper
.. autoclass:: MapperCache
.. autoclass:: CachedMapperCache
.. autoclass:: CachedMapper
.. autoclass:: TransformMapperCache
.. autoclass:: TransformMapper
.. autoclass:: TransformMapperWithExtraArgs
.. autoclass:: CopyMapper
Expand Down Expand Up @@ -150,9 +153,41 @@
A type variable representing the result type of a :class:`Mapper` when mapping
a :class:`pytato.function.FunctionDefinition`.
.. class:: CacheExprT
A type variable representing an input from which to compute a cache key in order
to cache a result.
.. class:: CacheKeyT
A type variable representing a key computed from an input expression.
.. class:: CacheResultT
A type variable representing a result to be cached.
.. class:: Scalar
See :data:`pymbolic.Scalar`.
.. class:: P
A :class:`typing.ParamSpec` used to annotate `*args` and `**kwargs`.
.. class:: _OtherResultT
Duplicate of :class:`pytato.transform.ResultT`, used for defining class-local
type aliases.
.. class:: _OtherFunctionResultT
Duplicate of :class:`pytato.transform.FunctionResultT`, used for defining
class-local type aliases.
.. class:: _OtherP
Duplicate of :class:`P`, used for defining class-local type aliases.
"""

transform_logger = logging.getLogger(__file__)
Expand All @@ -172,6 +207,12 @@ class ForeignObjectError(ValueError):
FunctionResultT = TypeVar("FunctionResultT")
P = ParamSpec("P")

# Duplicates of type variables, mainly used for defining aliases of parameterized
# types inside mapper classes
_OtherResultT = TypeVar("_OtherResultT")
_OtherFunctionResultT = TypeVar("_OtherFunctionResultT")
_OtherP = ParamSpec("_OtherP")


def _verify_is_array(expr: ArrayOrNames) -> Array:
assert isinstance(expr, Array)
Expand Down Expand Up @@ -252,6 +293,84 @@ def __call__(self,

# {{{ CachedMapper

CacheExprT = TypeVar("CacheExprT")
CacheKeyT = TypeVar("CacheKeyT")
CacheResultT = TypeVar("CacheResultT")


class MapperCache(Generic[CacheExprT, CacheKeyT, CacheResultT, P]):
"""
Cache for mappers.
.. automethod:: __init__
.. method:: get_key
Compute the key for an input expression.
.. automethod:: add
.. automethod:: retrieve
.. automethod:: clear
"""
def __init__(
self,
# FIXME: Figure out the right way to type annotate this
key_func: Callable[..., CacheKeyT]) -> None:
"""
Initialize the cache.
:arg key_func: Function to compute a hashable cache key from an input
expression and any extra arguments.
"""
self.get_key = key_func

self._expr_key_to_result: dict[CacheKeyT, CacheResultT] = {}

def add(
self,
key_inputs:
CacheExprT
# FIXME: Figure out the right way to type annotate these
| tuple[CacheExprT, tuple[Any, ...], dict[str, Any]],
result: CacheResultT,
key: CacheKeyT | None = None) -> CacheResultT:
"""Cache a mapping result."""
if key is None:
if isinstance(key_inputs, tuple):
expr, key_args, key_kwargs = key_inputs
key = self.get_key(expr, *key_args, **key_kwargs)
else:
key = self.get_key(key_inputs)

self._expr_key_to_result[key] = result

return result

def retrieve(
self,
key_inputs:
CacheExprT
# FIXME: Figure out the right way to type annotate these
| tuple[CacheExprT, tuple[Any, ...], dict[str, Any]],
key: CacheKeyT | None = None) -> CacheResultT:
"""Retrieve the cached mapping result."""
if key is None:
if isinstance(key_inputs, tuple):
expr, key_args, key_kwargs = key_inputs
key = self.get_key(expr, *key_args, **key_kwargs)
else:
key = self.get_key(key_inputs)

return self._expr_key_to_result[key]

def clear(self) -> None:
"""Reset the cache."""
self._expr_key_to_result = {}


class CachedMapperCache(MapperCache[CacheExprT, Hashable, CacheResultT, P]):
pass


class CachedMapper(Mapper[ResultT, FunctionResultT, P]):
"""Mapper class that maps each node in the DAG exactly once. This loses some
information compared to :class:`Mapper` as a node is visited only from
Expand All @@ -261,18 +380,23 @@ class CachedMapper(Mapper[ResultT, FunctionResultT, P]):
.. automethod:: get_function_definition_cache_key
.. automethod:: clone_for_callee
"""

def __init__(
self,
# Arrays are cached separately for each call stack frame, but
# functions are cached globally
_function_cache: dict[Hashable, FunctionResultT] | None = None
_cache:
CachedMapperCache[ArrayOrNames, ResultT, P] | None = None,
_function_cache:
CachedMapperCache[FunctionDefinition, FunctionResultT, P] | None = None
) -> None:
super().__init__()
self._cache: dict[Hashable, ResultT] = {}

self._function_cache: dict[Hashable, FunctionResultT] = \
_function_cache if _function_cache is not None else {}
self._cache: CachedMapperCache[ArrayOrNames, ResultT, P] = (
_cache if _cache is not None
else CachedMapperCache(self.get_cache_key))

self._function_cache: CachedMapperCache[
FunctionDefinition, FunctionResultT, P] = (
_function_cache if _function_cache is not None
else CachedMapperCache(self.get_function_definition_cache_key))

def get_cache_key(
self, expr: ArrayOrNames, *args: P.args, **kwargs: P.kwargs
Expand All @@ -285,48 +409,59 @@ def get_function_definition_cache_key(
return (expr, *args, tuple(sorted(kwargs.items())))

def rec(self, expr: ArrayOrNames, *args: P.args, **kwargs: P.kwargs) -> ResultT:
key = self.get_cache_key(expr, *args, **kwargs)
key = self._cache.get_key(expr, *args, **kwargs)
try:
return self._cache[key]
return self._cache.retrieve((expr, args, kwargs), key=key)
except KeyError:
result = super().rec(expr, *args, **kwargs)
self._cache[key] = result
return result
return self._cache.add(
(expr, args, kwargs),
super().rec(expr, *args, **kwargs),
key=key)

def rec_function_definition(
self, expr: FunctionDefinition, *args: P.args, **kwargs: P.kwargs
) -> FunctionResultT:
key = self.get_function_definition_cache_key(expr, *args, **kwargs)
key = self._function_cache.get_key(expr, *args, **kwargs)
try:
return self._function_cache[key]
return self._function_cache.retrieve((expr, args, kwargs), key=key)
except KeyError:
result = super().rec_function_definition(expr, *args, **kwargs)
self._function_cache[key] = result
return result
return self._function_cache.add(
(expr, args, kwargs),
super().rec_function_definition(expr, *args, **kwargs),
key=key)

def clone_for_callee(
self, function: FunctionDefinition) -> Self:
"""
Called to clone *self* before starting traversal of a
:class:`pytato.function.FunctionDefinition`.
"""
# Functions are cached globally, but arrays aren't
return type(self)(_function_cache=self._function_cache)

# }}}


# {{{ TransformMapper

class TransformMapperCache(CachedMapperCache[CacheExprT, CacheExprT, P]):
pass


class TransformMapper(CachedMapper[ArrayOrNames, FunctionDefinition, []]):
"""Base class for mappers that transform :class:`pytato.array.Array`\\ s into
other :class:`pytato.array.Array`\\ s.
Enables certain operations that can only be done if the mapping results are also
arrays (e.g., calling :meth:`~CachedMapper.get_cache_key` on them). Does not
implement default mapper methods; for that, see :class:`CopyMapper`.
arrays (e.g., computing a cache key from them). Does not implement default
mapper methods; for that, see :class:`CopyMapper`.
"""
pass
def __init__(
self,
_cache: TransformMapperCache[ArrayOrNames, []] | None = None,
_function_cache: TransformMapperCache[FunctionDefinition, []] | None = None
) -> None:
super().__init__(_cache=_cache, _function_cache=_function_cache)

# }}}

Expand All @@ -343,7 +478,13 @@ class TransformMapperWithExtraArgs(
The logic in :class:`TransformMapper` purposely does not take the extra
arguments to keep the cost of its each call frame low.
"""
pass
def __init__(
self,
_cache: TransformMapperCache[ArrayOrNames, P] | None = None,
_function_cache:
TransformMapperCache[FunctionDefinition, P] | None = None
) -> None:
super().__init__(_cache=_cache, _function_cache=_function_cache)

# }}}

Expand Down Expand Up @@ -1369,22 +1510,26 @@ class CachedMapAndCopyMapper(CopyMapper):
def __init__(
self,
map_fn: Callable[[ArrayOrNames], ArrayOrNames],
_function_cache: dict[Hashable, FunctionDefinition] | None = None
_cache: TransformMapperCache[ArrayOrNames, []] | None = None,
_function_cache: TransformMapperCache[FunctionDefinition, []] | None = None
) -> None:
super().__init__(_function_cache=_function_cache)
super().__init__(_cache=_cache, _function_cache=_function_cache)
self.map_fn: Callable[[ArrayOrNames], ArrayOrNames] = map_fn

def clone_for_callee(
self, function: FunctionDefinition) -> Self:
return type(self)(self.map_fn, _function_cache=self._function_cache)
return type(self)(
self.map_fn,
_function_cache=cast(
"TransformMapperCache[FunctionDefinition, []]", self._function_cache))

def rec(self, expr: ArrayOrNames) -> ArrayOrNames:
if expr in self._cache:
return self._cache[expr]

result = super().rec(self.map_fn(expr))
self._cache[expr] = result
return result
key = self._cache.get_key(expr)
try:
return self._cache.retrieve(expr, key=key)
except KeyError:
return self._cache.add(
expr, super().rec(self.map_fn(expr)), key=key)

# }}}

Expand Down
Loading

0 comments on commit 9cbe7f4

Please sign in to comment.