Skip to content

Commit

Permalink
Enable Ruff SIM rules
Browse files Browse the repository at this point in the history
  • Loading branch information
inducer committed Jan 23, 2025
1 parent e83b4a6 commit 91be593
Show file tree
Hide file tree
Showing 13 changed files with 53 additions and 85 deletions.
5 changes: 3 additions & 2 deletions doc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@
author = "Pytato Contributors"

ver_dic = {}
exec(compile(open("../pytato/version.py").read(), "../pytato/version.py",
"exec"), ver_dic)
with open("../pytato/version.py") as vfile:
exec(compile(vfile.read(), "../pytato/version.py", "exec"), ver_dic)

version = ".".join(str(x) for x in ver_dic["VERSION"])
release = ver_dic["VERSION_TEXT"]

Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ extend-select = [
"RUF",
"UP",
"TC",
"SIM",
]
extend-ignore = [
"E226",
Expand Down
13 changes: 6 additions & 7 deletions pytato/analysis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,13 +249,12 @@ def _get_indices_from_input_subscript(subscript: str,

# }}}

if is_output:
if len(normalized_indices) != len(set(normalized_indices)):
repeated_idx = next(idx
for idx in normalized_indices
if normalized_indices.count(idx) > 1)
raise ValueError(f"Output subscript '{subscript}' contains "
f"'{repeated_idx}' multiple times.")
if is_output and len(normalized_indices) != len(set(normalized_indices)):
repeated_idx = next(idx
for idx in normalized_indices
if normalized_indices.count(idx) > 1)
raise ValueError(f"Output subscript '{subscript}' contains "
f"'{repeated_idx}' multiple times.")

return tuple(normalized_indices)

Expand Down
25 changes: 9 additions & 16 deletions pytato/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,10 +355,7 @@ def _augment_array_dataclass(
attr_tuple_hash = ", ".join(f"self.{fld.name}"
for fld in fields(cls) if fld.name != "non_equality_tags")

if attr_tuple_hash:
attr_tuple_hash = f"({attr_tuple_hash},)"
else:
attr_tuple_hash = "()"
attr_tuple_hash = f"({attr_tuple_hash},)" if attr_tuple_hash else "()"

from pytools.codegen import remove_common_indentation
augment_code = remove_common_indentation(
Expand Down Expand Up @@ -414,11 +411,10 @@ def _dataclass_setstate(self, state):
# place, or it inherits a value but does not set it itself.
sets_mapper_method = "_mapper_method" in mm_cls.__dict__

if sets_mapper_method:
if default_mapper_method_name == mm_cls._mapper_method:
warn(f"Explicit _mapper_method on {mm_cls} not needed, default matches "
"explicit assignment. Just delete the explicit assignment.",
stacklevel=3)
if sets_mapper_method and default_mapper_method_name == mm_cls._mapper_method:
warn(f"Explicit _mapper_method on {mm_cls} not needed, default matches "
"explicit assignment. Just delete the explicit assignment.",
stacklevel=3)

if not sets_mapper_method:
mm_cls._mapper_method = intern(default_mapper_method_name)
Expand Down Expand Up @@ -1500,9 +1496,9 @@ def einsum(subscripts: str, *operands: Array,
raise ValueError(f"'{idx}' is not a reduction dim.")

for descr in index_to_descr.values():
if isinstance(descr, EinsumReductionAxis):
if descr not in redn_axis_to_redn_descr:
redn_axis_to_redn_descr[descr] = ReductionDescriptor(frozenset())
if (isinstance(descr, EinsumReductionAxis)
and descr not in redn_axis_to_redn_descr):
redn_axis_to_redn_descr[descr] = ReductionDescriptor(frozenset())

# }}}

Expand Down Expand Up @@ -2356,10 +2352,7 @@ def full(shape: ConvertibleToShape, fill_value: Scalar | prim.NaN,
if order != "C":
raise ValueError("Only C-ordered arrays supported for now.")

if dtype is None:
conv_dtype = np.array(fill_value).dtype
else:
conv_dtype = np.dtype(dtype)
conv_dtype = np.array(fill_value).dtype if dtype is None else np.dtype(dtype)

shape = normalize_shape(shape)

Expand Down
29 changes: 13 additions & 16 deletions pytato/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,7 @@

def is_symbolic_index(o: object) -> TypeIs[SymbolicIndex]:
if isinstance(o, tuple):
for i in o:
if not is_integral_scalar_expression(i):
return False
return True
return all(is_integral_scalar_expression(i) for i in o)
else:
return False

Expand Down Expand Up @@ -280,18 +277,18 @@ def get_function_definition_cache_key(self, expr: FunctionDefinition) -> int:
return id(expr)

def post_visit(self, expr: Any) -> None:
if isinstance(expr, Placeholder | SizeParam | DataWrapper):
if expr.name is not None:
try:
ary = self.name_to_input[expr.name]
except KeyError:
self.name_to_input[expr.name] = expr
else:
if ary is not expr:
from pytato.diagnostic import NameClashError
raise NameClashError(
"Received two separate instances of inputs "
f"named '{expr.name}'.")
if (isinstance(expr, Placeholder | SizeParam | DataWrapper)
and expr.name is not None):
try:
ary = self.name_to_input[expr.name]
except KeyError:
self.name_to_input[expr.name] = expr
else:
if ary is not expr:
from pytato.diagnostic import NameClashError
raise NameClashError(
"Received two separate instances of inputs "
f"named '{expr.name}'.")


def check_validity_of_outputs(exprs: DictOfNamedArrays) -> None:
Expand Down
2 changes: 1 addition & 1 deletion pytato/scalar_expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@


def is_integral_scalar_expression(expr: object) -> TypeIs[IntegralScalarExpression]:
return isinstance(expr, int | np.integer) or isinstance(expr, prim.ExpressionNode)
return isinstance(expr, int | np.integer | prim.ExpressionNode)


def parse(s: str) -> ScalarExpression:
Expand Down
4 changes: 2 additions & 2 deletions pytato/target/loopy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ def __call__(self, queue: pyopencl.CommandQueue, # type: ignore
**kwargs: Any) -> Any:
"""Convenience function for launching a :mod:`pyopencl` computation."""

if __debug__:
if __debug__: # noqa: SIM102
if set(kwargs.keys()) & set(self.bound_arguments.keys()):
raise ValueError("Got arguments that were previously bound: "
f"{set(kwargs.keys()) & set(self.bound_arguments.keys())}.")
Expand Down Expand Up @@ -314,7 +314,7 @@ def __call__(self, queue: pyopencl.CommandQueue, # type: ignore
**kwargs: Any) -> Any:
"""Convenience function for launching a :mod:`pyopencl` computation."""

if __debug__:
if __debug__: # noqa: SIM102
if set(kwargs.keys()) & set(self.bound_arguments.keys()):
raise ValueError("Got arguments that were previously bound: "
f"{set(kwargs.keys()) & set(self.bound_arguments.keys())}.")
Expand Down
5 changes: 1 addition & 4 deletions pytato/target/loopy/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -904,10 +904,7 @@ def add_store(name: str, expr: Array, result: ImplementedResult,

# Make the instruction
from loopy.kernel.instruction import make_assignment
if indices:
assignee = prim.Variable(name)[indices]
else:
assignee = prim.Variable(name)
assignee = prim.Variable(name)[indices] if indices else prim.Variable(name)
insn_id = state.insn_id_gen(f"{name}_store")
insn = make_assignment((assignee,),
loopy_expr,
Expand Down
27 changes: 7 additions & 20 deletions pytato/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,9 +129,8 @@ def _get_result_axis_length(axis_lengths: list[ShapeComponent]
) -> ShapeComponent:
result_axis_len = axis_lengths[0]
for axis_len in axis_lengths[1:]:
if are_shape_components_equal(axis_len, result_axis_len):
pass
elif are_shape_components_equal(axis_len, 1):
if (are_shape_components_equal(axis_len, result_axis_len)
or are_shape_components_equal(axis_len, 1)):
pass
elif are_shape_components_equal(result_axis_len, 1):
result_axis_len = axis_len
Expand Down Expand Up @@ -233,7 +232,7 @@ def cast_to_result_type(
array: ArrayOrScalar,
expr: ScalarExpression | Bool
) -> ScalarExpression | Bool:
if ((isinstance(array, Array) or isinstance(array, np.generic))
if ((isinstance(array, Array | np.generic))
and array.dtype != result_dtype):
# Loopy's type casts don't like casting to bool
assert result_dtype != np.bool_
Expand Down Expand Up @@ -460,15 +459,9 @@ def _normalize_slice(slice_: slice,
if -axis_len <= start < axis_len:
start = start % axis_len
elif start >= axis_len:
if step > 0:
start = axis_len
else:
start = axis_len - 1
start = axis_len if step > 0 else axis_len - 1
else:
if step > 0:
start = 0
else:
start = -1
start = 0 if step > 0 else -1
else:
raise NotImplementedError

Expand All @@ -479,15 +472,9 @@ def _normalize_slice(slice_: slice,
if -axis_len <= stop < axis_len:
stop = stop % axis_len
elif stop >= axis_len:
if step > 0:
stop = axis_len
else:
stop = axis_len - 1
stop = axis_len if step > 0 else axis_len - 1
else:
if step > 0:
stop = 0
else:
stop = -1
stop = 0 if step > 0 else -1
else:
raise NotImplementedError

Expand Down
12 changes: 4 additions & 8 deletions pytato/visualization/dot.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,11 +207,7 @@ def handle_unsupported_array(self, expr: Array) -> None:
continue
attr = getattr(expr, field.name)

if isinstance(attr, Array):
self.rec(attr)
info.edges[field.name] = attr

elif isinstance(attr, AbstractResultWithNamedArrays):
if isinstance(attr, Array | AbstractResultWithNamedArrays):
self.rec(attr)
info.edges[field.name] = attr

Expand Down Expand Up @@ -632,7 +628,7 @@ def get_dot_graph_from_partition(partition: DistributedGraphPartition) -> str:

for part in partition.parts.values():
array_to_id = {}
for array in part_id_func_to_node_info[part.pid, None].keys():
for array in part_id_func_to_node_info[part.pid, None]:
if isinstance(array, Placeholder):
# Placeholders are only emitted once
if array in placeholder_to_id:
Expand Down Expand Up @@ -704,7 +700,7 @@ def get_dot_graph_from_partition(partition: DistributedGraphPartition) -> str:
input_arrays: list[Array] = []
internal_arrays: list[ArrayOrNames] = []

for array in part_node_to_info.keys():
for array in part_node_to_info:
if isinstance(array, InputArgumentBase):
input_arrays.append(array)
else:
Expand Down Expand Up @@ -811,7 +807,7 @@ def get_dot_graph_from_partition(partition: DistributedGraphPartition) -> str:
# {{{ draw overall outputs

combined_array_to_id: dict[ArrayOrNames, str] = {}
for part_id in partition.parts.keys():
for part_id in partition.parts:
combined_array_to_id.update(part_id_to_array_to_id[part_id])

_emit_name_cluster(
Expand Down
5 changes: 2 additions & 3 deletions test/test_apps.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,8 @@ def __init__(self, fft_vec_gatherer):

def map_index_lambda(self, expr):
tags = expr.tags_of_type(FFTIntermediate)
if tags:
if self.finalized or expr in self.old_array_to_new_array:
return self.old_array_to_new_array[expr]
if tags and (self.finalized or expr in self.old_array_to_new_array):
return self.old_array_to_new_array[expr]

return super().map_index_lambda(
expr.copy(expr=ConstantSizer()(expr.expr)))
Expand Down
6 changes: 2 additions & 4 deletions test/test_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,11 +454,9 @@ def test_slice(ctx_factory, shape):
outputs = {}
ref_outputs = {}

i = 0
for slice_ in generate_test_slices(shape):
for i, slice_ in enumerate(generate_test_slices(shape)):
outputs[f"out_{i}"] = x[slice_]
ref_outputs[f"out_{i}"] = x_in[slice_]
i += 1

prog = pt.generate_loopy(outputs)

Expand Down Expand Up @@ -1956,7 +1954,7 @@ def build_expression(tracer):

assert len(outputs) == len(expected)

for key in outputs.keys():
for key in outputs:
np.testing.assert_allclose(outputs[key], expected[key])


Expand Down
4 changes: 2 additions & 2 deletions test/test_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def test_distributed_scheduler_counts():
count_list = np.zeros(len(sizes))
for i, tree_size in enumerate(sizes):
needed_ids = {i: set() for i in range(int(tree_size))}
for key in needed_ids.keys():
for key in needed_ids:
needed_ids[key] = {key-1} if key > 0 else set()
_, count_list[i] = _schedule_task_batches_counted(needed_ids)

Expand Down Expand Up @@ -190,7 +190,7 @@ def test_distributed_scheduling_o_n_direct_dependents():
count_list = np.zeros(len(sizes))
for i, tree_size in enumerate(sizes):
needed_ids = {i: set() for i in range(int(tree_size))}
for key in needed_ids.keys():
for key in needed_ids:
for j in range(key):
needed_ids[key].add(j)
_, count_list[i] = _schedule_task_batches_counted(needed_ids)
Expand Down

0 comments on commit 91be593

Please sign in to comment.