Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cache mapped functions #531

Merged
merged 13 commits into from
Jan 17, 2025
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 34 additions & 11 deletions pytato/analysis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,12 @@
THE SOFTWARE.
"""

from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Never

from typing_extensions import Self

from loopy.tools import LoopyKeyBuilder
from pymbolic.mapper.optimize import optimize_mapper
from pytools import memoize_method

from pytato.array import (
Array,
Expand Down Expand Up @@ -76,7 +77,7 @@

# {{{ NUserCollector

class NUserCollector(Mapper[None, []]):
class NUserCollector(Mapper[None, None, []]):
"""
A :class:`pytato.transform.CachedWalkMapper` that records the number of
times an array expression is a direct dependency of other nodes.
Expand Down Expand Up @@ -317,7 +318,7 @@ def is_einsum_similar_to_subscript(expr: Einsum, subscripts: str) -> bool:

# {{{ DirectPredecessorsGetter

class DirectPredecessorsGetter(Mapper[frozenset[ArrayOrNames], []]):
class DirectPredecessorsGetter(Mapper[frozenset[ArrayOrNames], Never, []]):
"""
Mapper to get the
`direct predecessors
Expand Down Expand Up @@ -413,16 +414,31 @@ class NodeCountMapper(CachedWalkMapper[[]]):
Dictionary mapping node types to number of nodes of that type.
"""

def __init__(self, count_duplicates: bool = False) -> None:
def __init__(
self,
count_duplicates: bool = False,
_visited_functions: set[Any] | None = None,
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can something more precise than Any be said here? (Also for other _visited_functions annotations.)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like I did that to match _visited_arrays_or_names in CachedWalkMapper, which also uses Any. I guess I could change both to Hashable?

) -> None:
super().__init__(_visited_functions=_visited_functions)

from collections import defaultdict
super().__init__()
self.expr_type_counts: dict[type[Any], int] = defaultdict(int)
self.count_duplicates = count_duplicates

def get_cache_key(self, expr: ArrayOrNames) -> int | ArrayOrNames:
# Returns unique nodes only if count_duplicates is False
return id(expr) if self.count_duplicates else expr

def get_function_definition_cache_key(
self, expr: FunctionDefinition) -> int | FunctionDefinition:
# Returns unique nodes only if count_duplicates is False
return id(expr) if self.count_duplicates else expr

def clone_for_callee(self, function: FunctionDefinition) -> Self:
return type(self)(
count_duplicates=self.count_duplicates,
_visited_functions=self._visited_functions)

def post_visit(self, expr: Any) -> None:
if not isinstance(expr, DictOfNamedArrays):
self.expr_type_counts[type(expr)] += 1
Expand Down Expand Up @@ -488,15 +504,20 @@ class NodeMultiplicityMapper(CachedWalkMapper[[]]):

.. autoattribute:: expr_multiplicity_counts
"""
def __init__(self) -> None:
def __init__(self, _visited_functions: set[Any] | None = None) -> None:
super().__init__(_visited_functions=_visited_functions)

from collections import defaultdict
super().__init__()
self.expr_multiplicity_counts: dict[Array, int] = defaultdict(int)

def get_cache_key(self, expr: ArrayOrNames) -> int:
# Returns each node, including nodes that are duplicates
return id(expr)

def get_function_definition_cache_key(self, expr: FunctionDefinition) -> int:
# Returns each node, including nodes that are duplicates
return id(expr)

def post_visit(self, expr: Any) -> None:
if not isinstance(expr, DictOfNamedArrays):
self.expr_multiplicity_counts[expr] += 1
Expand Down Expand Up @@ -530,14 +551,16 @@ class CallSiteCountMapper(CachedWalkMapper[[]]):
The number of nodes.
"""

def __init__(self) -> None:
super().__init__()
def __init__(self, _visited_functions: set[Any] | None = None) -> None:
super().__init__(_visited_functions=_visited_functions)
self.count = 0

def get_cache_key(self, expr: ArrayOrNames) -> int:
return id(expr)

@memoize_method
def get_function_definition_cache_key(self, expr: FunctionDefinition) -> int:
return id(expr)

def map_function_definition(self, expr: FunctionDefinition) -> None:
if not self.visit(expr):
return
Expand Down
22 changes: 14 additions & 8 deletions pytato/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,9 @@


if TYPE_CHECKING:
from collections.abc import Mapping
from collections.abc import Hashable, Mapping

from pytato.function import NamedCallResult
from pytato.function import FunctionDefinition, NamedCallResult
from pytato.target import Target


