Skip to content

Commit

Permalink
Merge branch 'main' into tag_created_at
Browse files Browse the repository at this point in the history
  • Loading branch information
matthiasdiener authored Nov 22, 2023
2 parents 36166c6 + 7776a53 commit d110d0f
Show file tree
Hide file tree
Showing 10 changed files with 131 additions and 82 deletions.
20 changes: 17 additions & 3 deletions pytato/analysis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
DictOfNamedArrays, NamedArray,
IndexBase, IndexRemappingBase, InputArgumentBase,
ShapeType)
from pytato.function import FunctionDefinition, Call
from pytato.function import FunctionDefinition, Call, NamedCallResult
from pytato.transform import Mapper, ArrayOrNames, CachedWalkMapper
from pytato.loopy import LoopyCall
from pymbolic.mapper.optimize import optimize_mapper
Expand Down Expand Up @@ -174,6 +174,15 @@ def map_distributed_recv(self, expr: DistributedRecv) -> None:
self.nusers[dim] += 1
self.rec(dim)

def map_call(self, expr: Call) -> None:
for ary in expr.bindings.values():
if isinstance(ary, Array):
self.nusers[ary] += 1
self.rec(ary)

def map_named_call_result(self, expr: NamedCallResult) -> None:
self.rec(expr._container)

# }}}


Expand Down Expand Up @@ -358,6 +367,12 @@ def map_distributed_send_ref_holder(self,
) -> FrozenSet[Array]:
return frozenset([expr.passthrough_data])

def map_named_call_result(self, expr: NamedCallResult) -> FrozenSet[Array]:
raise NotImplementedError(
"DirectPredecessorsGetter does not yet support expressions containing "
"functions.")


# }}}


Expand Down Expand Up @@ -423,10 +438,9 @@ def map_function_definition(self, /, expr: FunctionDefinition,
if not self.visit(expr):
return

new_mapper = self.clone_for_callee()
new_mapper = self.clone_for_callee(expr)
for subexpr in expr.returns.values():
new_mapper(subexpr, *args, **kwargs)

self.count += new_mapper.count

self.post_visit(expr, *args, **kwargs)
Expand Down
7 changes: 4 additions & 3 deletions pytato/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
DataInterface, SizeParam, InputArgumentBase,
make_dict_of_named_arrays)

from pytato.function import NamedCallResult
from pytato.transform.lower_to_index_lambda import ToIndexLambdaMixin

from pytato.scalar_expr import IntegralScalarExpression
Expand Down Expand Up @@ -112,9 +113,6 @@ def __init__(self, target: Target,
self.target = target
self.kernels_seen: Dict[str, lp.LoopKernel] = kernels_seen or {}

def clone_for_callee(self) -> CodeGenPreprocessor:
return CodeGenPreprocessor(self.target, self.kernels_seen)

def map_size_param(self, expr: SizeParam) -> Array:
name = expr.name
assert name is not None
Expand Down Expand Up @@ -196,6 +194,9 @@ def map_data_wrapper(self, expr: DataWrapper) -> Array:
axes=expr.axes,
tags=expr.tags)

def map_named_call_result(self, expr: NamedCallResult) -> Array:
raise NotImplementedError("CodeGenPreprocessor does not support functions.")

# }}}


Expand Down
17 changes: 14 additions & 3 deletions pytato/distributed/partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,8 @@
from pytato.distributed.nodes import CommTagType
from pytato.analysis import DirectPredecessorsGetter

from pytato.function import FunctionDefinition, NamedCallResult

if TYPE_CHECKING:
import mpi4py.MPI

