From c09e9f17382a915a20a20a8d4357d2de7c6858fc Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Tue, 2 Jul 2024 07:57:09 -0500 Subject: [PATCH] move axis tag attaching code into a separate method in AxisTagAttacher --- pytato/transform/metadata.py | 103 ++++++++++++++++++----------------- 1 file changed, 54 insertions(+), 49 deletions(-) diff --git a/pytato/transform/metadata.py b/pytato/transform/metadata.py index 49190c76e..635a478d8 100644 --- a/pytato/transform/metadata.py +++ b/pytato/transform/metadata.py @@ -601,56 +601,61 @@ def __init__(self, self.axis_to_tags: Mapping[tuple[Array, int], Collection[Tag]] = axis_to_tags self.tag_corresponding_redn_descr: bool = tag_corresponding_redn_descr + def _attach_tags(self, expr: Array, rec_expr: Array) -> Array: + assert rec_expr.ndim == expr.ndim + + result = rec_expr + + for iaxis in range(expr.ndim): + result = result.with_tagged_axis( + iaxis, self.axis_to_tags.get((expr, iaxis), [])) + + # {{{ tag reduction descrs + + if self.tag_corresponding_redn_descr: + if isinstance(expr, Einsum): + assert isinstance(result, Einsum) + for arg, access_descrs in zip(expr.args, + expr.access_descriptors, + strict=True): + for iaxis, access_descr in enumerate(access_descrs): + if isinstance(access_descr, EinsumReductionAxis): + result = result.with_tagged_reduction( + access_descr, + self.axis_to_tags.get((arg, iaxis), []) + ) + + if isinstance(expr, IndexLambda): + assert isinstance(result, IndexLambda) + try: + hlo = index_lambda_to_high_level_op(expr) + except UnknownIndexLambdaExpr: + pass + else: + if isinstance(hlo, ReduceOp): + for iaxis, redn_var in hlo.axes.items(): + result = result.with_tagged_reduction( + redn_var, + self.axis_to_tags.get((hlo.x, iaxis), []) + ) + + # }}} + + return result + def rec(self, expr: ArrayOrNames) -> ArrayOrNames: - if isinstance(expr, AbstractResultWithNamedArrays | DistributedSendRefHolder): - return super().rec(expr) - else: - assert isinstance(expr, Array) - key = self.get_cache_key(expr) - try: - return self._cache[key] - except KeyError: - expr_copy = Mapper.rec(self, expr) - assert isinstance(expr_copy, Array) - assert expr_copy.ndim == expr.ndim - - for iaxis in range(expr.ndim): - expr_copy = expr_copy.with_tagged_axis( - iaxis, self.axis_to_tags.get((expr, iaxis), [])) - - # {{{ tag reduction descrs - - if self.tag_corresponding_redn_descr: - if isinstance(expr, Einsum): - assert isinstance(expr_copy, Einsum) - for arg, access_descrs in zip(expr.args, - expr.access_descriptors, - strict=True): - for iaxis, access_descr in enumerate(access_descrs): - if isinstance(access_descr, EinsumReductionAxis): - expr_copy = expr_copy.with_tagged_reduction( - access_descr, - self.axis_to_tags.get((arg, iaxis), []) - ) - - if isinstance(expr, IndexLambda): - assert isinstance(expr_copy, IndexLambda) - try: - hlo = index_lambda_to_high_level_op(expr) - except UnknownIndexLambdaExpr: - pass - else: - if isinstance(hlo, ReduceOp): - for iaxis, redn_var in hlo.axes.items(): - expr_copy = expr_copy.with_tagged_reduction( - redn_var, - self.axis_to_tags.get((hlo.x, iaxis), []) - ) - - # }}} - - self._cache[key] = expr_copy - return expr_copy + key = self.get_cache_key(expr) + try: + return self._cache[key] + except KeyError: + result = Mapper.rec(self, expr) + if not isinstance( + expr, AbstractResultWithNamedArrays | DistributedSendRefHolder): + assert isinstance(expr, Array) + # type-ignore reason: passed "ArrayOrNames"; expected "Array" + result = self._attach_tags(expr, result) # type: ignore[arg-type] + self._cache[key] = result + return result def map_named_call_result(self, expr: NamedCallResult) -> Array: raise NotImplementedError(