Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cache mapped functions #531

Merged
merged 13 commits into from
Jan 17, 2025
Merged

Conversation

majosm
Copy link
Collaborator

@majosm majosm commented Jul 26, 2024

Caches the results of mapping FunctionDefinitions. Caching is done globally to avoid re-traversing functions in different call stack frames.

Depends on #503 (merged) and #530 (merged).

@majosm majosm force-pushed the cache-mapped-function-results branch from 0ed120a to 0c73b20 Compare July 26, 2024 04:54
@majosm majosm mentioned this pull request Jul 26, 2024
@majosm majosm force-pushed the cache-mapped-function-results branch 6 times, most recently from dd6d54f to 2ff245d Compare August 2, 2024 21:33
@majosm majosm marked this pull request as ready for review August 2, 2024 22:24
@majosm
Copy link
Collaborator Author

majosm commented Aug 2, 2024

@inducer Should be ready for a look. I think the mypy errors are unrelated, they're occurring on main too.

@majosm majosm requested a review from inducer August 2, 2024 22:26
@majosm majosm force-pushed the cache-mapped-function-results branch from 2ff245d to f8c2c68 Compare August 12, 2024 20:45
@inducer inducer force-pushed the cache-mapped-function-results branch from f8c2c68 to 58c8dbf Compare August 25, 2024 04:54
@majosm majosm force-pushed the cache-mapped-function-results branch 2 times, most recently from 4519ce2 to 913d3d1 Compare September 5, 2024 19:23
@majosm majosm force-pushed the cache-mapped-function-results branch 2 times, most recently from 312e8ff to 72ce101 Compare September 24, 2024 17:47
This was referenced Sep 24, 2024
@inducer
Copy link
Owner

inducer commented Nov 14, 2024

Uh-oh. The recent typing apocalypse has made a bit of a mess of this. I've started resolving conflicts, but I got bogged down in pytato.transform. Could you take a look? (Or, if you've lost patience with me, I can also try and finish.)

@majosm majosm force-pushed the cache-mapped-function-results branch 4 times, most recently from b8cbbc5 to e57fe22 Compare December 20, 2024 02:36
@majosm
Copy link
Collaborator Author

majosm commented Dec 20, 2024

@inducer Should be ready for review again (mypy error looks unrelated?). In e57fe22 I added a few questions/comments about things I noticed while resolving conflicts. I'll remove them after you've had a chance to glance at them.

@inducer
Copy link
Owner

inducer commented Dec 20, 2024

#571 should address the mypy issue.

