Skip to content

Commit

Permalink
make the parametric type of CachedMapper to be the return type
Browse files Browse the repository at this point in the history
  • Loading branch information
kaushikcfd committed Apr 19, 2023
1 parent e45d63b commit b4de0ab
Show file tree
Hide file tree
Showing 6 changed files with 76 additions and 34 deletions.
3 changes: 2 additions & 1 deletion pytato/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,8 @@ def _generate_name_for_temp(

# {{{ preprocessing for codegen

class CodeGenPreprocessor(ToIndexLambdaMixin, CopyMapper):
# type-ignore-reason: incompatible 'rec' types between ToIndexLambdaMixin, CopyMapper
class CodeGenPreprocessor(ToIndexLambdaMixin, CopyMapper): # type: ignore[misc]
"""A mapper that preprocesses graphs to simplify code generation.
The following node simplifications are performed:
Expand Down
21 changes: 15 additions & 6 deletions pytato/distributed/partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
NamedArray)
from pytato.transform import (ArrayOrNames, CopyMapper, Mapper,
CachedWalkMapper, CopyMapperWithExtraArgs,
CombineMapper)
CombineMapper, CopyMapperT)
from pytato.partition import GraphPart, GraphPartition, PartId, GraphPartitioner
from pytato.distributed.nodes import (
DistributedRecv, DistributedSend, DistributedSendRefHolder)
Expand Down Expand Up @@ -233,10 +233,12 @@ def __init__(self, get_part_id: Callable[[ArrayOrNames], PartId]) -> None:
def map_distributed_send_ref_holder(
self, expr: DistributedSendRefHolder, *args: Any) -> Any:
send_part_id = self.get_part_id(expr.send.data)
rec_send_data = self.rec(expr.send.data)
assert isinstance(rec_send_data, Array)

self.pid_to_dist_sends.setdefault(send_part_id, []).append(
DistributedSend(
data=self.rec(expr.send.data),
data=rec_send_data,
dest_rank=expr.send.dest_rank,
comm_tag=expr.send.comm_tag,
tags=expr.send.tags))
Expand Down Expand Up @@ -560,7 +562,7 @@ def __init__(self,

# type-ignore reason: incompatible attribute type wrt base.
self._cache: Dict[Tuple[ArrayOrNames, int],
Any] = {} # type: ignore[assignment]
ArrayOrNames] = {} # type: ignore[assignment]

# type-ignore-reason: incompatible with super class
def get_cache_key(self, # type: ignore[override]
Expand All @@ -572,11 +574,12 @@ def get_cache_key(self, # type: ignore[override]

# type-ignore-reason: incompatible with super class
def rec(self, # type: ignore[override]
expr: ArrayOrNames,
user_part_id: int) -> Any:
expr: CopyMapperT,
user_part_id: int) -> CopyMapperT:
key = self.get_cache_key(expr, user_part_id)
try:
return self._cache[key]
# type-ignore-reason: parametric dicts are not a thing in typing module
return self._cache[key] # type: ignore[return-value]
except KeyError:
if isinstance(expr, Array):
if expr in self.stored_array_to_part_id:
Expand All @@ -592,6 +595,12 @@ def rec(self, # type: ignore[override]
self._cache[key] = result
return result

# type-ignore-reason: incompatible with super class
def __call__(self, # type: ignore[override]
expr: CopyMapperT,
user_part_id: int) -> CopyMapperT:
return self.rec(expr, user_part_id)


def _remove_part_id_tag(ary: ArrayOrNames) -> Array:
assert isinstance(ary, Array)
Expand Down
5 changes: 4 additions & 1 deletion pytato/partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,10 @@ def handle_edge(self, expr: ArrayOrNames, child: ArrayOrNames) -> Any:
tags=child.tags,
axes=child.axes)

self.var_name_to_result[ph_name] = self.rec(child)
# type-ignore-reason: mypy is right, types of self.rec are
# imprecise (TODO)
self.var_name_to_result[ph_name] = (
self.rec(child)) # type: ignore[assignment]

self._seen_node_to_placeholder[child] = ph

Expand Down
9 changes: 4 additions & 5 deletions pytato/target/python/numpy_like.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
cast, List, Set, Tuple, Type)

from pytools import UniqueNameGenerator
from pytato.transform import CachedMapper, ArrayOrNames
from pytato.transform import CachedMapper
from pytato.array import (Stack, Concatenate, IndexLambda, DataWrapper,
Placeholder, SizeParam, Roll,
AxisPermutation, Einsum,
Expand Down Expand Up @@ -164,7 +164,7 @@ def _is_slice_trivial(slice_: NormalizedSlice,
}


class NumpyCodegenMapper(CachedMapper[ArrayOrNames]):
class NumpyCodegenMapper(CachedMapper[str]):
"""
.. note::
Expand Down Expand Up @@ -408,7 +408,7 @@ def _map_index_base(self, expr: IndexBase) -> str:
)

