Skip to content

Commit

Permalink
remove map_foreign from Mapper
Browse files Browse the repository at this point in the history
  • Loading branch information
majosm committed Jan 9, 2025
1 parent ca4f336 commit 5d4b0e2
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 12 deletions.
7 changes: 5 additions & 2 deletions pytato/stringifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
IndexLambda,
ReductionDescriptor,
)
from pytato.transform import Mapper
from pytato.transform import ForeignObjectError, Mapper


if TYPE_CHECKING:
Expand Down Expand Up @@ -77,7 +77,10 @@ def rec(self, expr: Any, depth: int) -> str:
try:
return self._cache[cache_key]
except KeyError:
result = super().rec(expr, depth)
try:
result = super().rec(expr, depth)
except ForeignObjectError:
result = self.map_foreign(expr, depth)
self._cache[cache_key] = result
return result

Expand Down
19 changes: 9 additions & 10 deletions pytato/transform/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,10 @@ class UnsupportedArrayError(ValueError):
pass


class ForeignObjectError(ValueError):
pass


# {{{ mapper base class

ResultT = TypeVar("ResultT")
Expand All @@ -185,7 +189,6 @@ class Mapper(Generic[ResultT, FunctionResultT, P]):
if this is not desired.
.. automethod:: handle_unsupported_array
.. automethod:: map_foreign
.. automethod:: rec
.. automethod:: __call__
"""
Expand All @@ -199,13 +202,6 @@ def handle_unsupported_array(self, expr: MappedT,
raise UnsupportedArrayError(
f"{type(self).__name__} cannot handle expressions of type {type(expr)}")

def map_foreign(self, expr: Any, *args: P.args, **kwargs: P.kwargs) -> Any:
"""Mapper method that is invoked for an object of class for which a
mapper method does not exist in this mapper.
"""
raise ValueError(
f"{type(self).__name__} encountered invalid foreign object: {expr!r}")

def rec(self, expr: ArrayOrNames, *args: P.args, **kwargs: P.kwargs) -> ResultT:
"""Call the mapper method of *expr* and return the result."""
method: Callable[..., Any] | None
Expand All @@ -223,7 +219,9 @@ def rec(self, expr: ArrayOrNames, *args: P.args, **kwargs: P.kwargs) -> ResultT:
else:
return self.handle_unsupported_array(expr, *args, **kwargs)
else:
return cast("ResultT", self.map_foreign(expr, *args, **kwargs))
raise ForeignObjectError(
f"{type(self).__name__} encountered invalid foreign "
f"object: {expr!r}") from None

assert method is not None
return cast("ResultT", method(expr, *args, **kwargs))
Expand All @@ -237,7 +235,8 @@ def rec_function_definition(
try:
method = self.map_function_definition # type: ignore[attr-defined]
except AttributeError:
return cast("FunctionResultT", self.map_foreign(expr, *args, **kwargs))
raise ValueError(
f"{type(self).__name__} lacks a mapper method for functions.") from None

assert method is not None
return cast("FunctionResultT", method(expr, *args, **kwargs))
Expand Down

0 comments on commit 5d4b0e2

Please sign in to comment.