Skip to content

Commit

Permalink
Permit "same_kind" casting for element-wise in-place operators (#2170)
Browse files Browse the repository at this point in the history
The PR proposes to permit `"same_kind"` casting for element-wise
in-place operators. The implementation leverages on dpctl changes added
in scope of [PR#1827](IntelPython/dpctl#1827).

It also adds callbacks to support in-place bit-wise operators (leverages
on dpctl changes from
[RR#1447](IntelPython/dpctl#1447)).

The PR removes a temporary workaround from `dpnp.wrap` which depends on
the implemented changes.
  • Loading branch information
antonwolfy authored Jan 11, 2025
1 parent f7c0938 commit 3d02b6b
Show file tree
Hide file tree
Showing 7 changed files with 1,200 additions and 889 deletions.
18 changes: 14 additions & 4 deletions dpnp/dpnp_algo/dpnp_elementwise_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,20 @@ def __call__(
"as an argument, but both were provided."
)

x1_usm = dpnp.get_usm_ndarray_or_scalar(x1)
x2_usm = dpnp.get_usm_ndarray_or_scalar(x2)
out_usm = None if out is None else dpnp.get_usm_ndarray(out)

if (
isinstance(x1, dpnp_array)
and x1 is out
and order == "K"
and dtype is None
):
# in-place operation
super()._inplace_op(x1_usm, x2_usm)
return x1

if order is None:
order = "K"
elif order in "afkcAFKC":
Expand All @@ -344,9 +358,6 @@ def __call__(
"order must be one of 'C', 'F', 'A', or 'K' (got '{order}')"
)

x1_usm = dpnp.get_usm_ndarray_or_scalar(x1)
x2_usm = dpnp.get_usm_ndarray_or_scalar(x2)

if dtype is not None:
if dpnp.isscalar(x1):
x1_usm = dpt.asarray(
Expand All @@ -368,7 +379,6 @@ def __call__(
x1_usm = dpt.astype(x1_usm, dtype, copy=False)
x2_usm = dpt.astype(x2_usm, dtype, copy=False)

out_usm = None if out is None else dpnp.get_usm_ndarray(out)
res_usm = super().__call__(x1_usm, x2_usm, out=out_usm, order=order)

if out is not None and isinstance(out, dpnp_array):
Expand Down
2 changes: 1 addition & 1 deletion dpnp/dpnp_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,7 @@ def __imatmul__(self, other):
axes = [(-2, -1), (-2, -1), (-2, -1)]

try:
dpnp.matmul(self, other, out=self, axes=axes)
dpnp.matmul(self, other, out=self, dtype=self.dtype, axes=axes)
except AxisError:
# AxisError should indicate that the axes argument didn't work out
# which should mean the second operand not being 2 dimensional.
Expand Down
5 changes: 5 additions & 0 deletions dpnp/dpnp_iface_bitwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@ def binary_repr(num, width=None):
ti._bitwise_and_result_type,
ti._bitwise_and,
_BITWISE_AND_DOCSTRING,
binary_inplace_fn=ti._bitwise_and_inplace,
)


Expand Down Expand Up @@ -285,6 +286,7 @@ def binary_repr(num, width=None):
ti._bitwise_or_result_type,
ti._bitwise_or,
_BITWISE_OR_DOCSTRING,
binary_inplace_fn=ti._bitwise_or_inplace,
)


Expand Down Expand Up @@ -366,6 +368,7 @@ def binary_repr(num, width=None):
ti._bitwise_xor_result_type,
ti._bitwise_xor,
_BITWISE_XOR_DOCSTRING,
binary_inplace_fn=ti._bitwise_xor_inplace,
)


Expand Down Expand Up @@ -518,6 +521,7 @@ def binary_repr(num, width=None):
ti._bitwise_left_shift_result_type,
ti._bitwise_left_shift,
_LEFT_SHIFT_DOCSTRING,
binary_inplace_fn=ti._bitwise_left_shift_inplace,
)

bitwise_left_shift = left_shift # bitwise_left_shift is an alias for left_shift
Expand Down Expand Up @@ -595,6 +599,7 @@ def binary_repr(num, width=None):
ti._bitwise_right_shift_result_type,
ti._bitwise_right_shift,
_RIGHT_SHIFT_DOCSTRING,
binary_inplace_fn=ti._bitwise_right_shift_inplace,
)

# bitwise_right_shift is an alias for right_shift
Expand Down
4 changes: 1 addition & 3 deletions dpnp/dpnp_iface_trigonometric.py
Original file line number Diff line number Diff line change
Expand Up @@ -2450,7 +2450,5 @@ def unwrap(p, discont=None, axis=-1, *, period=2 * dpnp.pi):

up = dpnp.astype(p, dtype=dt, copy=True)
up[slice1] = p[slice1]
# TODO: replace, once dpctl-1757 resolved
# up[slice1] += ph_correct.cumsum(axis=axis)
up[slice1] += ph_correct.cumsum(axis=axis, dtype=dt)
up[slice1] += ph_correct.cumsum(axis=axis)
return up
Loading

0 comments on commit 3d02b6b

Please sign in to comment.