Expand Down Expand Up @@ -136,10 +136,13 @@ class CodeGenPreprocessor(ToIndexLambdaMixin, CopyMapper): # type: ignore[misc]
====================================== =====================================
"""

def __init__(self, target: Target,
kernels_seen: dict[str, lp.LoopKernel] | None = None
) -> None:
super().__init__()
def __init__(
self,
target: Target,
kernels_seen: dict[str, lp.LoopKernel] | None = None,
_function_cache: dict[Hashable, FunctionDefinition] | None = None
) -> None:
super().__init__(_function_cache=_function_cache)
self.bound_arguments: dict[str, DataInterface] = {}
self.var_name_gen: UniqueNameGenerator = UniqueNameGenerator()
self.target = target
Expand Down Expand Up @@ -266,13 +269,16 @@ def normalize_outputs(

@optimize_mapper(drop_args=True, drop_kwargs=True, inline_get_cache_key=True)
class NamesValidityChecker(CachedWalkMapper[[]]):
def __init__(self) -> None:
def __init__(self, _visited_functions: set[Any] | None = None) -> None:
self.name_to_input: dict[str, InputArgumentBase] = {}
super().__init__()
super().__init__(_visited_functions=_visited_functions)

def get_cache_key(self, expr: ArrayOrNames) -> int:
return id(expr)

def get_function_definition_cache_key(self, expr: FunctionDefinition) -> int:
return id(expr)

def post_visit(self, expr: Any) -> None:
if isinstance(expr, Placeholder | SizeParam | DataWrapper):
if expr.name is not None:
Expand Down
18 changes: 13 additions & 5 deletions pytato/distributed/partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
TYPE_CHECKING,
Any,
Generic,
Never,
TypeVar,
cast,
)
Expand All @@ -89,7 +90,13 @@
DistributedSendRefHolder,
)
from pytato.scalar_expr import SCALAR_CLASSES
from pytato.transform import ArrayOrNames, CachedWalkMapper, CombineMapper, CopyMapper
from pytato.transform import (
ArrayOrNames,
CachedWalkMapper,
CombineMapper,
CopyMapper,
_verify_is_array,
)


if TYPE_CHECKING:
Expand Down Expand Up @@ -288,8 +295,9 @@ def __init__(self,
recvd_ary_to_name: Mapping[Array, str],
sptpo_ary_to_name: Mapping[Array, str],
name_to_output: Mapping[str, Array],
_function_cache: dict[Hashable, FunctionDefinition] | None = None,
) -> None:
super().__init__()
super().__init__(_function_cache=_function_cache)

self.recvd_ary_to_name = recvd_ary_to_name
self.sptpo_ary_to_name = sptpo_ary_to_name
Expand All @@ -303,7 +311,7 @@ def clone_for_callee(
self, function: FunctionDefinition) -> _DistributedInputReplacer:
# Function definitions aren't allowed to contain receives,
# stored arrays promoted to part outputs, or part outputs
return type(self)({}, {}, {})
return type(self)({}, {}, {}, _function_cache=self._function_cache)

def map_placeholder(self, expr: Placeholder) -> Placeholder:
self.user_input_names.add(expr.name)
Expand Down Expand Up @@ -394,7 +402,7 @@ def _make_distributed_partition(

for name, val in name_to_part_output.items():
assert name not in name_to_output
name_to_output[name] = comm_replacer.rec_ary(val)
name_to_output[name] = _verify_is_array(comm_replacer.rec(val))

comm_ids = part_comm_ids[part_id]

Expand Down Expand Up @@ -456,7 +464,7 @@ def _recv_to_comm_id(


class _LocalSendRecvDepGatherer(
CombineMapper[frozenset[CommunicationOpIdentifier]]):
CombineMapper[frozenset[CommunicationOpIdentifier], Never]):
def __init__(self, local_rank: int) -> None:
super().__init__()
self.local_comm_ids_to_needed_comm_ids: \
Expand Down
4 changes: 2 additions & 2 deletions pytato/distributed/verify.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,8 @@ class MissingRecvError(DistributedPartitionVerificationError):

@optimize_mapper(drop_args=True, drop_kwargs=True, inline_get_cache_key=True)
class _SeenNodesWalkMapper(CachedWalkMapper[[]]):
def __init__(self) -> None:
super().__init__()
def __init__(self, _visited_functions: set[Any] | None = None) -> None:
super().__init__(_visited_functions=_visited_functions)
self.seen_nodes: set[ArrayOrNames] = set()

def get_cache_key(self, expr: ArrayOrNames) -> int:
Expand Down
37 changes: 20 additions & 17 deletions pytato/equality.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,6 @@

from typing import TYPE_CHECKING, Any

from pytools import memoize_method

from pytato.array import (
AbstractResultWithNamedArrays,
AdvancedIndexInContiguousAxes,
Expand All @@ -49,13 +47,14 @@
SizeParam,
Stack,
)
from pytato.function import FunctionDefinition


if TYPE_CHECKING:
from collections.abc import Callable

from pytato.distributed.nodes import DistributedRecv, DistributedSendRefHolder
from pytato.function import Call, FunctionDefinition, NamedCallResult
from pytato.function import Call, NamedCallResult
from pytato.loopy import LoopyCall, LoopyCallResult

__doc__ = """
Expand Down Expand Up @@ -85,26 +84,31 @@ class EqualityComparer:
more on this.
"""
def __init__(self) -> None:
# Uses the same cache for both arrays and functions
self._cache: dict[tuple[int, int], bool] = {}

def rec(self, expr1: ArrayOrNames, expr2: Any) -> bool:
def rec(self, expr1: ArrayOrNames | FunctionDefinition, expr2: Any) -> bool:
cache_key = id(expr1), id(expr2)
try:
return self._cache[cache_key]
except KeyError:

method: Callable[[Array | AbstractResultWithNamedArrays, Any],
bool]

try:
method = getattr(self, expr1._mapper_method)
except AttributeError:
if isinstance(expr1, Array):
result = self.handle_unsupported_array(expr1, expr2)
if expr1 is expr2:
result = True
elif isinstance(expr1, ArrayOrNames):
method: Callable[[ArrayOrNames, Any], bool]
try:
method = getattr(self, expr1._mapper_method)
except AttributeError:
if isinstance(expr1, Array):
result = self.handle_unsupported_array(expr1, expr2)
else:
result = self.map_foreign(expr1, expr2)
else:
result = self.map_foreign(expr1, expr2)
result = method(expr1, expr2)
elif isinstance(expr1, FunctionDefinition):
result = self.map_function_definition(expr1, expr2)
else:
result = (expr1 is expr2) or method(expr1, expr2)
result = self.map_foreign(expr1, expr2)

self._cache[cache_key] = result
return result
Expand Down Expand Up @@ -296,7 +300,6 @@ def map_distributed_recv(self, expr1: DistributedRecv, expr2: Any) -> bool:
and expr1.tags == expr2.tags
)

@memoize_method
def map_function_definition(self, expr1: FunctionDefinition, expr2: Any
) -> bool:
return (expr1.__class__ is expr2.__class__
Expand All @@ -310,7 +313,7 @@ def map_function_definition(self, expr1: FunctionDefinition, expr2: Any

def map_call(self, expr1: Call, expr2: Any) -> bool:
return (expr1.__class__ is expr2.__class__
and self.map_function_definition(expr1.function, expr2.function)
and self.rec(expr1.function, expr2.function)
and frozenset(expr1.bindings) == frozenset(expr2.bindings)
and all(self.rec(bnd,
expr2.bindings[name])
Expand Down
24 changes: 17 additions & 7 deletions pytato/stringifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,6 @@
import numpy as np
from immutabledict import immutabledict

from pytools import memoize_method

from pytato.array import (
Array,
Axis,
Expand All @@ -41,7 +39,7 @@
IndexLambda,
ReductionDescriptor,
)
from pytato.transform import Mapper
from pytato.transform import ForeignObjectError, Mapper


if TYPE_CHECKING:
Expand All @@ -58,7 +56,7 @@

# {{{ Reprifier

class Reprifier(Mapper[str, [int]]):
class Reprifier(Mapper[str, str, [int]]):
"""
Stringifies :mod:`pytato`-types to closely resemble CPython's implementation
of :func:`repr` for its builtin datatypes.
Expand All @@ -71,14 +69,27 @@ def __init__(self,
self.truncation_depth = truncation_depth
self.truncation_string = truncation_string

# Uses the same cache for both arrays and functions
self._cache: dict[tuple[int, int], str] = {}

def rec(self, expr: Any, depth: int) -> str:
cache_key = (id(expr), depth)
try:
return self._cache[cache_key]
except KeyError:
result = super().rec(expr, depth)
try:
result = super().rec(expr, depth)
except ForeignObjectError:
result = self.map_foreign(expr, depth)
self._cache[cache_key] = result
return result

def rec_function_definition(self, expr: FunctionDefinition, depth: int) -> str:
cache_key = (id(expr), depth)
try:
return self._cache[cache_key]
except KeyError:
result = super().rec_function_definition(expr, depth)
self._cache[cache_key] = result
return result

Expand Down Expand Up @@ -171,7 +182,6 @@ def _get_field_val(field: str) -> str:
for field in dataclasses.fields(type(expr)))
+ ")")

@memoize_method
def map_function_definition(self, expr: FunctionDefinition, depth: int) -> str:
if depth > self.truncation_depth:
return self.truncation_string
Expand All @@ -194,7 +204,7 @@ def map_call(self, expr: Call, depth: int) -> str:

def _get_field_val(field: str) -> str:
if field == "function":
return self.map_function_definition(expr.function, depth+1)
return self.rec_function_definition(expr.function, depth+1)
else:
return self.rec(getattr(expr, field), depth+1)

Expand Down
Loading
Loading