Skip to content

Commit

Permalink
avoid traversing functions multiple times
Browse files Browse the repository at this point in the history
  • Loading branch information
majosm committed Dec 20, 2024
1 parent 3f90cbd commit 87b8470
Show file tree
Hide file tree
Showing 16 changed files with 233 additions and 122 deletions.
43 changes: 33 additions & 10 deletions pytato/analysis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,10 @@

from typing import TYPE_CHECKING, Any

from typing_extensions import Self

from loopy.tools import LoopyKeyBuilder
from pymbolic.mapper.optimize import optimize_mapper
from pytools import memoize_method

from pytato.array import (
Array,
Expand Down Expand Up @@ -76,7 +77,7 @@

# {{{ NUserCollector

class NUserCollector(Mapper[None, []]):
class NUserCollector(Mapper[None, None, []]):
"""
A :class:`pytato.transform.CachedWalkMapper` that records the number of
times an array expression is a direct dependency of other nodes.
Expand Down Expand Up @@ -317,7 +318,7 @@ def is_einsum_similar_to_subscript(expr: Einsum, subscripts: str) -> bool:

# {{{ DirectPredecessorsGetter

class DirectPredecessorsGetter(Mapper[frozenset[ArrayOrNames], []]):
class DirectPredecessorsGetter(Mapper[frozenset[ArrayOrNames], None, []]):
"""
Mapper to get the
`direct predecessors
Expand Down Expand Up @@ -413,16 +414,31 @@ class NodeCountMapper(CachedWalkMapper[[]]):
Dictionary mapping node types to number of nodes of that type.
"""

def __init__(self, count_duplicates: bool = False) -> None:
def __init__(
self,
count_duplicates: bool = False,
_visited_functions: set[Any] | None = None,
) -> None:
super().__init__(_visited_functions=_visited_functions)

from collections import defaultdict
super().__init__()
self.expr_type_counts: dict[type[Any], int] = defaultdict(int)
self.count_duplicates = count_duplicates

def get_cache_key(self, expr: ArrayOrNames) -> int | ArrayOrNames:
# Returns unique nodes only if count_duplicates is False
return id(expr) if self.count_duplicates else expr

def get_function_definition_cache_key(
self, expr: FunctionDefinition) -> int | FunctionDefinition:
# Returns unique nodes only if count_duplicates is False
return id(expr) if self.count_duplicates else expr

def clone_for_callee(self, function: FunctionDefinition) -> Self:
return type(self)(
count_duplicates=self.count_duplicates,
_visited_functions=self._visited_functions)

def post_visit(self, expr: Any) -> None:
if not isinstance(expr, DictOfNamedArrays):
self.expr_type_counts[type(expr)] += 1
Expand Down Expand Up @@ -488,15 +504,20 @@ class NodeMultiplicityMapper(CachedWalkMapper[[]]):
.. autoattribute:: expr_multiplicity_counts
"""
def __init__(self) -> None:
def __init__(self, _visited_functions: set[Any] | None = None) -> None:
super().__init__(_visited_functions=_visited_functions)

from collections import defaultdict
super().__init__()
self.expr_multiplicity_counts: dict[Array, int] = defaultdict(int)

def get_cache_key(self, expr: ArrayOrNames) -> int:
# Returns each node, including nodes that are duplicates
return id(expr)

def get_function_definition_cache_key(self, expr: FunctionDefinition) -> int:
# Returns each node, including nodes that are duplicates
return id(expr)

def post_visit(self, expr: Any) -> None:
if not isinstance(expr, DictOfNamedArrays):
self.expr_multiplicity_counts[expr] += 1
Expand Down Expand Up @@ -530,14 +551,16 @@ class CallSiteCountMapper(CachedWalkMapper[[]]):
The number of nodes.
"""

def __init__(self) -> None:
super().__init__()
def __init__(self, _visited_functions: set[Any] | None = None) -> None:
super().__init__(_visited_functions=_visited_functions)
self.count = 0

def get_cache_key(self, expr: ArrayOrNames) -> int:
return id(expr)

@memoize_method
def get_function_definition_cache_key(self, expr: FunctionDefinition) -> int:
return id(expr)

def map_function_definition(self, expr: FunctionDefinition) -> None:
if not self.visit(expr):
return
Expand Down
22 changes: 14 additions & 8 deletions pytato/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,9 @@


if TYPE_CHECKING:
from collections.abc import Mapping
from collections.abc import Hashable, Mapping

from pytato.function import NamedCallResult
from pytato.function import FunctionDefinition, NamedCallResult
from pytato.target import Target


Expand Down Expand Up @@ -136,10 +136,13 @@ class CodeGenPreprocessor(ToIndexLambdaMixin, CopyMapper): # type: ignore[misc]
====================================== =====================================
"""

def __init__(self, target: Target,
kernels_seen: dict[str, lp.LoopKernel] | None = None
) -> None:
super().__init__()
def __init__(
self,
target: Target,
kernels_seen: dict[str, lp.LoopKernel] | None = None,
_function_cache: dict[Hashable, FunctionDefinition] | None = None
) -> None:
super().__init__(_function_cache=_function_cache)
self.bound_arguments: dict[str, DataInterface] = {}
self.var_name_gen: UniqueNameGenerator = UniqueNameGenerator()
self.target = target
Expand Down Expand Up @@ -266,13 +269,16 @@ def normalize_outputs(

@optimize_mapper(drop_args=True, drop_kwargs=True, inline_get_cache_key=True)
class NamesValidityChecker(CachedWalkMapper[[]]):
def __init__(self) -> None:
def __init__(self, _visited_functions: set[Any] | None = None) -> None:
self.name_to_input: dict[str, InputArgumentBase] = {}
super().__init__()
super().__init__(_visited_functions=_visited_functions)

def get_cache_key(self, expr: ArrayOrNames) -> int:
return id(expr)

def get_function_definition_cache_key(self, expr: FunctionDefinition) -> int:
return id(expr)

def post_visit(self, expr: Any) -> None:
if isinstance(expr, Placeholder | SizeParam | DataWrapper):
if expr.name is not None:
Expand Down
7 changes: 4 additions & 3 deletions pytato/distributed/partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,8 +288,9 @@ def __init__(self,
recvd_ary_to_name: Mapping[Array, str],
sptpo_ary_to_name: Mapping[Array, str],
name_to_output: Mapping[str, Array],
_function_cache: dict[Hashable, FunctionDefinition] | None = None,
) -> None:
super().__init__()
super().__init__(_function_cache=_function_cache)

self.recvd_ary_to_name = recvd_ary_to_name
self.sptpo_ary_to_name = sptpo_ary_to_name
Expand All @@ -303,7 +304,7 @@ def clone_for_callee(
self, function: FunctionDefinition) -> _DistributedInputReplacer:
# Function definitions aren't allowed to contain receives,
# stored arrays promoted to part outputs, or part outputs
return type(self)({}, {}, {})
return type(self)({}, {}, {}, _function_cache=self._function_cache)

def map_placeholder(self, expr: Placeholder) -> Placeholder:
self.user_input_names.add(expr.name)
Expand Down Expand Up @@ -456,7 +457,7 @@ def _recv_to_comm_id(


class _LocalSendRecvDepGatherer(
CombineMapper[frozenset[CommunicationOpIdentifier]]):
CombineMapper[frozenset[CommunicationOpIdentifier], None]):
def __init__(self, local_rank: int) -> None:
super().__init__()
self.local_comm_ids_to_needed_comm_ids: \
Expand Down
4 changes: 2 additions & 2 deletions pytato/distributed/verify.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,8 @@ class MissingRecvError(DistributedPartitionVerificationError):

@optimize_mapper(drop_args=True, drop_kwargs=True, inline_get_cache_key=True)
class _SeenNodesWalkMapper(CachedWalkMapper[[]]):
def __init__(self) -> None:
super().__init__()
def __init__(self, _visited_functions: set[Any] | None = None) -> None:
super().__init__(_visited_functions=_visited_functions)
self.seen_nodes: set[ArrayOrNames] = set()

def get_cache_key(self, expr: ArrayOrNames) -> int:
Expand Down
34 changes: 18 additions & 16 deletions pytato/equality.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,6 @@

from typing import TYPE_CHECKING, Any

from pytools import memoize_method

from pytato.array import (
AbstractResultWithNamedArrays,
AdvancedIndexInContiguousAxes,
Expand Down Expand Up @@ -85,26 +83,31 @@ class EqualityComparer:
more on this.
"""
def __init__(self) -> None:
# Uses the same cache for both arrays and functions
self._cache: dict[tuple[int, int], bool] = {}

def rec(self, expr1: ArrayOrNames, expr2: Any) -> bool:
def rec(self, expr1: ArrayOrNames | FunctionDefinition, expr2: Any) -> bool:
cache_key = id(expr1), id(expr2)
try:
return self._cache[cache_key]
except KeyError:

method: Callable[[Array | AbstractResultWithNamedArrays, Any],
bool]

try:
method = getattr(self, expr1._mapper_method)
except AttributeError:
if isinstance(expr1, Array):
result = self.handle_unsupported_array(expr1, expr2)
if expr1 is expr2:
result = True
elif isinstance(expr1, ArrayOrNames):
method: Callable[[ArrayOrNames, Any], bool]
try:
method = getattr(self, expr1._mapper_method)
except AttributeError:
if isinstance(expr1, Array):
result = self.handle_unsupported_array(expr1, expr2)
else:
result = self.map_foreign(expr1, expr2)
else:
result = self.map_foreign(expr1, expr2)
result = method(expr1, expr2)
elif isinstance(expr1, FunctionDefinition):
result = self.map_function_definition(expr1, expr2)
else:
result = (expr1 is expr2) or method(expr1, expr2)
result = self.map_foreign(expr1, expr2)

self._cache[cache_key] = result
return result
Expand Down Expand Up @@ -296,7 +299,6 @@ def map_distributed_recv(self, expr1: DistributedRecv, expr2: Any) -> bool:
and expr1.tags == expr2.tags
)

@memoize_method
def map_function_definition(self, expr1: FunctionDefinition, expr2: Any
) -> bool:
return (expr1.__class__ is expr2.__class__
Expand All @@ -310,7 +312,7 @@ def map_function_definition(self, expr1: FunctionDefinition, expr2: Any

def map_call(self, expr1: Call, expr2: Any) -> bool:
return (expr1.__class__ is expr2.__class__
and self.map_function_definition(expr1.function, expr2.function)
and self.rec(expr1.function, expr2.function)
and frozenset(expr1.bindings) == frozenset(expr2.bindings)
and all(self.rec(bnd,
expr2.bindings[name])
Expand Down
17 changes: 12 additions & 5 deletions pytato/stringifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,6 @@
import numpy as np
from immutabledict import immutabledict

from pytools import memoize_method

from pytato.array import (
Array,
Axis,
Expand All @@ -58,7 +56,7 @@

# {{{ Reprifier

class Reprifier(Mapper[str, [int]]):
class Reprifier(Mapper[str, str, [int]]):
"""
Stringifies :mod:`pytato`-types to closely resemble CPython's implementation
of :func:`repr` for its builtin datatypes.
Expand All @@ -71,6 +69,7 @@ def __init__(self,
self.truncation_depth = truncation_depth
self.truncation_string = truncation_string

# Uses the same cache for both arrays and functions
self._cache: dict[tuple[int, int], str] = {}

def rec(self, expr: Any, depth: int) -> str:
Expand All @@ -82,6 +81,15 @@ def rec(self, expr: Any, depth: int) -> str:
self._cache[cache_key] = result
return result

def rec_function_definition(self, expr: FunctionDefinition, depth: int) -> str:
cache_key = (id(expr), depth)
try:
return self._cache[cache_key]
except KeyError:
result = super().rec_function_definition(expr, depth)
self._cache[cache_key] = result
return result

def __call__(self, expr: Any, depth: int = 0) -> str:
return self.rec(expr, depth)

Expand Down Expand Up @@ -171,7 +179,6 @@ def _get_field_val(field: str) -> str:
for field in dataclasses.fields(type(expr)))
+ ")")

@memoize_method
def map_function_definition(self, expr: FunctionDefinition, depth: int) -> str:
if depth > self.truncation_depth:
return self.truncation_string
Expand All @@ -194,7 +201,7 @@ def map_call(self, expr: Call, depth: int) -> str:

def _get_field_val(field: str) -> str:
if field == "function":
return self.map_function_definition(expr.function, depth+1)
return self.rec_function_definition(expr.function, depth+1)
else:
return self.rec(getattr(expr, field), depth+1)

Expand Down
2 changes: 1 addition & 1 deletion pytato/target/loopy/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,7 @@ def update_t_unit(self, t_unit: lp.TranslationUnit) -> None:

# {{{ codegen mapper

class CodeGenMapper(Mapper[ImplementedResult, [CodeGenState]]):
class CodeGenMapper(Mapper[ImplementedResult, None, [CodeGenState]]):
"""A mapper for generating code for nodes in the computation graph.
"""
exprgen_mapper: InlinedExpressionGenMapper
Expand Down
3 changes: 2 additions & 1 deletion pytato/target/python/numpy_like.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
SizeParam,
Stack,
)
from pytato.function import FunctionDefinition
from pytato.raising import BinaryOpType, C99CallOp
from pytato.reductions import (
AllReductionOperation,
Expand Down Expand Up @@ -171,7 +172,7 @@ def _is_slice_trivial(slice_: NormalizedSlice,
}


class NumpyCodegenMapper(CachedMapper[str, []]):
class NumpyCodegenMapper(CachedMapper[str, FunctionDefinition, []]):
"""
.. note::
Expand Down
Loading

0 comments on commit 87b8470

Please sign in to comment.