Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes for functions #471

Merged
merged 14 commits into from
Nov 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 17 additions & 3 deletions pytato/analysis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
DictOfNamedArrays, NamedArray,
IndexBase, IndexRemappingBase, InputArgumentBase,
ShapeType)
from pytato.function import FunctionDefinition, Call
from pytato.function import FunctionDefinition, Call, NamedCallResult
from pytato.transform import Mapper, ArrayOrNames, CachedWalkMapper
from pytato.loopy import LoopyCall
from pymbolic.mapper.optimize import optimize_mapper
Expand Down Expand Up @@ -174,6 +174,15 @@ def map_distributed_recv(self, expr: DistributedRecv) -> None:
self.nusers[dim] += 1
self.rec(dim)

def map_call(self, expr: Call) -> None:
for ary in expr.bindings.values():
if isinstance(ary, Array):
self.nusers[ary] += 1
self.rec(ary)

def map_named_call_result(self, expr: NamedCallResult) -> None:
self.rec(expr._container)

# }}}


Expand Down Expand Up @@ -358,6 +367,12 @@ def map_distributed_send_ref_holder(self,
) -> FrozenSet[Array]:
return frozenset([expr.passthrough_data])

def map_named_call_result(self, expr: NamedCallResult) -> FrozenSet[Array]:
raise NotImplementedError(
"DirectPredecessorsGetter does not yet support expressions containing "
"functions.")


# }}}


Expand Down Expand Up @@ -423,10 +438,9 @@ def map_function_definition(self, /, expr: FunctionDefinition,
if not self.visit(expr):
return

new_mapper = self.clone_for_callee()
new_mapper = self.clone_for_callee(expr)
for subexpr in expr.returns.values():
new_mapper(subexpr, *args, **kwargs)

self.count += new_mapper.count

self.post_visit(expr, *args, **kwargs)
Expand Down
7 changes: 4 additions & 3 deletions pytato/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
DataInterface, SizeParam, InputArgumentBase,
make_dict_of_named_arrays)

from pytato.function import NamedCallResult
from pytato.transform.lower_to_index_lambda import ToIndexLambdaMixin

from pytato.scalar_expr import IntegralScalarExpression
Expand Down Expand Up @@ -112,9 +113,6 @@ def __init__(self, target: Target,
self.target = target
self.kernels_seen: Dict[str, lp.LoopKernel] = kernels_seen or {}

def clone_for_callee(self) -> CodeGenPreprocessor:
return CodeGenPreprocessor(self.target, self.kernels_seen)

def map_size_param(self, expr: SizeParam) -> Array:
name = expr.name
assert name is not None
Expand Down Expand Up @@ -196,6 +194,9 @@ def map_data_wrapper(self, expr: DataWrapper) -> Array:
axes=expr.axes,
tags=expr.tags)

def map_named_call_result(self, expr: NamedCallResult) -> Array:
raise NotImplementedError("CodeGenPreprocessor does not support functions.")

# }}}


Expand Down
17 changes: 14 additions & 3 deletions pytato/distributed/partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,8 @@
from pytato.distributed.nodes import CommTagType
from pytato.analysis import DirectPredecessorsGetter

from pytato.function import FunctionDefinition, NamedCallResult

if TYPE_CHECKING:
import mpi4py.MPI

Expand Down Expand Up @@ -291,6 +293,12 @@ def __init__(self,
self.user_input_names: Set[str] = set()
self.partition_input_name_to_placeholder: Dict[str, Placeholder] = {}

def clone_for_callee(
self, function: FunctionDefinition) -> _DistributedInputReplacer:
# Function definitions aren't allowed to contain receives,
# stored arrays promoted to part outputs, or part outputs
return type(self)({}, {}, {})

def map_placeholder(self, expr: Placeholder) -> Placeholder:
self.user_input_names.add(expr.name)
return expr
Expand Down Expand Up @@ -323,8 +331,6 @@ def map_distributed_send(self, expr: DistributedSend) -> DistributedSend:

# type ignore because no args, kwargs
def rec(self, expr: ArrayOrNames) -> ArrayOrNames: # type: ignore[override]
assert isinstance(expr, Array)

key = self.get_cache_key(expr)
try:
return self._cache[key]
Expand All @@ -334,7 +340,7 @@ def rec(self, expr: ArrayOrNames) -> ArrayOrNames: # type: ignore[override]
# If the array is an output from the current part, it would
# be counterproductive to turn it into a placeholder: we're
# the ones who are supposed to compute it!
if expr not in self.output_arrays:
if isinstance(expr, Array) and expr not in self.output_arrays:

name = self.sptpo_ary_to_name.get(expr)
if name is not None:
Expand Down Expand Up @@ -502,6 +508,11 @@ def map_distributed_recv(

return frozenset({recv_id})

def map_named_call_result(
self, expr: NamedCallResult) -> FrozenSet[CommunicationOpIdentifier]:
raise NotImplementedError(
"LocalSendRecvDepGatherer does not support functions.")

# }}}


Expand Down
31 changes: 18 additions & 13 deletions pytato/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,8 @@
Hashable, Sequence, ClassVar, Iterator, Iterable, Mapping)
from immutabledict import immutabledict
from functools import cached_property
from pytato.array import (Array, AbstractResultWithNamedArrays,
Placeholder, NamedArray, ShapeType, _dtype_any,
InputArgumentBase)
from pytato.array import (Array, AbstractResultWithNamedArrays,
Placeholder, NamedArray, ShapeType, _dtype_any)
from pytools.tag import Tag, Taggable

ReturnT = TypeVar("ReturnT", Array, Tuple[Array, ...], Dict[str, Array])
Expand Down Expand Up @@ -132,19 +131,25 @@ class FunctionDefinition(Taggable):
@cached_property
def _placeholders(self) -> Mapping[str, Placeholder]:
from pytato.transform import InputGatherer
from functools import reduce

mapper = InputGatherer()

all_input_args: FrozenSet[InputArgumentBase] = reduce(
frozenset.union,
(mapper(ary) for ary in self.returns.values()),
frozenset()
)
all_placeholders: FrozenSet[Placeholder] = frozenset()
for ary in self.returns.values():
new_placeholders = frozenset({
arg for arg in mapper(ary)
if isinstance(arg, Placeholder)})
all_placeholders |= new_placeholders

return immutabledict({input_arg.name: input_arg
for input_arg in all_input_args
if isinstance(input_arg, Placeholder)})
# FIXME: Need a way to check for *any* captured arrays, not just placeholders
if __debug__:
pl_names = frozenset(arg.name for arg in all_placeholders)
extra_pl_names = pl_names - self.parameters
assert not extra_pl_names, \
f"Found non-argument placeholder '{next(iter(extra_pl_names))}' " \
"in function definition."

return immutabledict({arg.name: arg for arg in all_placeholders})

def get_placeholder(self, name: str) -> Placeholder:
"""
Expand All @@ -168,7 +173,7 @@ def __call__(self, **kwargs: Array

if self.parameters != frozenset(kwargs):
missing_params = self.parameters - frozenset(kwargs)
extra_params = self.parameters - frozenset(kwargs)
extra_params = frozenset(kwargs) - self.parameters

raise TypeError(
"Incorrect arguments."
Expand Down
35 changes: 35 additions & 0 deletions pytato/stringifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,13 @@

import numpy as np

from pytools import memoize_method

from typing import Any, Dict, Tuple, cast
from pytato.transform import Mapper
from pytato.array import (Array, DataWrapper, DictOfNamedArrays, Axis,
IndexLambda, ReductionDescriptor)
from pytato.function import FunctionDefinition, Call
from pytato.loopy import LoopyCall
from immutabledict import immutabledict
import attrs
Expand Down Expand Up @@ -153,6 +156,38 @@ def _get_field_val(field: str) -> str:
for field in attrs.fields(type(expr)))
+ ")")

@memoize_method
def map_function_definition(self, expr: FunctionDefinition, depth: int) -> str:
if depth > self.truncation_depth:
return self.truncation_string

def _get_field_val(field: str) -> str:
if field == "returns":
return self.rec(getattr(expr, field), depth+1)
else:
return repr(getattr(expr, field))

return (f"{type(expr).__name__}("
+ ", ".join(f"{field.name}={_get_field_val(field.name)}"
for field in attrs.fields(type(expr)))
+ ")")

def map_call(self, expr: Call, depth: int) -> str:
if depth > self.truncation_depth:
return self.truncation_string

def _get_field_val(field: str) -> str:
if field == "function":
return self.map_function_definition(expr.function, depth+1)
else:
return self.rec(getattr(expr, field), depth+1)

return (f"{type(expr).__name__}("
+ ", ".join(f"{field}={_get_field_val(field)}"
for field in ["function",
"bindings"])
+ ")")

def map_loopy_call(self, expr: LoopyCall, depth: int) -> str:
if depth > self.truncation_depth:
return self.truncation_string
Expand Down
5 changes: 3 additions & 2 deletions pytato/target/loopy/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,7 @@ class CodeGenMapper(Mapper):
def __init__(self,
array_tag_t_to_not_propagate: FrozenSet[Type[Tag]],
axis_tag_t_to_not_propagate: FrozenSet[Type[Tag]]) -> None:
super().__init__()
self.exprgen_mapper = InlinedExpressionGenMapper(axis_tag_t_to_not_propagate)
self.array_tag_t_to_not_propagate = array_tag_t_to_not_propagate
self.axis_tag_t_to_not_propagate = axis_tag_t_to_not_propagate
Expand Down Expand Up @@ -591,13 +592,13 @@ def map_named_call_result(self, expr: NamedCallResult,
raise NotImplementedError("LoopyTarget does not support outlined calls"
" (yet). As a fallback, the call"
" could be inlined using"
" pt.mark_all_calls_to_be_inlined.")
" pt.tag_all_calls_to_be_inlined.")

def map_call(self, expr: Call, state: CodeGenState) -> None:
raise NotImplementedError("LoopyTarget does not support outlined calls"
" (yet). As a fallback, the call"
" could be inlined using"
" pt.mark_all_calls_to_be_inlined.")
" pt.tag_all_calls_to_be_inlined.")

# }}}

Expand Down
Loading
Loading