Skip to content

Commit

Permalink
move axis tag attaching code into a separate method in AxisTagAttacher
Browse files Browse the repository at this point in the history
  • Loading branch information
majosm committed Jan 23, 2025
1 parent 1374fdb commit c09e9f1
Showing 1 changed file with 54 additions and 49 deletions.
103 changes: 54 additions & 49 deletions pytato/transform/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,56 +601,61 @@ def __init__(self,
self.axis_to_tags: Mapping[tuple[Array, int], Collection[Tag]] = axis_to_tags
self.tag_corresponding_redn_descr: bool = tag_corresponding_redn_descr

def _attach_tags(self, expr: Array, rec_expr: Array) -> Array:
assert rec_expr.ndim == expr.ndim

result = rec_expr

for iaxis in range(expr.ndim):
result = result.with_tagged_axis(
iaxis, self.axis_to_tags.get((expr, iaxis), []))

# {{{ tag reduction descrs

if self.tag_corresponding_redn_descr:
if isinstance(expr, Einsum):
assert isinstance(result, Einsum)
for arg, access_descrs in zip(expr.args,
expr.access_descriptors,
strict=True):
for iaxis, access_descr in enumerate(access_descrs):
if isinstance(access_descr, EinsumReductionAxis):
result = result.with_tagged_reduction(
access_descr,
self.axis_to_tags.get((arg, iaxis), [])
)

if isinstance(expr, IndexLambda):
assert isinstance(result, IndexLambda)
try:
hlo = index_lambda_to_high_level_op(expr)
except UnknownIndexLambdaExpr:
pass
else:
if isinstance(hlo, ReduceOp):
for iaxis, redn_var in hlo.axes.items():
result = result.with_tagged_reduction(
redn_var,
self.axis_to_tags.get((hlo.x, iaxis), [])
)

# }}}

return result

def rec(self, expr: ArrayOrNames) -> ArrayOrNames:
if isinstance(expr, AbstractResultWithNamedArrays | DistributedSendRefHolder):
return super().rec(expr)
else:
assert isinstance(expr, Array)
key = self.get_cache_key(expr)
try:
return self._cache[key]
except KeyError:
expr_copy = Mapper.rec(self, expr)
assert isinstance(expr_copy, Array)
assert expr_copy.ndim == expr.ndim

for iaxis in range(expr.ndim):
expr_copy = expr_copy.with_tagged_axis(
iaxis, self.axis_to_tags.get((expr, iaxis), []))

# {{{ tag reduction descrs

if self.tag_corresponding_redn_descr:
if isinstance(expr, Einsum):
assert isinstance(expr_copy, Einsum)
for arg, access_descrs in zip(expr.args,
expr.access_descriptors,
strict=True):
for iaxis, access_descr in enumerate(access_descrs):
if isinstance(access_descr, EinsumReductionAxis):
expr_copy = expr_copy.with_tagged_reduction(
access_descr,
self.axis_to_tags.get((arg, iaxis), [])
)

if isinstance(expr, IndexLambda):
assert isinstance(expr_copy, IndexLambda)
try:
hlo = index_lambda_to_high_level_op(expr)
except UnknownIndexLambdaExpr:
pass
else:
if isinstance(hlo, ReduceOp):
for iaxis, redn_var in hlo.axes.items():
expr_copy = expr_copy.with_tagged_reduction(
redn_var,
self.axis_to_tags.get((hlo.x, iaxis), [])
)

# }}}

self._cache[key] = expr_copy
return expr_copy
key = self.get_cache_key(expr)
try:
return self._cache[key]
except KeyError:
result = Mapper.rec(self, expr)
if not isinstance(
expr, AbstractResultWithNamedArrays | DistributedSendRefHolder):
assert isinstance(expr, Array)
# type-ignore reason: passed "ArrayOrNames"; expected "Array"
result = self._attach_tags(expr, result) # type: ignore[arg-type]
self._cache[key] = result
return result

def map_named_call_result(self, expr: NamedCallResult) -> Array:
raise NotImplementedError(
Expand Down

0 comments on commit c09e9f1

Please sign in to comment.