Skip to content

Commit

Permalink
cleanup OrderedSet
Browse files Browse the repository at this point in the history
  • Loading branch information
matthiasdiener committed Nov 10, 2023
1 parent 9af0d34 commit a476462
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 39 deletions.
70 changes: 31 additions & 39 deletions pytato/distributed/partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@
DistributedRecv, DistributedSend, DistributedSendRefHolder)
from pytato.distributed.nodes import CommTagType
from pytato.analysis import DirectPredecessorsGetter
from orderedsets import OrderedSet, FrozenOrderedSet

if TYPE_CHECKING:
import mpi4py.MPI
Expand Down Expand Up @@ -118,15 +119,6 @@ class CommunicationOpIdentifier:
_KeyT = TypeVar("_KeyT")
_ValueT = TypeVar("_ValueT")


# {{{ crude ordered set

from orderedsets import OrderedSet as _OrderedSet
from orderedsets import FrozenOrderedSet as frozenset

# }}}


# {{{ distributed graph part

PartId = Hashable
Expand Down Expand Up @@ -239,9 +231,9 @@ def __init__(self,
self.recvd_ary_to_name = recvd_ary_to_name
self.sptpo_ary_to_name = sptpo_ary_to_name
self.name_to_output = name_to_output
self.output_arrays = frozenset(name_to_output.values())
self.output_arrays = FrozenOrderedSet(name_to_output.values())

self.user_input_names: Set[str] = _OrderedSet()
self.user_input_names: Set[str] = OrderedSet()
self.partition_input_name_to_placeholder: Dict[str, Placeholder] = {}

def map_placeholder(self, expr: Placeholder) -> Placeholder:
Expand Down Expand Up @@ -353,11 +345,11 @@ def _make_distributed_partition(

parts[part_id] = DistributedGraphPart(
pid=part_id,
needed_pids=frozenset({part_id - 1} if part_id else {}),
user_input_names=frozenset(comm_replacer.user_input_names),
partition_input_names=frozenset(
needed_pids=FrozenOrderedSet({part_id - 1} if part_id else {}),
user_input_names=FrozenOrderedSet(comm_replacer.user_input_names),
partition_input_names=FrozenOrderedSet(
comm_replacer.partition_input_name_to_placeholder.keys()),
output_names=frozenset(name_to_ouput.keys()),
output_names=FrozenOrderedSet(name_to_ouput.keys()),
name_to_recv_node=immutabledict({
recvd_ary_to_name[local_recv_id_to_recv_node[recv_id]]:
local_recv_id_to_recv_node[recv_id]
Expand Down Expand Up @@ -419,7 +411,7 @@ def __init__(self, local_rank: int) -> None:
def combine(
self, *args: FrozenSet[CommunicationOpIdentifier]
) -> FrozenSet[CommunicationOpIdentifier]:
return reduce(frozenset.union, args, frozenset())
return reduce(FrozenOrderedSet.union, args, FrozenOrderedSet())

def map_distributed_send_ref_holder(self,
expr: DistributedSendRefHolder
Expand All @@ -438,7 +430,7 @@ def map_distributed_send_ref_holder(self,
return self.rec(expr.passthrough_data)

def _map_input_base(self, expr: Array) -> FrozenSet[CommunicationOpIdentifier]:
return frozenset()
return FrozenOrderedSet()

map_placeholder = _map_input_base
map_data_wrapper = _map_input_base
Expand All @@ -453,11 +445,11 @@ def map_distributed_recv(
from pytato.distributed.verify import DuplicateRecvError
raise DuplicateRecvError(f"Multiple receives found for '{recv_id}'")

self.local_comm_ids_to_needed_comm_ids[recv_id] = frozenset()
self.local_comm_ids_to_needed_comm_ids[recv_id] = FrozenOrderedSet()

self.local_recv_id_to_recv_node[recv_id] = expr

return frozenset({recv_id})
return FrozenOrderedSet({recv_id})

# }}}

Expand Down Expand Up @@ -493,7 +485,7 @@ def _schedule_task_batches_counted(
task_to_dep_level, visits_in_depend = \
_calculate_dependency_levels(task_ids_to_needed_task_ids)
nlevels = 1 + max(task_to_dep_level.values(), default=-1)
task_batches: Sequence[Set[TaskType]] = [set() for _ in range(nlevels)]
task_batches: Sequence[Set[TaskType]] = [OrderedSet() for _ in range(nlevels)]

for task_id, dep_level in task_to_dep_level.items():
task_batches[dep_level].add(task_id)
Expand All @@ -518,7 +510,7 @@ def _calculate_dependency_levels(
1 + the maximum dependency level for its children.
"""
task_to_dep_level: Dict[TaskType, int] = {}
seen: set[TaskType] = set()
seen: set[TaskType] = OrderedSet()
nodes_visited: int = 0

def _dependency_level_dfs(task_id: TaskType) -> int:
Expand Down Expand Up @@ -559,7 +551,7 @@ class _MaterializedArrayCollector(CachedWalkMapper):
"""
def __init__(self) -> None:
super().__init__()
self.materialized_arrays: _OrderedSet[Array] = _OrderedSet()
self.materialized_arrays: OrderedSet[Array] = OrderedSet()

# type-ignore-reason: dropped the extra `*args, **kwargs`.
def get_cache_key(self, expr: ArrayOrNames) -> int: # type: ignore[override]
Expand Down Expand Up @@ -595,7 +587,7 @@ def _set_dict_union_mpi(
assert mpi_data_type is None
result = dict(dict_a)
for key, values in dict_b.items():
result[key] = result.get(key, frozenset()) | values
result[key] = result.get(key, FrozenOrderedSet()) | values
return result

# }}}
Expand Down Expand Up @@ -773,9 +765,9 @@ def find_distributed_partition(

part_comm_ids: List[_PartCommIDs] = []
if comm_batches:
recv_ids: FrozenSet[CommunicationOpIdentifier] = frozenset()
recv_ids: FrozenSet[CommunicationOpIdentifier] = FrozenOrderedSet()
for batch in comm_batches:
send_ids = frozenset(
send_ids = FrozenOrderedSet(
comm_id for comm_id in batch
if comm_id.src_rank == local_rank)
if recv_ids or send_ids:
Expand All @@ -784,19 +776,19 @@ def find_distributed_partition(
recv_ids=recv_ids,
send_ids=send_ids))
# These go into the next part
recv_ids = frozenset(
recv_ids = FrozenOrderedSet(
comm_id for comm_id in batch
if comm_id.dest_rank == local_rank)
if recv_ids:
part_comm_ids.append(
_PartCommIDs(
recv_ids=recv_ids,
send_ids=frozenset()))
send_ids=FrozenOrderedSet()))
else:
part_comm_ids.append(
_PartCommIDs(
recv_ids=frozenset(),
send_ids=frozenset()))
recv_ids=FrozenOrderedSet(),
send_ids=FrozenOrderedSet()))

nparts = len(part_comm_ids)

Expand Down Expand Up @@ -826,10 +818,10 @@ def find_distributed_partition(
# The sets of arrays below must have a deterministic order in order to ensure
# that the resulting partition is also deterministic

sent_arrays = _OrderedSet(
sent_arrays = FrozenOrderedSet(
send_node.data for _, send_node in sorted(lsrdg.local_send_id_to_send_node.items()))

received_arrays = _OrderedSet([recv for _, recv in sorted(lsrdg.local_recv_id_to_recv_node.items())])
received_arrays = FrozenOrderedSet([recv for _, recv in sorted(lsrdg.local_recv_id_to_recv_node.items())])

# While receive nodes may be marked as materialized, we shouldn't be
# including them here because we're using them (along with the send nodes)
Expand All @@ -838,17 +830,17 @@ def find_distributed_partition(
# from send *nodes*, but we choose to exclude them in order to simplify the
# processing below.
materialized_arrays = (
materialized_arrays_collector.materialized_arrays
FrozenOrderedSet(materialized_arrays_collector.materialized_arrays)
- received_arrays
- sent_arrays)

# "mso" for "materialized/sent/output"
output_arrays = _OrderedSet(outputs._data.values())
output_arrays = FrozenOrderedSet(outputs._data.values())
mso_arrays = materialized_arrays | sent_arrays | output_arrays

# FIXME: This gathers up materialized_arrays recursively, leading to
# result sizes potentially quadratic in the number of materialized arrays.
mso_array_dep_mapper = SubsetDependencyMapper(frozenset(mso_arrays))
mso_array_dep_mapper = SubsetDependencyMapper(FrozenOrderedSet(mso_arrays))

mso_ary_to_first_dep_send_part_id: Dict[Array, int] = {
ary: nparts
Expand Down Expand Up @@ -909,29 +901,29 @@ def find_distributed_partition(
assert all(0 <= part_id < nparts
for part_id in stored_ary_to_part_id.values())

stored_arrays = _OrderedSet(stored_ary_to_part_id)
stored_arrays = FrozenOrderedSet(stored_ary_to_part_id)

# {{{ find which stored arrays should become part outputs
# (because they are used in not just their local part, but also others)

direct_preds_getter = DirectPredecessorsGetter()

def get_materialized_predecessors(ary: Array) -> _OrderedSet[Array]:
materialized_preds: _OrderedSet[Array] = _OrderedSet()
def get_materialized_predecessors(ary: Array) -> OrderedSet[Array]:
materialized_preds: OrderedSet[Array] = OrderedSet()
for pred in direct_preds_getter(ary):
if pred in materialized_arrays:
materialized_preds.add(pred)
else:
materialized_preds |= get_materialized_predecessors(pred)
return materialized_preds

stored_arrays_promoted_to_part_outputs = {
stored_arrays_promoted_to_part_outputs = FrozenOrderedSet([
stored_pred
for stored_ary in stored_arrays
for stored_pred in get_materialized_predecessors(stored_ary)
if (stored_ary_to_part_id[stored_ary]
!= stored_ary_to_part_id[stored_pred])
}
])

# }}}

Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
"immutabledict",
"attrs",
"bidict",
"orderedsets",
],
package_data={"pytato": ["py.typed"]},
author="Andreas Kloeckner, Matt Wala, Xiaoyu Wei",
Expand Down

0 comments on commit a476462

Please sign in to comment.