Skip to content

Commit 0fbf329

Browse files
authored
following NEP-50 for dpnp.einsum (#2120)
1 parent 7c45c10 commit 0fbf329

File tree

4 files changed

+17
-22
lines changed

4 files changed

+17
-22
lines changed

dpnp/dpnp_iface_statistics.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -370,7 +370,7 @@ def correlate(x1, x2, mode="valid"):
370370
-----------
371371
Input arrays are supported as :obj:`dpnp.ndarray`.
372372
Size and shape of input arrays are supported to be equal.
373-
Parameter `mode` is supported only with default value ``"valid``.
373+
Parameter `mode` is supported only with default value ``"valid"``.
374374
Otherwise the function will be executed sequentially on CPU.
375375
Input array data types are limited by supported DPNP :ref:`Data types`.
376376

dpnp/dpnp_utils/dpnp_utils_einsum.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,8 @@
3333
from dpctl.utils import ExecutionPlacementError
3434

3535
import dpnp
36-
from dpnp.dpnp_utils import get_usm_allocations
37-
38-
from ..dpnp_array import dpnp_array
36+
from dpnp.dpnp_array import dpnp_array
37+
from dpnp.dpnp_utils import get_usm_allocations, map_dtype_to_device
3938

4039
_einsum_symbols = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
4140

@@ -1027,17 +1026,16 @@ def dpnp_einsum(
10271026
"Input and output allocation queues are not compatible"
10281027
)
10291028

1030-
result_dtype = dpnp.result_type(*arrays) if dtype is None else dtype
10311029
for id, a in enumerate(operands):
10321030
if dpnp.isscalar(a):
1031+
scalar_dtype = map_dtype_to_device(type(a), exec_q.sycl_device)
10331032
operands[id] = dpnp.array(
1034-
a, dtype=result_dtype, usm_type=res_usm_type, sycl_queue=exec_q
1033+
a, dtype=scalar_dtype, usm_type=res_usm_type, sycl_queue=exec_q
10351034
)
1035+
arrays.append(operands[id])
10361036
result_dtype = dpnp.result_type(*arrays) if dtype is None else dtype
1037-
if order in ["a", "A"]:
1038-
order = (
1039-
"F" if not any(arr.flags.c_contiguous for arr in arrays) else "C"
1040-
)
1037+
if order in "aA":
1038+
order = "F" if all(arr.flags.fnc for arr in arrays) else "C"
10411039

10421040
input_subscripts = [
10431041
_parse_ellipsis_subscript(sub, idx, ndim=arr.ndim)

tests/test_linalg.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1139,14 +1139,12 @@ def check_einsum_sums(self, dtype, do_opt=False):
11391139
result = inp.einsum(*args, dtype="?", casting="unsafe", optimize=do_opt)
11401140
assert_dtype_allclose(result, expected)
11411141

1142-
# with an scalar, NumPy < 2.0.0 uses the other input arrays to
1143-
# determine the output type while for NumPy > 2.0.0 the scalar
1144-
# with default machine dtype is used to determine the output
1145-
# data type
1142+
# NumPy >= 2.0 follows NEP-50 to determine the output dtype when one of
1143+
# the inputs is a scalar while NumPy < 2.0 does not
11461144
if numpy.lib.NumpyVersion(numpy.__version__) < "2.0.0":
1147-
check_type = True
1148-
else:
11491145
check_type = False
1146+
else:
1147+
check_type = True
11501148
a = numpy.arange(9, dtype=dtype)
11511149
a_dp = inp.array(a)
11521150
expected = numpy.einsum(",i->", 3, a)
@@ -1712,7 +1710,7 @@ def test_broadcasting_dot_cases(self):
17121710

17131711
def test_output_order(self):
17141712
# Ensure output order is respected for optimize cases, the below
1715-
# conraction should yield a reshaped tensor view
1713+
# contraction should yield a reshaped tensor view
17161714
a = inp.ones((2, 3, 5), order="F")
17171715
b = inp.ones((4, 3), order="F")
17181716

tests/third_party/cupy/linalg_tests/test_einsum.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -475,13 +475,12 @@ def test_einsum_binary(self, xp, dtype_a, dtype_b):
475475

476476

477477
class TestEinSumBinaryOperationWithScalar:
478-
# with an scalar, NumPy < 2.0.0 uses the other input arrays to determine
479-
# the output type while for NumPy > 2.0.0 the scalar with default machine
480-
# dtype is used to determine the output type
478+
# NumPy >= 2.0 follows NEP-50 to determine the output dtype when one of
479+
# the inputs is a scalar while NumPy < 2.0 does not
481480
if numpy.lib.NumpyVersion(numpy.__version__) < "2.0.0":
482-
type_check = has_support_aspect64()
483-
else:
484481
type_check = False
482+
else:
483+
type_check = has_support_aspect64()
485484

486485
@testing.for_all_dtypes()
487486
@testing.numpy_cupy_allclose(contiguous_check=False, type_check=type_check)

0 commit comments

Comments
 (0)