Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ruff sim #573

Merged
merged 2 commits into from
Jan 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions doc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@
author = "Pytato Contributors"

ver_dic = {}
exec(compile(open("../pytato/version.py").read(), "../pytato/version.py",
"exec"), ver_dic)
with open("../pytato/version.py") as vfile:
exec(compile(vfile.read(), "../pytato/version.py", "exec"), ver_dic)

version = ".".join(str(x) for x in ver_dic["VERSION"])
release = ver_dic["VERSION_TEXT"]

Expand Down
7 changes: 7 additions & 0 deletions doc/design.rst
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,13 @@ Reserved Identifiers
as automatically generated names (if required) in
:attr:`pytato.array.IndexLambda.bindings`.


.. note::

Other than the iname names (``_[0-9]+``), these naming conventions are not
compulsory. The above is merely intended to set aside parts of the namespace
for this purpose that are guaranteed not to be trampled on by the user.

Tags
----

Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ extend-select = [
"RUF",
"UP",
"TC",
"SIM",
]
extend-ignore = [
"E226",
Expand Down
13 changes: 6 additions & 7 deletions pytato/analysis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,13 +249,12 @@ def _get_indices_from_input_subscript(subscript: str,

# }}}

if is_output:
if len(normalized_indices) != len(set(normalized_indices)):
repeated_idx = next(idx
for idx in normalized_indices
if normalized_indices.count(idx) > 1)
raise ValueError(f"Output subscript '{subscript}' contains "
f"'{repeated_idx}' multiple times.")
if is_output and len(normalized_indices) != len(set(normalized_indices)):
repeated_idx = next(idx
for idx in normalized_indices
if normalized_indices.count(idx) > 1)
raise ValueError(f"Output subscript '{subscript}' contains "
f"'{repeated_idx}' multiple times.")

return tuple(normalized_indices)

Expand Down
25 changes: 9 additions & 16 deletions pytato/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,10 +355,7 @@ def _augment_array_dataclass(
attr_tuple_hash = ", ".join(f"self.{fld.name}"
for fld in fields(cls) if fld.name != "non_equality_tags")

if attr_tuple_hash:
attr_tuple_hash = f"({attr_tuple_hash},)"
else:
attr_tuple_hash = "()"
attr_tuple_hash = f"({attr_tuple_hash},)" if attr_tuple_hash else "()"

from pytools.codegen import remove_common_indentation
augment_code = remove_common_indentation(
Expand Down Expand Up @@ -414,11 +411,10 @@ def _dataclass_setstate(self, state):
# place, or it inherits a value but does not set it itself.
sets_mapper_method = "_mapper_method" in mm_cls.__dict__

if sets_mapper_method:
if default_mapper_method_name == mm_cls._mapper_method:
warn(f"Explicit _mapper_method on {mm_cls} not needed, default matches "
"explicit assignment. Just delete the explicit assignment.",
stacklevel=3)
if sets_mapper_method and default_mapper_method_name == mm_cls._mapper_method:
warn(f"Explicit _mapper_method on {mm_cls} not needed, default matches "
"explicit assignment. Just delete the explicit assignment.",
stacklevel=3)

if not sets_mapper_method:
mm_cls._mapper_method = intern(default_mapper_method_name)
Expand Down Expand Up @@ -1500,9 +1496,9 @@ def einsum(subscripts: str, *operands: Array,
raise ValueError(f"'{idx}' is not a reduction dim.")

for descr in index_to_descr.values():
if isinstance(descr, EinsumReductionAxis):
if descr not in redn_axis_to_redn_descr:
redn_axis_to_redn_descr[descr] = ReductionDescriptor(frozenset())
if (isinstance(descr, EinsumReductionAxis)
and descr not in redn_axis_to_redn_descr):
redn_axis_to_redn_descr[descr] = ReductionDescriptor(frozenset())

# }}}

Expand Down Expand Up @@ -2356,10 +2352,7 @@ def full(shape: ConvertibleToShape, fill_value: Scalar | prim.NaN,
if order != "C":
raise ValueError("Only C-ordered arrays supported for now.")

if dtype is None:
conv_dtype = np.array(fill_value).dtype
else:
conv_dtype = np.dtype(dtype)
conv_dtype = np.array(fill_value).dtype if dtype is None else np.dtype(dtype)

shape = normalize_shape(shape)

Expand Down
29 changes: 13 additions & 16 deletions pytato/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,7 @@

def is_symbolic_index(o: object) -> TypeIs[SymbolicIndex]:
if isinstance(o, tuple):
for i in o:
if not is_integral_scalar_expression(i):
return False
return True
return all(is_integral_scalar_expression(i) for i in o)
else:
return False

Expand Down Expand Up @@ -280,18 +277,18 @@ def get_function_definition_cache_key(self, expr: FunctionDefinition) -> int:
return id(expr)

def post_visit(self, expr: Any) -> None:
if isinstance(expr, Placeholder | SizeParam | DataWrapper):
if expr.name is not None:
try:
ary = self.name_to_input[expr.name]
except KeyError:
self.name_to_input[expr.name] = expr
else:
if ary is not expr:
from pytato.diagnostic import NameClashError
raise NameClashError(
"Received two separate instances of inputs "
f"named '{expr.name}'.")
if (isinstance(expr, Placeholder | SizeParam | DataWrapper)
and expr.name is not None):
try:
ary = self.name_to_input[expr.name]
except KeyError:
self.name_to_input[expr.name] = expr
else:
if ary is not expr:
from pytato.diagnostic import NameClashError
raise NameClashError(
"Received two separate instances of inputs "
f"named '{expr.name}'.")


def check_validity_of_outputs(exprs: DictOfNamedArrays) -> None:
Expand Down
2 changes: 1 addition & 1 deletion pytato/scalar_expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@


def is_integral_scalar_expression(expr: object) -> TypeIs[IntegralScalarExpression]:
return isinstance(expr, int | np.integer) or isinstance(expr, prim.ExpressionNode)
return isinstance(expr, int | np.integer | prim.ExpressionNode)


def parse(s: str) -> ScalarExpression:
Expand Down
4 changes: 2 additions & 2 deletions pytato/target/loopy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@
**kwargs: Any) -> Any:
"""Convenience function for launching a :mod:`pyopencl` computation."""

if __debug__:
if __debug__: # noqa: SIM102
if set(kwargs.keys()) & set(self.bound_arguments.keys()):
raise ValueError("Got arguments that were previously bound: "
f"{set(kwargs.keys()) & set(self.bound_arguments.keys())}.")
Expand All @@ -243,7 +243,7 @@
isinstance(arg, np.ndarray) or np.isscalar(arg)
for arg in kwargs.values())

return self.program(queue,

Check warning on line 246 in pytato/target/loopy/__init__.py

View workflow job for this annotation

GitHub Actions / Conda Pytest

TranslationUnit.__call__ will become uncached in 2024, meaning it will incur possibly substantial compilation cost with every invocation. Use TranslationUnit.executor to obtain an object that holds longer-lived caches.

Check warning on line 246 in pytato/target/loopy/__init__.py

View workflow job for this annotation

GitHub Actions / Conda Pytest

TranslationUnit.__call__ will become uncached in 2024, meaning it will incur possibly substantial compilation cost with every invocation. Use TranslationUnit.executor to obtain an object that holds longer-lived caches.

Check warning on line 246 in pytato/target/loopy/__init__.py

View workflow job for this annotation

GitHub Actions / Conda Pytest

TranslationUnit.__call__ will become uncached in 2024, meaning it will incur possibly substantial compilation cost with every invocation. Use TranslationUnit.executor to obtain an object that holds longer-lived caches.

Check warning on line 246 in pytato/target/loopy/__init__.py

View workflow job for this annotation

GitHub Actions / Conda Pytest

TranslationUnit.__call__ will become uncached in 2024, meaning it will incur possibly substantial compilation cost with every invocation. Use TranslationUnit.executor to obtain an object that holds longer-lived caches.

Check warning on line 246 in pytato/target/loopy/__init__.py

View workflow job for this annotation

GitHub Actions / Conda Pytest

TranslationUnit.__call__ will become uncached in 2024, meaning it will incur possibly substantial compilation cost with every invocation. Use TranslationUnit.executor to obtain an object that holds longer-lived caches.

Check warning on line 246 in pytato/target/loopy/__init__.py

View workflow job for this annotation

GitHub Actions / Conda Pytest

TranslationUnit.__call__ will become uncached in 2024, meaning it will incur possibly substantial compilation cost with every invocation. Use TranslationUnit.executor to obtain an object that holds longer-lived caches.

Check warning on line 246 in pytato/target/loopy/__init__.py

View workflow job for this annotation

GitHub Actions / Conda Pytest

TranslationUnit.__call__ will become uncached in 2024, meaning it will incur possibly substantial compilation cost with every invocation. Use TranslationUnit.executor to obtain an object that holds longer-lived caches.

Check warning on line 246 in pytato/target/loopy/__init__.py

View workflow job for this annotation

GitHub Actions / Conda Pytest

TranslationUnit.__call__ will become uncached in 2024, meaning it will incur possibly substantial compilation cost with every invocation. Use TranslationUnit.executor to obtain an object that holds longer-lived caches.

Check warning on line 246 in pytato/target/loopy/__init__.py

View workflow job for this annotation

GitHub Actions / Conda Pytest

TranslationUnit.__call__ will become uncached in 2024, meaning it will incur possibly substantial compilation cost with every invocation. Use TranslationUnit.executor to obtain an object that holds longer-lived caches.

Check warning on line 246 in pytato/target/loopy/__init__.py

View workflow job for this annotation

GitHub Actions / Conda Pytest

TranslationUnit.__call__ will become uncached in 2024, meaning it will incur possibly substantial compilation cost with every invocation. Use TranslationUnit.executor to obtain an object that holds longer-lived caches.
allocator=allocator, wait_for=wait_for,
out_host=out_host,
**updated_kwargs)
Expand Down Expand Up @@ -314,7 +314,7 @@
**kwargs: Any) -> Any:
"""Convenience function for launching a :mod:`pyopencl` computation."""

if __debug__:
if __debug__: # noqa: SIM102
if set(kwargs.keys()) & set(self.bound_arguments.keys()):
raise ValueError("Got arguments that were previously bound: "
f"{set(kwargs.keys()) & set(self.bound_arguments.keys())}.")
Expand Down
5 changes: 1 addition & 4 deletions pytato/target/loopy/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -904,10 +904,7 @@ def add_store(name: str, expr: Array, result: ImplementedResult,

# Make the instruction
from loopy.kernel.instruction import make_assignment
if indices:
assignee = prim.Variable(name)[indices]
else:
assignee = prim.Variable(name)
assignee = prim.Variable(name)[indices] if indices else prim.Variable(name)
insn_id = state.insn_id_gen(f"{name}_store")
insn = make_assignment((assignee,),
loopy_expr,
Expand Down
27 changes: 7 additions & 20 deletions pytato/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,9 +129,8 @@ def _get_result_axis_length(axis_lengths: list[ShapeComponent]
) -> ShapeComponent:
result_axis_len = axis_lengths[0]
for axis_len in axis_lengths[1:]:
if are_shape_components_equal(axis_len, result_axis_len):
pass
elif are_shape_components_equal(axis_len, 1):
if (are_shape_components_equal(axis_len, result_axis_len)
or are_shape_components_equal(axis_len, 1)):
pass
elif are_shape_components_equal(result_axis_len, 1):
result_axis_len = axis_len
Expand Down Expand Up @@ -233,7 +232,7 @@ def cast_to_result_type(
array: ArrayOrScalar,
expr: ScalarExpression | Bool
) -> ScalarExpression | Bool:
if ((isinstance(array, Array) or isinstance(array, np.generic))
if ((isinstance(array, Array | np.generic))
and array.dtype != result_dtype):
# Loopy's type casts don't like casting to bool
assert result_dtype != np.bool_
Expand Down Expand Up @@ -460,15 +459,9 @@ def _normalize_slice(slice_: slice,
if -axis_len <= start < axis_len:
start = start % axis_len
elif start >= axis_len:
if step > 0:
start = axis_len
else:
start = axis_len - 1
start = axis_len if step > 0 else axis_len - 1
else:
if step > 0:
start = 0
else:
start = -1
start = 0 if step > 0 else -1
else:
raise NotImplementedError

Expand All @@ -479,15 +472,9 @@ def _normalize_slice(slice_: slice,
if -axis_len <= stop < axis_len:
stop = stop % axis_len
elif stop >= axis_len:
if step > 0:
stop = axis_len
else:
stop = axis_len - 1
stop = axis_len if step > 0 else axis_len - 1
else:
if step > 0:
stop = 0
else:
stop = -1
stop = 0 if step > 0 else -1
else:
raise NotImplementedError

Expand Down
12 changes: 4 additions & 8 deletions pytato/visualization/dot.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,11 +207,7 @@ def handle_unsupported_array(self, expr: Array) -> None:
continue
attr = getattr(expr, field.name)

if isinstance(attr, Array):
self.rec(attr)
info.edges[field.name] = attr

elif isinstance(attr, AbstractResultWithNamedArrays):
if isinstance(attr, Array | AbstractResultWithNamedArrays):
self.rec(attr)
info.edges[field.name] = attr

Expand Down Expand Up @@ -632,7 +628,7 @@ def get_dot_graph_from_partition(partition: DistributedGraphPartition) -> str:

for part in partition.parts.values():
array_to_id = {}
for array in part_id_func_to_node_info[part.pid, None].keys():
for array in part_id_func_to_node_info[part.pid, None]:
if isinstance(array, Placeholder):
# Placeholders are only emitted once
if array in placeholder_to_id:
Expand Down Expand Up @@ -704,7 +700,7 @@ def get_dot_graph_from_partition(partition: DistributedGraphPartition) -> str:
input_arrays: list[Array] = []
internal_arrays: list[ArrayOrNames] = []

for array in part_node_to_info.keys():
for array in part_node_to_info:
if isinstance(array, InputArgumentBase):
input_arrays.append(array)
else:
Expand Down Expand Up @@ -811,7 +807,7 @@ def get_dot_graph_from_partition(partition: DistributedGraphPartition) -> str:
# {{{ draw overall outputs

combined_array_to_id: dict[ArrayOrNames, str] = {}
for part_id in partition.parts.keys():
for part_id in partition.parts:
combined_array_to_id.update(part_id_to_array_to_id[part_id])

_emit_name_cluster(
Expand Down
5 changes: 2 additions & 3 deletions test/test_apps.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,8 @@ def __init__(self, fft_vec_gatherer):

def map_index_lambda(self, expr):
tags = expr.tags_of_type(FFTIntermediate)
if tags:
if self.finalized or expr in self.old_array_to_new_array:
return self.old_array_to_new_array[expr]
if tags and (self.finalized or expr in self.old_array_to_new_array):
return self.old_array_to_new_array[expr]

return super().map_index_lambda(
expr.copy(expr=ConstantSizer()(expr.expr)))
Expand Down
6 changes: 2 additions & 4 deletions test/test_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,11 +454,9 @@ def test_slice(ctx_factory, shape):
outputs = {}
ref_outputs = {}

i = 0
for slice_ in generate_test_slices(shape):
for i, slice_ in enumerate(generate_test_slices(shape)):
outputs[f"out_{i}"] = x[slice_]
ref_outputs[f"out_{i}"] = x_in[slice_]
i += 1

prog = pt.generate_loopy(outputs)

Expand Down Expand Up @@ -1956,7 +1954,7 @@ def build_expression(tracer):

assert len(outputs) == len(expected)

for key in outputs.keys():
for key in outputs:
np.testing.assert_allclose(outputs[key], expected[key])


Expand Down
4 changes: 2 additions & 2 deletions test/test_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def test_distributed_scheduler_counts():
count_list = np.zeros(len(sizes))
for i, tree_size in enumerate(sizes):
needed_ids = {i: set() for i in range(int(tree_size))}
for key in needed_ids.keys():
for key in needed_ids:
needed_ids[key] = {key-1} if key > 0 else set()
_, count_list[i] = _schedule_task_batches_counted(needed_ids)

Expand Down Expand Up @@ -190,7 +190,7 @@ def test_distributed_scheduling_o_n_direct_dependents():
count_list = np.zeros(len(sizes))
for i, tree_size in enumerate(sizes):
needed_ids = {i: set() for i in range(int(tree_size))}
for key in needed_ids.keys():
for key in needed_ids:
for j in range(key):
needed_ids[key].add(j)
_, count_list[i] = _schedule_task_batches_counted(needed_ids)
Expand Down
Loading