if last_non_trivial_index == -1:
return self.rec(expr.array) # type: ignore[no-any-return]
return self.rec(expr.array)

lhs = self.vng("_pt_tmp")

Expand Down Expand Up @@ -500,8 +500,7 @@ def map_reshape(self, expr: Reshape) -> str:
return self._record_line_and_return_lhs(lhs, rhs)

def map_named_array(self, expr: NamedArray) -> str:
# type-ignore-reason: CachedMapper.rec's types are imprecise
return self.rec(expr.expr) # type: ignore[no-any-return]
return self.rec(expr.expr)

def map_dict_of_named_arrays(self, expr: DictOfNamedArrays) -> str:
lhs = self.vng("_pt_tmp")
Expand Down
68 changes: 50 additions & 18 deletions pytato/transform/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,11 @@
from pymbolic.mapper.optimize import optimize_mapper

ArrayOrNames = Union[Array, AbstractResultWithNamedArrays]
MappedT = TypeVar("MappedT", bound=ArrayOrNames)
MappedT = TypeVar("MappedT",
Array, AbstractResultWithNamedArrays, ArrayOrNames)
CombineT = TypeVar("CombineT") # used in CombineMapper
CopyMapperT = TypeVar("CopyMapperT", # used in CopyMapper
Array, AbstractResultWithNamedArrays, ArrayOrNames)
CachedMapperT = TypeVar("CachedMapperT") # used in CachedMapper
IndexOrShapeExpr = TypeVar("IndexOrShapeExpr")
R = FrozenSet[Array]
Expand Down Expand Up @@ -180,20 +183,26 @@ class CachedMapper(Mapper, Generic[CachedMapperT]):
"""

def __init__(self) -> None:
self._cache: Dict[CachedMapperT, Any] = {}
self._cache: Dict[Any, CachedMapperT] = {}

def get_cache_key(self, expr: CachedMapperT) -> Any:
def get_cache_key(self, expr: ArrayOrNames) -> Any:
return expr

# type-ignore-reason: incompatible with super class
def rec(self, expr: CachedMapperT) -> Any: # type: ignore[override]
def rec(self, expr: ArrayOrNames) -> CachedMapperT: # type: ignore[override]
key = self.get_cache_key(expr)
try:
return self._cache[key]
except KeyError:
result = super().rec(expr) # type: ignore[type-var]
result = super().rec(expr)
self._cache[key] = result
return result
# type-ignore-reason: Mapper.rec has imprecise func. signature
return result # type: ignore[no-any-return]

# type-ignore-reason: incompatible with super class
def __call__(self, expr: ArrayOrNames # type: ignore[override]
) -> CachedMapperT:
return self.rec(expr)

# }}}

Expand All @@ -210,9 +219,21 @@ class CopyMapper(CachedMapper[ArrayOrNames]):
This does not copy the data of a :class:`pytato.array.DataWrapper`.
"""

# type-ignore-reason: specialized variant of super-class' rec method
def rec(self, expr: CopyMapperT) -> CopyMapperT: # type: ignore[override]
# type-ignore-reason: CachedMapper.rec's return type is imprecise
return super().rec(expr) # type: ignore[return-value]

# type-ignore-reason: specialized variant of super-class' rec method
def __call__(self, expr: CopyMapperT) -> CopyMapperT: # type: ignore[override]
return self.rec(expr)

def rec_idx_or_size_tuple(self, situp: Tuple[IndexOrShapeExpr, ...]
) -> Tuple[IndexOrShapeExpr, ...]:
return tuple(self.rec(s) if isinstance(s, Array) else s for s in situp)
# type-ignore-reason: apparently mypy cannot substitute typevars
# here.
return tuple(self.rec(s) if isinstance(s, Array) else s # type: ignore[misc]
for s in situp)

