-
Notifications
You must be signed in to change notification settings - Fork 16
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Cache mapped functions #531
Conversation
0ed120a
to
0c73b20
Compare
dd6d54f
to
2ff245d
Compare
@inducer Should be ready for a look. I think the mypy errors are unrelated, they're occurring on main too. |
2ff245d
to
f8c2c68
Compare
f8c2c68
to
58c8dbf
Compare
4519ce2
to
913d3d1
Compare
312e8ff
to
72ce101
Compare
Uh-oh. The recent typing apocalypse has made a bit of a mess of this. I've started resolving conflicts, but I got bogged down in |
b8cbbc5
to
e57fe22
Compare
#571 should address the mypy issue. |
e57fe22
to
a4b609c
Compare
pytato/transform/__init__.py
Outdated
try: | ||
method = self.map_function_definition # type: ignore[attr-defined] | ||
except AttributeError: | ||
return cast("FunctionResultT", self.map_foreign(expr, *args, **kwargs)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can this be an error instead?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
pytato/transform/__init__.py
Outdated
@@ -194,7 +193,7 @@ def handle_unsupported_array(self, expr: MappedT, | |||
raise UnsupportedArrayError( | |||
f"{type(self).__name__} cannot handle expressions of type {type(expr)}") | |||
|
|||
def map_foreign(self, expr: Any, *args: P.args, **kwargs: P.kwargs) -> ResultT: | |||
def map_foreign(self, expr: Any, *args: P.args, **kwargs: P.kwargs) -> Any: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can get rid of map_foreign
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reprifier
appears to override map_foreign
to handle stringifying non-array fields. Seems like it needs to stay?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's move it to Reprifier
then. I don't think it needs to be a feature of the generic machinery.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
pytato/transform/__init__.py
Outdated
else: | ||
function_cache = {} | ||
|
||
self._function_cache: dict[Hashable, FunctionResultT] = function_cache |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can this be an if
expression?
pytato/transform/__init__.py
Outdated
# Why multiple inheriting? | ||
CachedMapper[ArrayOrNames, FunctionDefinition, P], | ||
Mapper[ArrayOrNames, FunctionDefinition, P] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# Why multiple inheriting? | |
CachedMapper[ArrayOrNames, FunctionDefinition, P], | |
Mapper[ArrayOrNames, FunctionDefinition, P] | |
CachedMapper[ArrayOrNames, FunctionDefinition, P], |
I don't think it should be.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems to work fine without it.
pytato/transform/__init__.py
Outdated
""" | ||
# FIXME: This inflates recursion depth |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if TYPE_CHECKING:
... (same)
else:
@property
def rec_ary(self):
return self.rec
pytato/transform/__init__.py
Outdated
""" | ||
# FIXME: This inflates recursion depth |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def is_array(expr: ArrayOrNames) -> Array:
assert isinstance(expr, Array)
return expr
(This is probably the saner way.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Something like ca4f336?
(The other way appears to work too: majosm@34af8ef)
pytato/transform/__init__.py
Outdated
# Don't need to pass function cache as argument here, because unlike | ||
# CachedMapper we're not creating a new mapper for each call | ||
self.function_cache: dict[FunctionDefinition, FunctionResultT] = {} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the behavior should be consistent across mapper types.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added _function_cache
to the __init__
args like the others. I didn't go the whole way and add clone_for_callee
, etc., yet because I'm worried that's going to make a mess of my downstream changes.
pytato/transform/__init__.py
Outdated
else: | ||
visited_functions = set() | ||
|
||
self._visited_functions: set[Any] = visited_functions |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if
expr?
pytato/target/loopy/codegen.py
Outdated
@@ -384,7 +384,7 @@ def update_t_unit(self, t_unit: lp.TranslationUnit) -> None: | |||
|
|||
# {{{ codegen mapper | |||
|
|||
class CodeGenMapper(Mapper[ImplementedResult, [CodeGenState]]): | |||
class CodeGenMapper(Mapper[ImplementedResult, None, [CodeGenState]]): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
class CodeGenMapper(Mapper[ImplementedResult, None, [CodeGenState]]): | |
class CodeGenMapper(Mapper[ImplementedResult, Never, [CodeGenState]]): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added here and in a few other places.
610d413
to
5d4b0e2
Compare
…t don't support functions
doesn't appear to be needed
the latter inflates recursion depth
5d4b0e2
to
1890af8
Compare
Ready for another look I think @inducer. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! A few questions/comments below.
def __init__( | ||
self, | ||
count_duplicates: bool = False, | ||
_visited_functions: set[Any] | None = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can something more precise than Any
be said here? (Also for other _visited_functions
annotations.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like I did that to match _visited_arrays_or_names
in CachedWalkMapper
, which also uses Any
. I guess I could change both to Hashable
?
pytato/target/python/numpy_like.py
Outdated
@@ -171,7 +172,7 @@ def _is_slice_trivial(slice_: NormalizedSlice, | |||
} | |||
|
|||
|
|||
class NumpyCodegenMapper(CachedMapper[str, []]): | |||
class NumpyCodegenMapper(CachedMapper[str, FunctionDefinition, []]): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this annotation right? Does the numpy target generate appropriate code for functions?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NumpyCodegenMapper
doesn't support functions, so I guess this should be Never
. Same for FancyDotWriter
. Fixed these.
pytato/transform/calls.py
Outdated
@@ -58,6 +58,7 @@ class PlaceholderSubstitutor(CopyMapper): | |||
""" | |||
|
|||
def __init__(self, substitutions: Mapping[str, Array]) -> None: | |||
# Ignoring function cache, not needed |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How come? Not using the function cache may lead to large amounts of redundant work?
(I'm OK with not acting on this right now, but maybe this comment should be a FIXME?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PlaceholderSubstitutor
should only be called on function bodies that don't have functions nested in them (since the placeholder names can overlap). I've disabled PlaceholderSubstitutor.map_named_call_result
to make that more explicit.
I changed |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! This looks great. In it goes.
Caches the results of mapping
FunctionDefinition
s. Caching is done globally to avoid re-traversing functions in different call stack frames.Depends on
#503(merged) and#530(merged).