From e83b4a63150fe50cc24298b063151b9004df8e57 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Thu, 23 Jan 2025 17:05:16 -0600 Subject: [PATCH 1/2] Design: clarify that reserved namespaces are not mandatory --- doc/design.rst | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/doc/design.rst b/doc/design.rst index d7791cab7..af53109a0 100644 --- a/doc/design.rst +++ b/doc/design.rst @@ -169,6 +169,13 @@ Reserved Identifiers as automatically generated names (if required) in :attr:`pytato.array.IndexLambda.bindings`. + +.. note:: + + Other than the iname names (``_[0-9]+``), these naming conventions are not + compulsory. The above is merely intended to set aside parts of the namespace + for this purpose that are guaranteed not to be trampled on by the user. + Tags ---- From 91be5939eca326306c4efa6a89e9fa254b65b04f Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Thu, 23 Jan 2025 17:06:55 -0600 Subject: [PATCH 2/2] Enable Ruff SIM rules --- doc/conf.py | 5 +++-- pyproject.toml | 1 + pytato/analysis/__init__.py | 13 ++++++------- pytato/array.py | 25 +++++++++---------------- pytato/codegen.py | 29 +++++++++++++---------------- pytato/scalar_expr.py | 2 +- pytato/target/loopy/__init__.py | 4 ++-- pytato/target/loopy/codegen.py | 5 +---- pytato/utils.py | 27 +++++++-------------------- pytato/visualization/dot.py | 12 ++++-------- test/test_apps.py | 5 ++--- test/test_codegen.py | 6 ++---- test/test_distributed.py | 4 ++-- 13 files changed, 53 insertions(+), 85 deletions(-) diff --git a/doc/conf.py b/doc/conf.py index 19cfe8581..e9a2710f4 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -10,8 +10,9 @@ author = "Pytato Contributors" ver_dic = {} -exec(compile(open("../pytato/version.py").read(), "../pytato/version.py", - "exec"), ver_dic) +with open("../pytato/version.py") as vfile: + exec(compile(vfile.read(), "../pytato/version.py", "exec"), ver_dic) + version = ".".join(str(x) for x in ver_dic["VERSION"]) release = ver_dic["VERSION_TEXT"] diff --git a/pyproject.toml b/pyproject.toml index 38019888c..099041a4a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -64,6 +64,7 @@ extend-select = [ "RUF", "UP", "TC", + "SIM", ] extend-ignore = [ "E226", diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index ce21160dd..2d91e51eb 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -249,13 +249,12 @@ def _get_indices_from_input_subscript(subscript: str, # }}} - if is_output: - if len(normalized_indices) != len(set(normalized_indices)): - repeated_idx = next(idx - for idx in normalized_indices - if normalized_indices.count(idx) > 1) - raise ValueError(f"Output subscript '{subscript}' contains " - f"'{repeated_idx}' multiple times.") + if is_output and len(normalized_indices) != len(set(normalized_indices)): + repeated_idx = next(idx + for idx in normalized_indices + if normalized_indices.count(idx) > 1) + raise ValueError(f"Output subscript '{subscript}' contains " + f"'{repeated_idx}' multiple times.") return tuple(normalized_indices) diff --git a/pytato/array.py b/pytato/array.py index e42706613..ca44c2c2b 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -355,10 +355,7 @@ def _augment_array_dataclass( attr_tuple_hash = ", ".join(f"self.{fld.name}" for fld in fields(cls) if fld.name != "non_equality_tags") - if attr_tuple_hash: - attr_tuple_hash = f"({attr_tuple_hash},)" - else: - attr_tuple_hash = "()" + attr_tuple_hash = f"({attr_tuple_hash},)" if attr_tuple_hash else "()" from pytools.codegen import remove_common_indentation augment_code = remove_common_indentation( @@ -414,11 +411,10 @@ def _dataclass_setstate(self, state): # place, or it inherits a value but does not set it itself. sets_mapper_method = "_mapper_method" in mm_cls.__dict__ - if sets_mapper_method: - if default_mapper_method_name == mm_cls._mapper_method: - warn(f"Explicit _mapper_method on {mm_cls} not needed, default matches " - "explicit assignment. Just delete the explicit assignment.", - stacklevel=3) + if sets_mapper_method and default_mapper_method_name == mm_cls._mapper_method: + warn(f"Explicit _mapper_method on {mm_cls} not needed, default matches " + "explicit assignment. Just delete the explicit assignment.", + stacklevel=3) if not sets_mapper_method: mm_cls._mapper_method = intern(default_mapper_method_name) @@ -1500,9 +1496,9 @@ def einsum(subscripts: str, *operands: Array, raise ValueError(f"'{idx}' is not a reduction dim.") for descr in index_to_descr.values(): - if isinstance(descr, EinsumReductionAxis): - if descr not in redn_axis_to_redn_descr: - redn_axis_to_redn_descr[descr] = ReductionDescriptor(frozenset()) + if (isinstance(descr, EinsumReductionAxis) + and descr not in redn_axis_to_redn_descr): + redn_axis_to_redn_descr[descr] = ReductionDescriptor(frozenset()) # }}} @@ -2356,10 +2352,7 @@ def full(shape: ConvertibleToShape, fill_value: Scalar | prim.NaN, if order != "C": raise ValueError("Only C-ordered arrays supported for now.") - if dtype is None: - conv_dtype = np.array(fill_value).dtype - else: - conv_dtype = np.dtype(dtype) + conv_dtype = np.array(fill_value).dtype if dtype is None else np.dtype(dtype) shape = normalize_shape(shape) diff --git a/pytato/codegen.py b/pytato/codegen.py index 85ac4052d..eae7ea286 100644 --- a/pytato/codegen.py +++ b/pytato/codegen.py @@ -79,10 +79,7 @@ def is_symbolic_index(o: object) -> TypeIs[SymbolicIndex]: if isinstance(o, tuple): - for i in o: - if not is_integral_scalar_expression(i): - return False - return True + return all(is_integral_scalar_expression(i) for i in o) else: return False @@ -280,18 +277,18 @@ def get_function_definition_cache_key(self, expr: FunctionDefinition) -> int: return id(expr) def post_visit(self, expr: Any) -> None: - if isinstance(expr, Placeholder | SizeParam | DataWrapper): - if expr.name is not None: - try: - ary = self.name_to_input[expr.name] - except KeyError: - self.name_to_input[expr.name] = expr - else: - if ary is not expr: - from pytato.diagnostic import NameClashError - raise NameClashError( - "Received two separate instances of inputs " - f"named '{expr.name}'.") + if (isinstance(expr, Placeholder | SizeParam | DataWrapper) + and expr.name is not None): + try: + ary = self.name_to_input[expr.name] + except KeyError: + self.name_to_input[expr.name] = expr + else: + if ary is not expr: + from pytato.diagnostic import NameClashError + raise NameClashError( + "Received two separate instances of inputs " + f"named '{expr.name}'.") def check_validity_of_outputs(exprs: DictOfNamedArrays) -> None: diff --git a/pytato/scalar_expr.py b/pytato/scalar_expr.py index e74007c0f..4d55b4985 100644 --- a/pytato/scalar_expr.py +++ b/pytato/scalar_expr.py @@ -92,7 +92,7 @@ def is_integral_scalar_expression(expr: object) -> TypeIs[IntegralScalarExpression]: - return isinstance(expr, int | np.integer) or isinstance(expr, prim.ExpressionNode) + return isinstance(expr, int | np.integer | prim.ExpressionNode) def parse(s: str) -> ScalarExpression: diff --git a/pytato/target/loopy/__init__.py b/pytato/target/loopy/__init__.py index c6568ba58..ea66eea14 100644 --- a/pytato/target/loopy/__init__.py +++ b/pytato/target/loopy/__init__.py @@ -220,7 +220,7 @@ def __call__(self, queue: pyopencl.CommandQueue, # type: ignore **kwargs: Any) -> Any: """Convenience function for launching a :mod:`pyopencl` computation.""" - if __debug__: + if __debug__: # noqa: SIM102 if set(kwargs.keys()) & set(self.bound_arguments.keys()): raise ValueError("Got arguments that were previously bound: " f"{set(kwargs.keys()) & set(self.bound_arguments.keys())}.") @@ -314,7 +314,7 @@ def __call__(self, queue: pyopencl.CommandQueue, # type: ignore **kwargs: Any) -> Any: """Convenience function for launching a :mod:`pyopencl` computation.""" - if __debug__: + if __debug__: # noqa: SIM102 if set(kwargs.keys()) & set(self.bound_arguments.keys()): raise ValueError("Got arguments that were previously bound: " f"{set(kwargs.keys()) & set(self.bound_arguments.keys())}.") diff --git a/pytato/target/loopy/codegen.py b/pytato/target/loopy/codegen.py index c4b029a33..c8e2e084f 100644 --- a/pytato/target/loopy/codegen.py +++ b/pytato/target/loopy/codegen.py @@ -904,10 +904,7 @@ def add_store(name: str, expr: Array, result: ImplementedResult, # Make the instruction from loopy.kernel.instruction import make_assignment - if indices: - assignee = prim.Variable(name)[indices] - else: - assignee = prim.Variable(name) + assignee = prim.Variable(name)[indices] if indices else prim.Variable(name) insn_id = state.insn_id_gen(f"{name}_store") insn = make_assignment((assignee,), loopy_expr, diff --git a/pytato/utils.py b/pytato/utils.py index 10a5b7d43..362cc4599 100644 --- a/pytato/utils.py +++ b/pytato/utils.py @@ -129,9 +129,8 @@ def _get_result_axis_length(axis_lengths: list[ShapeComponent] ) -> ShapeComponent: result_axis_len = axis_lengths[0] for axis_len in axis_lengths[1:]: - if are_shape_components_equal(axis_len, result_axis_len): - pass - elif are_shape_components_equal(axis_len, 1): + if (are_shape_components_equal(axis_len, result_axis_len) + or are_shape_components_equal(axis_len, 1)): pass elif are_shape_components_equal(result_axis_len, 1): result_axis_len = axis_len @@ -233,7 +232,7 @@ def cast_to_result_type( array: ArrayOrScalar, expr: ScalarExpression | Bool ) -> ScalarExpression | Bool: - if ((isinstance(array, Array) or isinstance(array, np.generic)) + if ((isinstance(array, Array | np.generic)) and array.dtype != result_dtype): # Loopy's type casts don't like casting to bool assert result_dtype != np.bool_ @@ -460,15 +459,9 @@ def _normalize_slice(slice_: slice, if -axis_len <= start < axis_len: start = start % axis_len elif start >= axis_len: - if step > 0: - start = axis_len - else: - start = axis_len - 1 + start = axis_len if step > 0 else axis_len - 1 else: - if step > 0: - start = 0 - else: - start = -1 + start = 0 if step > 0 else -1 else: raise NotImplementedError @@ -479,15 +472,9 @@ def _normalize_slice(slice_: slice, if -axis_len <= stop < axis_len: stop = stop % axis_len elif stop >= axis_len: - if step > 0: - stop = axis_len - else: - stop = axis_len - 1 + stop = axis_len if step > 0 else axis_len - 1 else: - if step > 0: - stop = 0 - else: - stop = -1 + stop = 0 if step > 0 else -1 else: raise NotImplementedError diff --git a/pytato/visualization/dot.py b/pytato/visualization/dot.py index 02352fb95..c0c3e7945 100644 --- a/pytato/visualization/dot.py +++ b/pytato/visualization/dot.py @@ -207,11 +207,7 @@ def handle_unsupported_array(self, expr: Array) -> None: continue attr = getattr(expr, field.name) - if isinstance(attr, Array): - self.rec(attr) - info.edges[field.name] = attr - - elif isinstance(attr, AbstractResultWithNamedArrays): + if isinstance(attr, Array | AbstractResultWithNamedArrays): self.rec(attr) info.edges[field.name] = attr @@ -632,7 +628,7 @@ def get_dot_graph_from_partition(partition: DistributedGraphPartition) -> str: for part in partition.parts.values(): array_to_id = {} - for array in part_id_func_to_node_info[part.pid, None].keys(): + for array in part_id_func_to_node_info[part.pid, None]: if isinstance(array, Placeholder): # Placeholders are only emitted once if array in placeholder_to_id: @@ -704,7 +700,7 @@ def get_dot_graph_from_partition(partition: DistributedGraphPartition) -> str: input_arrays: list[Array] = [] internal_arrays: list[ArrayOrNames] = [] - for array in part_node_to_info.keys(): + for array in part_node_to_info: if isinstance(array, InputArgumentBase): input_arrays.append(array) else: @@ -811,7 +807,7 @@ def get_dot_graph_from_partition(partition: DistributedGraphPartition) -> str: # {{{ draw overall outputs combined_array_to_id: dict[ArrayOrNames, str] = {} - for part_id in partition.parts.keys(): + for part_id in partition.parts: combined_array_to_id.update(part_id_to_array_to_id[part_id]) _emit_name_cluster( diff --git a/test/test_apps.py b/test/test_apps.py index afb3a9ae4..fe1ba18bb 100644 --- a/test/test_apps.py +++ b/test/test_apps.py @@ -110,9 +110,8 @@ def __init__(self, fft_vec_gatherer): def map_index_lambda(self, expr): tags = expr.tags_of_type(FFTIntermediate) - if tags: - if self.finalized or expr in self.old_array_to_new_array: - return self.old_array_to_new_array[expr] + if tags and (self.finalized or expr in self.old_array_to_new_array): + return self.old_array_to_new_array[expr] return super().map_index_lambda( expr.copy(expr=ConstantSizer()(expr.expr))) diff --git a/test/test_codegen.py b/test/test_codegen.py index 5b62ce468..a1ba89611 100755 --- a/test/test_codegen.py +++ b/test/test_codegen.py @@ -454,11 +454,9 @@ def test_slice(ctx_factory, shape): outputs = {} ref_outputs = {} - i = 0 - for slice_ in generate_test_slices(shape): + for i, slice_ in enumerate(generate_test_slices(shape)): outputs[f"out_{i}"] = x[slice_] ref_outputs[f"out_{i}"] = x_in[slice_] - i += 1 prog = pt.generate_loopy(outputs) @@ -1956,7 +1954,7 @@ def build_expression(tracer): assert len(outputs) == len(expected) - for key in outputs.keys(): + for key in outputs: np.testing.assert_allclose(outputs[key], expected[key]) diff --git a/test/test_distributed.py b/test/test_distributed.py index 1554a024b..d78479e08 100644 --- a/test/test_distributed.py +++ b/test/test_distributed.py @@ -135,7 +135,7 @@ def test_distributed_scheduler_counts(): count_list = np.zeros(len(sizes)) for i, tree_size in enumerate(sizes): needed_ids = {i: set() for i in range(int(tree_size))} - for key in needed_ids.keys(): + for key in needed_ids: needed_ids[key] = {key-1} if key > 0 else set() _, count_list[i] = _schedule_task_batches_counted(needed_ids) @@ -190,7 +190,7 @@ def test_distributed_scheduling_o_n_direct_dependents(): count_list = np.zeros(len(sizes)) for i, tree_size in enumerate(sizes): needed_ids = {i: set() for i in range(int(tree_size))} - for key in needed_ids.keys(): + for key in needed_ids: for j in range(key): needed_ids[key].add(j) _, count_list[i] = _schedule_task_batches_counted(needed_ids)