From d60f5b6139d173aabbc10f48bbe283eb455d1171 Mon Sep 17 00:00:00 2001 From: "Komarova, Evseniia" Date: Tue, 12 Nov 2024 14:42:08 +0100 Subject: [PATCH] Add dpnp.common_type implementation --- dpnp/dpnp_iface.py | 68 ++++++++++++++++++++ tests/third_party/cupy/test_type_routines.py | 1 - 2 files changed, 68 insertions(+), 1 deletion(-) diff --git a/dpnp/dpnp_iface.py b/dpnp/dpnp_iface.py index 7d1a76448e5..88a35d101e5 100644 --- a/dpnp/dpnp_iface.py +++ b/dpnp/dpnp_iface.py @@ -61,6 +61,7 @@ "as_usm_ndarray", "check_limitations", "check_supported_arrays_type", + "common_type", "default_float_type", "from_dlpack", "get_dpnp_descriptor", @@ -406,6 +407,73 @@ def check_supported_arrays_type(*arrays, scalar_type=False, all_scalars=False): return True +# determine the "minimum common type" for a group of arrays +array_precision = { + dpnp.float16: 0, + dpnp.float32: 1, + dpnp.float64: 2, + dpnp.complex64: 3, + dpnp.complex128: 4, +} + +array_type = { + "float": {0: dpnp.float16, 1: dpnp.float32, 2: dpnp.float64}, + "complex": {3: dpnp.complex64, 4: dpnp.complex128}, +} + + +def common_type(*arrays): + """ + Return a scalar type which is common to the input arrays. + + The return type will always be an inexact (i.e. floating point) scalar + type, even if all the arrays are integer arrays. If one of the inputs is + an integer array, the minimum precision type that is returned is a + 64-bit floating point dtype. + + For full documentation refer to :obj:`numpy.common_type` + + Parameters + ---------- + array1, array2, ... : {dpnp.ndarray, usm_ndarray} + Input arrays. + + Returns + ------- + out : data type code + Data type code. + + Examples + -------- + >>> import dpnp as np + >>> np.common_type(np.arange(2, dtype=np.float32)) + + >>> np.common_type(np.arange(2, dtype=np.float32), np.arange(2)) + + >>> np.common_type(np.arange(4), np.array([45, 6.j]), np.array([45.0])) + + + """ + dpnp.check_supported_arrays_type(*arrays) + + is_complex = False + max_precision = 0 + + for a in arrays: + t = a.dtype.type + + if dpnp.issubdtype(t, dpnp.complexfloating): + is_complex = True + if dpnp.issubdtype(t, dpnp.integer): + t = dpnp.float64 + + max_precision = max(max_precision, array_precision.get(t, 0)) + + if is_complex: + return array_type["complex"].get(max_precision, dpnp.complex128) + return array_type["float"].get(max_precision, dpnp.float64) + + def default_float_type(device=None, sycl_queue=None): """ Return a floating type used by default in DPNP depending on device diff --git a/tests/third_party/cupy/test_type_routines.py b/tests/third_party/cupy/test_type_routines.py index c1e39a19cd0..2ea60c8fc46 100644 --- a/tests/third_party/cupy/test_type_routines.py +++ b/tests/third_party/cupy/test_type_routines.py @@ -46,7 +46,6 @@ def test_can_cast(self, xp, from_dtype, to_dtype): return ret -@pytest.mark.skip("dpnp.common_type() is not implemented yet") class TestCommonType(unittest.TestCase): @testing.numpy_cupy_equal() def test_common_type_empty(self, xp):