@majosm majosm force-pushed the cache-mapped-function-results branch from e57fe22 to a4b609c Compare January 8, 2025 16:01
Comment on lines 231 to 240
try:
method = self.map_function_definition # type: ignore[attr-defined]
except AttributeError:
return cast("FunctionResultT", self.map_foreign(expr, *args, **kwargs))
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this be an error instead?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@@ -194,7 +193,7 @@ def handle_unsupported_array(self, expr: MappedT,
raise UnsupportedArrayError(
f"{type(self).__name__} cannot handle expressions of type {type(expr)}")

def map_foreign(self, expr: Any, *args: P.args, **kwargs: P.kwargs) -> ResultT:
def map_foreign(self, expr: Any, *args: P.args, **kwargs: P.kwargs) -> Any:
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can get rid of map_foreign.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reprifier appears to override map_foreign to handle stringifying non-array fields. Seems like it needs to stay?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's move it to Reprifier then. I don't think it needs to be a feature of the generic machinery.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

else:
function_cache = {}

self._function_cache: dict[Hashable, FunctionResultT] = function_cache
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this be an if expression?

Comment on lines 339 to 341
# Why multiple inheriting?
CachedMapper[ArrayOrNames, FunctionDefinition, P],
Mapper[ArrayOrNames, FunctionDefinition, P]
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Why multiple inheriting?
CachedMapper[ArrayOrNames, FunctionDefinition, P],
Mapper[ArrayOrNames, FunctionDefinition, P]
CachedMapper[ArrayOrNames, FunctionDefinition, P],

I don't think it should be.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems to work fine without it.

"""
# FIXME: This inflates recursion depth
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if TYPE_CHECKING:
    ... (same)
else:
    @property
    def rec_ary(self):
        return self.rec

"""
# FIXME: This inflates recursion depth
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

def is_array(expr: ArrayOrNames) -> Array:
    assert isinstance(expr, Array)
    return expr

(This is probably the saner way.)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Something like ca4f336?

(The other way appears to work too: majosm@34af8ef)

Comment on lines 783 to 785
# Don't need to pass function cache as argument here, because unlike
# CachedMapper we're not creating a new mapper for each call
self.function_cache: dict[FunctionDefinition, FunctionResultT] = {}
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the behavior should be consistent across mapper types.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added _function_cache to the __init__ args like the others. I didn't go the whole way and add clone_for_callee, etc., yet because I'm worried that's going to make a mess of my downstream changes.

else:
visited_functions = set()

self._visited_functions: set[Any] = visited_functions
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if expr?

@@ -384,7 +384,7 @@ def update_t_unit(self, t_unit: lp.TranslationUnit) -> None:

# {{{ codegen mapper

class CodeGenMapper(Mapper[ImplementedResult, [CodeGenState]]):
class CodeGenMapper(Mapper[ImplementedResult, None, [CodeGenState]]):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
class CodeGenMapper(Mapper[ImplementedResult, None, [CodeGenState]]):
class CodeGenMapper(Mapper[ImplementedResult, Never, [CodeGenState]]):

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added here and in a few other places.

@majosm majosm force-pushed the cache-mapped-function-results branch 4 times, most recently from 610d413 to 5d4b0e2 Compare January 9, 2025 21:30
@majosm majosm force-pushed the cache-mapped-function-results branch from 5d4b0e2 to 1890af8 Compare January 9, 2025 21:43
@majosm majosm requested a review from inducer January 9, 2025 22:18
@majosm
Copy link
Collaborator Author

majosm commented Jan 9, 2025

Ready for another look I think @inducer.

Copy link
Owner

@inducer inducer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! A few questions/comments below.

def __init__(
self,
count_duplicates: bool = False,
_visited_functions: set[Any] | None = None,
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can something more precise than Any be said here? (Also for other _visited_functions annotations.)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like I did that to match _visited_arrays_or_names in CachedWalkMapper, which also uses Any. I guess I could change both to Hashable?

@@ -171,7 +172,7 @@ def _is_slice_trivial(slice_: NormalizedSlice,
}


class NumpyCodegenMapper(CachedMapper[str, []]):
class NumpyCodegenMapper(CachedMapper[str, FunctionDefinition, []]):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this annotation right? Does the numpy target generate appropriate code for functions?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NumpyCodegenMapper doesn't support functions, so I guess this should be Never. Same for FancyDotWriter. Fixed these.

@@ -58,6 +58,7 @@ class PlaceholderSubstitutor(CopyMapper):
"""

def __init__(self, substitutions: Mapping[str, Array]) -> None:
# Ignoring function cache, not needed
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How come? Not using the function cache may lead to large amounts of redundant work?

(I'm OK with not acting on this right now, but maybe this comment should be a FIXME?)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PlaceholderSubstitutor should only be called on function bodies that don't have functions nested in them (since the placeholder names can overlap). I've disabled PlaceholderSubstitutor.map_named_call_result to make that more explicit.

pytato/transform/__init__.py Show resolved Hide resolved
@majosm
Copy link
Collaborator Author

majosm commented Jan 16, 2025

I changed CachedWalkMapper to use an alias of Hashable instead of Any, and I removed a few other uses of Any that I found. Should be ready for another look.

@majosm majosm requested a review from inducer January 16, 2025 22:56
Copy link
Owner

@inducer inducer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! This looks great. In it goes.

pytato/transform/__init__.py Show resolved Hide resolved
@inducer inducer merged commit 46431ee into inducer:main Jan 17, 2025
11 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants