Skip to content

Commit

Permalink
Deterministic find_distributed_partition (orderedsets) (#572)
Browse files Browse the repository at this point in the history
  • Loading branch information
matthiasdiener authored Jan 22, 2025
1 parent 46431ee commit 57ab576
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 127 deletions.
1 change: 1 addition & 0 deletions doc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
"jax": ("https://jax.readthedocs.io/en/latest/", None),
"mpi4py": ("https://mpi4py.readthedocs.io/en/latest", None),
"immutabledict": ("https://immutabledict.corenting.fr/", None),
"orderedsets": ("https://matthiasdiener.github.io/orderedsets", None),
}

# Some modules need to import things just so that sphinx can resolve symbols in
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ dependencies = [
"pytools>=2024.1.21",
"pymbolic>=2024.2",
"typing_extensions>=4",
"orderedsets",
]

[project.urls]
Expand Down
54 changes: 29 additions & 25 deletions pytato/analysis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,10 @@
THE SOFTWARE.
"""

from collections.abc import Mapping
from typing import TYPE_CHECKING, Any, Never

from orderedsets import FrozenOrderedSet
from typing_extensions import Self

from loopy.tools import LoopyKeyBuilder
Expand Down Expand Up @@ -329,37 +331,37 @@ class DirectPredecessorsGetter(Mapper[frozenset[ArrayOrNames], Never, []]):
We only consider the predecessors of a nodes in a data-flow sense.
"""
def _get_preds_from_shape(self, shape: ShapeType) -> frozenset[ArrayOrNames]:
return frozenset({dim for dim in shape if isinstance(dim, Array)})
def _get_preds_from_shape(self, shape: ShapeType) -> FrozenOrderedSet[ArrayOrNames]:
return FrozenOrderedSet(dim for dim in shape if isinstance(dim, Array))

def map_index_lambda(self, expr: IndexLambda) -> frozenset[ArrayOrNames]:
return (frozenset(expr.bindings.values())
def map_index_lambda(self, expr: IndexLambda) -> FrozenOrderedSet[ArrayOrNames]:
return (FrozenOrderedSet(expr.bindings.values())
| self._get_preds_from_shape(expr.shape))

def map_stack(self, expr: Stack) -> frozenset[ArrayOrNames]:
return (frozenset(expr.arrays)
def map_stack(self, expr: Stack) -> FrozenOrderedSet[ArrayOrNames]:
return (FrozenOrderedSet(expr.arrays)
| self._get_preds_from_shape(expr.shape))

def map_concatenate(self, expr: Concatenate) -> frozenset[ArrayOrNames]:
return (frozenset(expr.arrays)
def map_concatenate(self, expr: Concatenate) -> FrozenOrderedSet[ArrayOrNames]:
return (FrozenOrderedSet(expr.arrays)
| self._get_preds_from_shape(expr.shape))

def map_einsum(self, expr: Einsum) -> frozenset[ArrayOrNames]:
return (frozenset(expr.args)
def map_einsum(self, expr: Einsum) -> FrozenOrderedSet[ArrayOrNames]:
return (FrozenOrderedSet(expr.args)
| self._get_preds_from_shape(expr.shape))

def map_loopy_call_result(self, expr: NamedArray) -> frozenset[ArrayOrNames]:
def map_loopy_call_result(self, expr: NamedArray) -> FrozenOrderedSet[ArrayOrNames]:
from pytato.loopy import LoopyCall, LoopyCallResult
assert isinstance(expr, LoopyCallResult)
assert isinstance(expr._container, LoopyCall)
return (frozenset(ary
return (FrozenOrderedSet(ary
for ary in expr._container.bindings.values()
if isinstance(ary, Array))
| self._get_preds_from_shape(expr.shape))

def _map_index_base(self, expr: IndexBase) -> frozenset[ArrayOrNames]:
return (frozenset([expr.array])
| frozenset(idx for idx in expr.indices
def _map_index_base(self, expr: IndexBase) -> FrozenOrderedSet[ArrayOrNames]:
return (FrozenOrderedSet([expr.array])
| FrozenOrderedSet(idx for idx in expr.indices
if isinstance(idx, Array))
| self._get_preds_from_shape(expr.shape))

Expand All @@ -368,34 +370,36 @@ def _map_index_base(self, expr: IndexBase) -> frozenset[ArrayOrNames]:
map_non_contiguous_advanced_index = _map_index_base

def _map_index_remapping_base(self, expr: IndexRemappingBase
) -> frozenset[ArrayOrNames]:
return frozenset([expr.array])
) -> FrozenOrderedSet[ArrayOrNames]:
return FrozenOrderedSet([expr.array])

map_roll = _map_index_remapping_base
map_axis_permutation = _map_index_remapping_base
map_reshape = _map_index_remapping_base

def _map_input_base(self, expr: InputArgumentBase) -> frozenset[ArrayOrNames]:
def _map_input_base(self, expr: InputArgumentBase) \
-> FrozenOrderedSet[ArrayOrNames]:
return self._get_preds_from_shape(expr.shape)

map_placeholder = _map_input_base
map_data_wrapper = _map_input_base
map_size_param = _map_input_base

def map_distributed_recv(self, expr: DistributedRecv) -> frozenset[ArrayOrNames]:
def map_distributed_recv(self,
expr: DistributedRecv) -> FrozenOrderedSet[ArrayOrNames]:
return self._get_preds_from_shape(expr.shape)

def map_distributed_send_ref_holder(self,
expr: DistributedSendRefHolder
) -> frozenset[ArrayOrNames]:
return frozenset([expr.passthrough_data])
) -> FrozenOrderedSet[ArrayOrNames]:
return FrozenOrderedSet([expr.passthrough_data])

def map_call(self, expr: Call) -> frozenset[ArrayOrNames]:
return frozenset(expr.bindings.values())
def map_call(self, expr: Call) -> FrozenOrderedSet[ArrayOrNames]:
return FrozenOrderedSet(expr.bindings.values())

def map_named_call_result(
self, expr: NamedCallResult) -> frozenset[ArrayOrNames]:
return frozenset([expr._container])
self, expr: NamedCallResult) -> FrozenOrderedSet[ArrayOrNames]:
return FrozenOrderedSet([expr._container])


# }}}
Expand Down
Loading

0 comments on commit 57ab576

Please sign in to comment.