Expand Down Expand Up @@ -291,6 +293,12 @@ def __init__(self,
self.user_input_names: Set[str] = set()
self.partition_input_name_to_placeholder: Dict[str, Placeholder] = {}

def clone_for_callee(
self, function: FunctionDefinition) -> _DistributedInputReplacer:
# Function definitions aren't allowed to contain receives,
# stored arrays promoted to part outputs, or part outputs
return type(self)({}, {}, {})

def map_placeholder(self, expr: Placeholder) -> Placeholder:
self.user_input_names.add(expr.name)
return expr
Expand Down Expand Up @@ -323,8 +331,6 @@ def map_distributed_send(self, expr: DistributedSend) -> DistributedSend:

# type ignore because no args, kwargs
def rec(self, expr: ArrayOrNames) -> ArrayOrNames: # type: ignore[override]
assert isinstance(expr, Array)

key = self.get_cache_key(expr)
try:
return self._cache[key]
Expand All @@ -334,7 +340,7 @@ def rec(self, expr: ArrayOrNames) -> ArrayOrNames: # type: ignore[override]
# If the array is an output from the current part, it would
# be counterproductive to turn it into a placeholder: we're
# the ones who are supposed to compute it!
if expr not in self.output_arrays:
if isinstance(expr, Array) and expr not in self.output_arrays:

name = self.sptpo_ary_to_name.get(expr)
if name is not None:
Expand Down Expand Up @@ -502,6 +508,11 @@ def map_distributed_recv(

return frozenset({recv_id})

def map_named_call_result(
self, expr: NamedCallResult) -> FrozenSet[CommunicationOpIdentifier]:
raise NotImplementedError(
"LocalSendRecvDepGatherer does not support functions.")

# }}}


Expand Down
31 changes: 18 additions & 13 deletions pytato/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,8 @@
Hashable, Sequence, ClassVar, Iterator, Iterable, Mapping)
from immutabledict import immutabledict
from functools import cached_property
from pytato.array import (Array, AbstractResultWithNamedArrays,
Placeholder, NamedArray, ShapeType, _dtype_any,
InputArgumentBase)
from pytato.array import (Array, AbstractResultWithNamedArrays,
Placeholder, NamedArray, ShapeType, _dtype_any)
from pytools.tag import Tag, Taggable

ReturnT = TypeVar("ReturnT", Array, Tuple[Array, ...], Dict[str, Array])
Expand Down Expand Up @@ -132,19 +131,25 @@ class FunctionDefinition(Taggable):
@cached_property
def _placeholders(self) -> Mapping[str, Placeholder]:
from pytato.transform import InputGatherer
from functools import reduce

mapper = InputGatherer()

all_input_args: FrozenSet[InputArgumentBase] = reduce(
frozenset.union,
(mapper(ary) for ary in self.returns.values()),
frozenset()
)
all_placeholders: FrozenSet[Placeholder] = frozenset()
for ary in self.returns.values():
new_placeholders = frozenset({
arg for arg in mapper(ary)
if isinstance(arg, Placeholder)})
all_placeholders |= new_placeholders

return immutabledict({input_arg.name: input_arg
for input_arg in all_input_args
if isinstance(input_arg, Placeholder)})
# FIXME: Need a way to check for *any* captured arrays, not just placeholders
if __debug__:
pl_names = frozenset(arg.name for arg in all_placeholders)
extra_pl_names = pl_names - self.parameters
assert not extra_pl_names, \
f"Found non-argument placeholder '{next(iter(extra_pl_names))}' " \
"in function definition."

return immutabledict({arg.name: arg for arg in all_placeholders})

def get_placeholder(self, name: str) -> Placeholder:
"""
Expand All @@ -168,7 +173,7 @@ def __call__(self, **kwargs: Array

if self.parameters != frozenset(kwargs):
missing_params = self.parameters - frozenset(kwargs)
extra_params = self.parameters - frozenset(kwargs)
extra_params = frozenset(kwargs) - self.parameters

raise TypeError(
"Incorrect arguments."
Expand Down
35 changes: 35 additions & 0 deletions pytato/stringifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,13 @@

import numpy as np

from pytools import memoize_method

from typing import Any, Dict, Tuple, cast
from pytato.transform import Mapper
from pytato.array import (Array, DataWrapper, DictOfNamedArrays, Axis,
IndexLambda, ReductionDescriptor)
from pytato.function import FunctionDefinition, Call
from pytato.loopy import LoopyCall
from immutabledict import immutabledict
import attrs
Expand Down Expand Up @@ -155,6 +158,38 @@ def _get_field_val(field: str) -> str:
for field in attrs.fields(type(expr)))
+ ")")

@memoize_method
def map_function_definition(self, expr: FunctionDefinition, depth: int) -> str:
if depth > self.truncation_depth:
return self.truncation_string

def _get_field_val(field: str) -> str:
if field == "returns":
return self.rec(getattr(expr, field), depth+1)
else:
return repr(getattr(expr, field))

return (f"{type(expr).__name__}("
+ ", ".join(f"{field.name}={_get_field_val(field.name)}"
for field in attrs.fields(type(expr)))
+ ")")

def map_call(self, expr: Call, depth: int) -> str:
if depth > self.truncation_depth:
return self.truncation_string

def _get_field_val(field: str) -> str:
if field == "function":
return self.map_function_definition(expr.function, depth+1)
else:
return self.rec(getattr(expr, field), depth+1)

return (f"{type(expr).__name__}("
+ ", ".join(f"{field}={_get_field_val(field)}"
for field in ["function",
"bindings"])
+ ")")

def map_loopy_call(self, expr: LoopyCall, depth: int) -> str:
if depth > self.truncation_depth:
return self.truncation_string
Expand Down
5 changes: 3 additions & 2 deletions pytato/target/loopy/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,7 @@ class CodeGenMapper(Mapper):
def __init__(self,
array_tag_t_to_not_propagate: FrozenSet[Type[Tag]],
axis_tag_t_to_not_propagate: FrozenSet[Type[Tag]]) -> None:
super().__init__()
self.exprgen_mapper = InlinedExpressionGenMapper(axis_tag_t_to_not_propagate)
self.array_tag_t_to_not_propagate = array_tag_t_to_not_propagate
self.axis_tag_t_to_not_propagate = axis_tag_t_to_not_propagate
Expand Down Expand Up @@ -591,13 +592,13 @@ def map_named_call_result(self, expr: NamedCallResult,
raise NotImplementedError("LoopyTarget does not support outlined calls"
" (yet). As a fallback, the call"
" could be inlined using"
" pt.mark_all_calls_to_be_inlined.")
" pt.tag_all_calls_to_be_inlined.")

def map_call(self, expr: Call, state: CodeGenState) -> None:
raise NotImplementedError("LoopyTarget does not support outlined calls"
" (yet). As a fallback, the call"
" could be inlined using"
" pt.mark_all_calls_to_be_inlined.")
" pt.tag_all_calls_to_be_inlined.")

# }}}

Expand Down
Loading

0 comments on commit d110d0f

Please sign in to comment.