def map_index_lambda(self, expr: IndexLambda) -> Array:
bindings: Dict[str, Array] = {
Expand Down Expand Up @@ -319,8 +340,10 @@ def map_loopy_call(self, expr: LoopyCall) -> LoopyCall:
)

def map_loopy_call_result(self, expr: LoopyCallResult) -> Array:
rec_container = self.rec(expr._container)
assert isinstance(rec_container, LoopyCall)
return LoopyCallResult(
loopy_call=self.rec(expr._container),
loopy_call=rec_container,
name=expr.name,
axes=expr.axes,
tags=expr.tags)
Expand Down Expand Up @@ -364,7 +387,7 @@ def __init__(self) -> None:
Tuple[Any, ...],
Tuple[Tuple[str, Any], ...]
],
Any] = {} # type: ignore[assignment]
ArrayOrNames] = {}

def get_cache_key(self,
expr: ArrayOrNames,
Expand All @@ -375,23 +398,30 @@ def get_cache_key(self,
return (expr, args, tuple(sorted(kwargs.items())))

def rec(self,
expr: ArrayOrNames,
*args: Any, **kwargs: Any) -> Any:
expr: CopyMapperT,
*args: Any, **kwargs: Any) -> CopyMapperT:
key = self.get_cache_key(expr, *args, **kwargs)
try:
return self._cache[key]
# type-ignore-reason: self._cache has ArrayOrNames as its values
return self._cache[key] # type: ignore[return-value]
except KeyError:
result = Mapper.rec(self, expr,
*args,
**kwargs)
self._cache[key] = result
return result
# type-ignore-reason: Mapper.rec is imprecise
return result # type: ignore[no-any-return]

def rec_idx_or_size_tuple(self, situp: Tuple[IndexOrShapeExpr, ...],
*args: Any, **kwargs: Any
) -> Tuple[IndexOrShapeExpr, ...]:
return tuple(self.rec(s, *args, **kwargs) if isinstance(s, Array) else s
for s in situp)
# type-ignore-reason: apparently mypy cannot substitute typevars
# here.
return tuple(
self.rec(s, *args, **kwargs) # type: ignore[misc]
if isinstance(s, Array)
else s
for s in situp)

def map_index_lambda(self, expr: IndexLambda,
*args: Any, **kwargs: Any) -> Array:
Expand Down Expand Up @@ -510,8 +540,10 @@ def map_loopy_call(self, expr: LoopyCall,

def map_loopy_call_result(self, expr: LoopyCallResult,
*args: Any, **kwargs: Any) -> Array:
rec_loopy_call = self.rec(expr._container, *args, **kwargs)
assert isinstance(rec_loopy_call, LoopyCall)
return LoopyCallResult(
loopy_call=self.rec(expr._container, *args, **kwargs),
loopy_call=rec_loopy_call,
name=expr.name,
axes=expr.axes,
tags=expr.tags)
Expand Down Expand Up @@ -1018,11 +1050,11 @@ def __init__(self, map_fn: Callable[[ArrayOrNames], ArrayOrNames]) -> None:
# type-ignore-reason:incompatible with Mapper.rec()
def rec(self, expr: ArrayOrNames) -> ArrayOrNames: # type: ignore[override]
if expr in self._cache:
return self._cache[expr] # type: ignore[no-any-return]
return self._cache[expr]

result = super().rec(self.map_fn(expr))
self._cache[expr] = result
return result # type: ignore[no-any-return]
return result

# type-ignore-reason: Mapper.__call__ returns Any
def __call__(self, expr: ArrayOrNames) -> ArrayOrNames: # type: ignore[override]
Expand Down
4 changes: 1 addition & 3 deletions pytato/transform/remove_broadcasts_einsum.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,6 @@ def rewrite_einsums_with_no_broadcasts(expr: MappedT) -> MappedT:
alter its value.
"""
mapper = EinsumWithNoBroadcastsRewriter()

# type-ignore-reason: mypy is right i.e. CopyMapper.__call__ is imprecise
return mapper(expr) # type: ignore[no-any-return]
return mapper(expr)

# vim:fdm=marker

0 comments on commit b4de0ab

Please sign in to comment.