From c74acdf77ac09dfbf9b07813ab8390cf6cb9636f Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Tue, 18 Apr 2023 18:57:18 -0500 Subject: [PATCH] make the parametric type of CachedMapper to be the return type --- pytato/codegen.py | 3 +- pytato/distributed/partition.py | 21 ++++-- pytato/partition.py | 5 +- pytato/target/python/numpy_like.py | 9 ++- pytato/transform/__init__.py | 68 ++++++++++++++------ pytato/transform/remove_broadcasts_einsum.py | 4 +- 6 files changed, 76 insertions(+), 34 deletions(-) diff --git a/pytato/codegen.py b/pytato/codegen.py index fb208a5a5..63067fbd9 100644 --- a/pytato/codegen.py +++ b/pytato/codegen.py @@ -83,7 +83,8 @@ def _generate_name_for_temp( # {{{ preprocessing for codegen -class CodeGenPreprocessor(ToIndexLambdaMixin, CopyMapper): +# type-ignore-reason: incompatible 'rec' types between ToIndexLambdaMixin, CopyMapper +class CodeGenPreprocessor(ToIndexLambdaMixin, CopyMapper): # type: ignore[misc] """A mapper that preprocesses graphs to simplify code generation. The following node simplifications are performed: diff --git a/pytato/distributed/partition.py b/pytato/distributed/partition.py index 3da47b91e..e476318f9 100644 --- a/pytato/distributed/partition.py +++ b/pytato/distributed/partition.py @@ -50,7 +50,7 @@ NamedArray) from pytato.transform import (ArrayOrNames, CopyMapper, Mapper, CachedWalkMapper, CopyMapperWithExtraArgs, - CombineMapper) + CombineMapper, CopyMapperT) from pytato.partition import GraphPart, GraphPartition, PartId, GraphPartitioner from pytato.distributed.nodes import ( DistributedRecv, DistributedSend, DistributedSendRefHolder) @@ -233,10 +233,12 @@ def __init__(self, get_part_id: Callable[[ArrayOrNames], PartId]) -> None: def map_distributed_send_ref_holder( self, expr: DistributedSendRefHolder, *args: Any) -> Any: send_part_id = self.get_part_id(expr.send.data) + rec_send_data = self.rec(expr.send.data) + assert isinstance(rec_send_data, Array) self.pid_to_dist_sends.setdefault(send_part_id, []).append( DistributedSend( - data=self.rec(expr.send.data), + data=rec_send_data, dest_rank=expr.send.dest_rank, comm_tag=expr.send.comm_tag, tags=expr.send.tags)) @@ -560,7 +562,7 @@ def __init__(self, # type-ignore reason: incompatible attribute type wrt base. self._cache: Dict[Tuple[ArrayOrNames, int], - Any] = {} # type: ignore[assignment] + ArrayOrNames] = {} # type: ignore[assignment] # type-ignore-reason: incompatible with super class def get_cache_key(self, # type: ignore[override] @@ -572,11 +574,12 @@ def get_cache_key(self, # type: ignore[override] # type-ignore-reason: incompatible with super class def rec(self, # type: ignore[override] - expr: ArrayOrNames, - user_part_id: int) -> Any: + expr: CopyMapperT, + user_part_id: int) -> CopyMapperT: key = self.get_cache_key(expr, user_part_id) try: - return self._cache[key] + # type-ignore-reason: parametric dicts are not a thing in typing module + return self._cache[key] # type: ignore[return-value] except KeyError: if isinstance(expr, Array): if expr in self.stored_array_to_part_id: @@ -592,6 +595,12 @@ def rec(self, # type: ignore[override] self._cache[key] = result return result + # type-ignore-reason: incompatible with super class + def __call__(self, # type: ignore[override] + expr: CopyMapperT, + user_part_id: int) -> CopyMapperT: + return self.rec(expr, user_part_id) + def _remove_part_id_tag(ary: ArrayOrNames) -> Array: assert isinstance(ary, Array) diff --git a/pytato/partition.py b/pytato/partition.py index 8b1596844..1cfa5f541 100644 --- a/pytato/partition.py +++ b/pytato/partition.py @@ -161,7 +161,10 @@ def handle_edge(self, expr: ArrayOrNames, child: ArrayOrNames) -> Any: tags=child.tags, axes=child.axes) - self.var_name_to_result[ph_name] = self.rec(child) + # type-ignore-reason: mypy is right, types of self.rec are + # imprecise (TODO) + self.var_name_to_result[ph_name] = ( + self.rec(child)) # type: ignore[assignment] self._seen_node_to_placeholder[child] = ph diff --git a/pytato/target/python/numpy_like.py b/pytato/target/python/numpy_like.py index 4d522e507..10c5fe4f9 100644 --- a/pytato/target/python/numpy_like.py +++ b/pytato/target/python/numpy_like.py @@ -31,7 +31,7 @@ cast, List, Set, Tuple, Type) from pytools import UniqueNameGenerator -from pytato.transform import CachedMapper, ArrayOrNames +from pytato.transform import CachedMapper from pytato.array import (Stack, Concatenate, IndexLambda, DataWrapper, Placeholder, SizeParam, Roll, AxisPermutation, Einsum, @@ -164,7 +164,7 @@ def _is_slice_trivial(slice_: NormalizedSlice, } -class NumpyCodegenMapper(CachedMapper[ArrayOrNames]): +class NumpyCodegenMapper(CachedMapper[str]): """ .. note:: @@ -408,7 +408,7 @@ def _map_index_base(self, expr: IndexBase) -> str: ) if last_non_trivial_index == -1: - return self.rec(expr.array) # type: ignore[no-any-return] + return self.rec(expr.array) lhs = self.vng("_pt_tmp") @@ -500,8 +500,7 @@ def map_reshape(self, expr: Reshape) -> str: return self._record_line_and_return_lhs(lhs, rhs) def map_named_array(self, expr: NamedArray) -> str: - # type-ignore-reason: CachedMapper.rec's types are imprecise - return self.rec(expr.expr) # type: ignore[no-any-return] + return self.rec(expr.expr) def map_dict_of_named_arrays(self, expr: DictOfNamedArrays) -> str: lhs = self.vng("_pt_tmp") diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index 5583e359e..9e4c7b8af 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -49,8 +49,11 @@ from pymbolic.mapper.optimize import optimize_mapper ArrayOrNames = Union[Array, AbstractResultWithNamedArrays] -MappedT = TypeVar("MappedT", bound=ArrayOrNames) +MappedT = TypeVar("MappedT", + Array, AbstractResultWithNamedArrays, ArrayOrNames) CombineT = TypeVar("CombineT") # used in CombineMapper +CopyMapperT = TypeVar("CopyMapperT", # used in CopyMapper + Array, AbstractResultWithNamedArrays, ArrayOrNames) CachedMapperT = TypeVar("CachedMapperT") # used in CachedMapper IndexOrShapeExpr = TypeVar("IndexOrShapeExpr") R = FrozenSet[Array] @@ -180,20 +183,26 @@ class CachedMapper(Mapper, Generic[CachedMapperT]): """ def __init__(self) -> None: - self._cache: Dict[CachedMapperT, Any] = {} + self._cache: Dict[Any, CachedMapperT] = {} - def get_cache_key(self, expr: CachedMapperT) -> Any: + def get_cache_key(self, expr: ArrayOrNames) -> Any: return expr # type-ignore-reason: incompatible with super class - def rec(self, expr: CachedMapperT) -> Any: # type: ignore[override] + def rec(self, expr: ArrayOrNames) -> CachedMapperT: # type: ignore[override] key = self.get_cache_key(expr) try: return self._cache[key] except KeyError: - result = super().rec(expr) # type: ignore[type-var] + result = super().rec(expr) self._cache[key] = result - return result + # type-ignore-reason: Mapper.rec has imprecise func. signature + return result # type: ignore[no-any-return] + + # type-ignore-reason: incompatible with super class + def __call__(self, expr: ArrayOrNames # type: ignore[override] + ) -> CachedMapperT: + return self.rec(expr) # }}} @@ -210,9 +219,21 @@ class CopyMapper(CachedMapper[ArrayOrNames]): This does not copy the data of a :class:`pytato.array.DataWrapper`. """ + # type-ignore-reason: specialized variant of super-class' rec method + def rec(self, expr: CopyMapperT) -> CopyMapperT: # type: ignore[override] + # type-ignore-reason: CachedMapper.rec's return type is imprecise + return super().rec(expr) # type: ignore[return-value] + + # type-ignore-reason: specialized variant of super-class' rec method + def __call__(self, expr: CopyMapperT) -> CopyMapperT: # type: ignore[override] + return self.rec(expr) + def rec_idx_or_size_tuple(self, situp: Tuple[IndexOrShapeExpr, ...] ) -> Tuple[IndexOrShapeExpr, ...]: - return tuple(self.rec(s) if isinstance(s, Array) else s for s in situp) + # type-ignore-reason: apparently mypy cannot substitute typevars + # here. + return tuple(self.rec(s) if isinstance(s, Array) else s # type: ignore[misc] + for s in situp) def map_index_lambda(self, expr: IndexLambda) -> Array: bindings: Dict[str, Array] = { @@ -319,8 +340,10 @@ def map_loopy_call(self, expr: LoopyCall) -> LoopyCall: ) def map_loopy_call_result(self, expr: LoopyCallResult) -> Array: + rec_container = self.rec(expr._container) + assert isinstance(rec_container, LoopyCall) return LoopyCallResult( - loopy_call=self.rec(expr._container), + loopy_call=rec_container, name=expr.name, axes=expr.axes, tags=expr.tags) @@ -364,7 +387,7 @@ def __init__(self) -> None: Tuple[Any, ...], Tuple[Tuple[str, Any], ...] ], - Any] = {} # type: ignore[assignment] + ArrayOrNames] = {} def get_cache_key(self, expr: ArrayOrNames, @@ -375,23 +398,30 @@ def get_cache_key(self, return (expr, args, tuple(sorted(kwargs.items()))) def rec(self, - expr: ArrayOrNames, - *args: Any, **kwargs: Any) -> Any: + expr: CopyMapperT, + *args: Any, **kwargs: Any) -> CopyMapperT: key = self.get_cache_key(expr, *args, **kwargs) try: - return self._cache[key] + # type-ignore-reason: self._cache has ArrayOrNames as its values + return self._cache[key] # type: ignore[return-value] except KeyError: result = Mapper.rec(self, expr, *args, **kwargs) self._cache[key] = result - return result + # type-ignore-reason: Mapper.rec is imprecise + return result # type: ignore[no-any-return] def rec_idx_or_size_tuple(self, situp: Tuple[IndexOrShapeExpr, ...], *args: Any, **kwargs: Any ) -> Tuple[IndexOrShapeExpr, ...]: - return tuple(self.rec(s, *args, **kwargs) if isinstance(s, Array) else s - for s in situp) + # type-ignore-reason: apparently mypy cannot substitute typevars + # here. + return tuple( + self.rec(s, *args, **kwargs) # type: ignore[misc] + if isinstance(s, Array) + else s + for s in situp) def map_index_lambda(self, expr: IndexLambda, *args: Any, **kwargs: Any) -> Array: @@ -510,8 +540,10 @@ def map_loopy_call(self, expr: LoopyCall, def map_loopy_call_result(self, expr: LoopyCallResult, *args: Any, **kwargs: Any) -> Array: + rec_loopy_call = self.rec(expr._container, *args, **kwargs) + assert isinstance(rec_loopy_call, LoopyCall) return LoopyCallResult( - loopy_call=self.rec(expr._container, *args, **kwargs), + loopy_call=rec_loopy_call, name=expr.name, axes=expr.axes, tags=expr.tags) @@ -1018,11 +1050,11 @@ def __init__(self, map_fn: Callable[[ArrayOrNames], ArrayOrNames]) -> None: # type-ignore-reason:incompatible with Mapper.rec() def rec(self, expr: ArrayOrNames) -> ArrayOrNames: # type: ignore[override] if expr in self._cache: - return self._cache[expr] # type: ignore[no-any-return] + return self._cache[expr] result = super().rec(self.map_fn(expr)) self._cache[expr] = result - return result # type: ignore[no-any-return] + return result # type-ignore-reason: Mapper.__call__ returns Any def __call__(self, expr: ArrayOrNames) -> ArrayOrNames: # type: ignore[override] diff --git a/pytato/transform/remove_broadcasts_einsum.py b/pytato/transform/remove_broadcasts_einsum.py index 2d0273697..22e45a117 100644 --- a/pytato/transform/remove_broadcasts_einsum.py +++ b/pytato/transform/remove_broadcasts_einsum.py @@ -96,8 +96,6 @@ def rewrite_einsums_with_no_broadcasts(expr: MappedT) -> MappedT: alter its value. """ mapper = EinsumWithNoBroadcastsRewriter() - - # type-ignore-reason: mypy is right i.e. CopyMapper.__call__ is imprecise - return mapper(expr) # type: ignore[no-any-return] + return mapper(expr) # vim:fdm=marker