Skip to content

Commit 0f53876

Browse files
committed
address comments
1 parent dd22aae commit 0f53876

File tree

3 files changed

+90
-44
lines changed

3 files changed

+90
-44
lines changed

dpnp/dpnp_iface_linearalgebra.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -147,11 +147,11 @@ def dot(a, b, out=None):
147147
# TODO: use specific scalar-vector kernel
148148
return dpnp.multiply(a, b, out=out)
149149

150+
# numpy.dot does not allow casting even if it is safe
151+
# casting="no" is used in the following
150152
if a_ndim == 1 and b_ndim == 1:
151-
return dpnp_dot(a, b, out=out)
153+
return dpnp_dot(a, b, out=out, casting="no")
152154

153-
# NumPy does not allow casting even if it is safe
154-
# casting="no" is used in the following
155155
if a_ndim == 2 and b_ndim == 2:
156156
return dpnp.matmul(a, b, out=out, casting="no")
157157

@@ -753,6 +753,7 @@ def matmul(
753753
Type to use in computing the matrix product. By default, the returned
754754
array will have data type that is determined by considering
755755
Promotion Type Rule and device capabilities.
756+
Default: ``None``.
756757
casting : {"no", "equiv", "safe", "same_kind", "unsafe"}, optional
757758
Controls what kind of data casting may occur.
758759
Default: ``"same_kind"``.
@@ -1203,7 +1204,7 @@ def vecdot(
12031204
.. math::
12041205
\mathbf{a} \cdot \mathbf{b} = \sum_{i=0}^{n-1} \overline{a_i}b_i
12051206
1206-
where the sum is over the last dimension (unless axis is specified) and
1207+
where the sum is over the last dimension (unless `axis` is specified) and
12071208
where :math:`\overline{a_i}` denotes the complex conjugate if :math:`a_i`
12081209
is complex and the identity otherwise.
12091210
@@ -1221,16 +1222,17 @@ def vecdot(
12211222
removed. If not provided or ``None``, a freshly-allocated array is
12221223
used.
12231224
Default: ``None``.
1224-
dtype : {None, dtype}, optional
1225-
Type to use in computing the vector dot product. By default, the
1226-
returned array will have data type that is determined by considering
1227-
Promotion Type Rule and device capabilities.
12281225
casting : {"no", "equiv", "safe", "same_kind", "unsafe"}, optional
12291226
Controls what kind of data casting may occur.
12301227
Default: ``"same_kind"``.
12311228
order : {"C", "F", "A", "K", None}, optional
12321229
Memory layout of the newly output array, if parameter `out` is ``None``.
12331230
Default: ``"K"``.
1231+
dtype : {None, dtype}, optional
1232+
Type to use in computing the vector dot product. By default, the
1233+
returned array will have data type that is determined by considering
1234+
Promotion Type Rule and device capabilities.
1235+
Default: ``None``.
12341236
axes : {None, list of tuples}, optional
12351237
A list of tuples with indices of axes the matrix product should operate
12361238
on. For instance, for the signature of ``(i),(i)->()``, the base

dpnp/dpnp_utils/dpnp_utils_linearalgebra.py

Lines changed: 30 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ def _define_contig_flag(x):
189189

190190
def _define_dim_flags(x, axis):
191191
"""
192-
Define useful flags for the main calculation in dpnp_matmul.
192+
Define useful flags for the calculations in dpnp_matmul and dpnp_vecdot.
193193
x_is_1D: `x` is 1D array or inherently 1D (all dimensions are equal to one
194194
except for one of them), for instance, if x.shape = (1, 1, 1, 2),
195195
then x_is_1D = True
@@ -220,7 +220,7 @@ def _define_dim_flags(x, axis):
220220
return x_is_2D, x_is_1D, x_base_is_1D
221221

222222

223-
def _get_result_shape(x1, x2, out, func, np_flag):
223+
def _get_result_shape(x1, x2, out, _get_result_shape_fn, np_flag):
224224
"""
225225
Three task are completed in this function:
226226
- Get the shape of the result array.
@@ -239,15 +239,7 @@ def _get_result_shape(x1, x2, out, func, np_flag):
239239
"The second input array does not have enough dimensions (has 0, but requires at least 1)"
240240
)
241241

242-
if func == "matmul":
243-
x1, x2, result_shape = _get_result_shape_matmul(
244-
x1, x2, x1_ndim, x2_ndim
245-
)
246-
else: # func == "vecdot"
247-
assert func == "vecdot"
248-
x1, x2, result_shape = _get_result_shape_vecdot(
249-
x1, x2, x1_ndim, x2_ndim
250-
)
242+
x1, x2, result_shape = _get_result_shape_fn(x1, x2, x1_ndim, x2_ndim)
251243

252244
if out is not None:
253245
out_shape = out.shape
@@ -474,7 +466,7 @@ def _shape_error(shape1, shape2, func, err_msg):
474466
elif func == "vecdot":
475467
signature = "(n?,),(n?,)->()"
476468
else:
477-
# applicable when err_msg == 3
469+
# applicable when err_msg == 2
478470
assert func is None
479471

480472
if err_msg == 0:
@@ -655,7 +647,7 @@ def dpnp_cross(a, b, cp):
655647
return cp
656648

657649

658-
def dpnp_dot(a, b, /, out=None, *, conjugate=False):
650+
def dpnp_dot(a, b, /, out=None, *, casting="same_kind", conjugate=False):
659651
"""
660652
Return the dot product of two arrays.
661653
@@ -717,8 +709,7 @@ def dpnp_dot(a, b, /, out=None, *, conjugate=False):
717709
if dot_dtype != res_dtype:
718710
result = result.astype(res_dtype, copy=False)
719711

720-
# numpy.dot does not allow casting even if it is safe
721-
return dpnp.get_result_array(result, out, casting="no")
712+
return dpnp.get_result_array(result, out, casting=casting)
722713

723714

724715
def dpnp_kron(a, b, a_ndim, b_ndim):
@@ -773,8 +764,10 @@ def dpnp_matmul(
773764
order = "F"
774765
else:
775766
order = "C"
776-
777-
if order in "kK":
767+
elif order in "kK":
768+
# For order="K", we return order="C" to align with NumPy behavior
769+
# It is different than logic used in dpnp_vecdot because NumPy
770+
# behaves differently for matmul and vecdot
778771
order = "C"
779772

780773
x1_ndim = x1.ndim
@@ -806,7 +799,7 @@ def dpnp_matmul(
806799
)
807800

808801
x1, x2, result_shape = _get_result_shape(
809-
x1, x2, out, "matmul", NumPy_special_behavior
802+
x1, x2, out, _get_result_shape_matmul, NumPy_special_behavior
810803
)
811804

812805
# Determine the appropriate data types
@@ -1000,6 +993,9 @@ def dpnp_vecdot(
1000993
_validate_out_array(out, exec_q)
1001994

1002995
if order in "aAkK":
996+
# This logic is also used for order="K" to align with NumPy behavior.
997+
# It is different than logic used in dpnp_matmul because NumPy
998+
# behaves differently for matmul and vecdot
1003999
if x1.flags.fnc and x2.flags.fnc:
10041000
order = "F"
10051001
else:
@@ -1035,7 +1031,7 @@ def dpnp_vecdot(
10351031
)
10361032

10371033
x1, x2, result_shape = _get_result_shape(
1038-
x1, x2, out, "vecdot", NumPy_special_behavior
1034+
x1, x2, out, _get_result_shape_vecdot, NumPy_special_behavior
10391035
)
10401036

10411037
# Determine the appropriate data types
@@ -1047,21 +1043,7 @@ def dpnp_vecdot(
10471043
_, x2_is_1D, _ = _define_dim_flags(x2, axis=-1)
10481044

10491045
if x1.size == 0 or x2.size == 0:
1050-
order = "C" if order in "kK" else order
1051-
result = _create_result_array(
1052-
x1,
1053-
x2,
1054-
out,
1055-
shape=result_shape,
1056-
dtype=res_dtype,
1057-
usm_type=res_usm_type,
1058-
sycl_queue=exec_q,
1059-
order=order,
1060-
)
1061-
if numpy.prod(result_shape) == 0:
1062-
return result
1063-
result.fill(0)
1064-
return result
1046+
call_flag = "trivial"
10651047
elif x1_is_1D and x2_is_1D:
10661048
call_flag = "dot"
10671049
# arrays are inehrently 1D, make them 1D
@@ -1072,7 +1054,20 @@ def dpnp_vecdot(
10721054
call_flag = "vecdot"
10731055

10741056
# dispatch to proper function call
1075-
if call_flag == "dot":
1057+
if call_flag == "trivial":
1058+
result = _create_result_array(
1059+
x1,
1060+
x2,
1061+
out,
1062+
shape=result_shape,
1063+
dtype=res_dtype,
1064+
usm_type=res_usm_type,
1065+
sycl_queue=exec_q,
1066+
order=order,
1067+
)
1068+
if numpy.prod(result_shape) != 0:
1069+
result.fill(0)
1070+
elif call_flag == "dot":
10761071
if out is not None and out.shape != ():
10771072
result = dpnp_dot(x1, x2, out=None, conjugate=True)
10781073
else:

tests/test_product.py

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1525,6 +1525,27 @@ def test_order(self, order1, order2, order, shape):
15251525
assert result.flags.f_contiguous == expected.flags.f_contiguous
15261526
assert_dtype_allclose(result, expected)
15271527

1528+
@pytest.mark.parametrize("order", ["C", "F", "K", "A"])
1529+
@pytest.mark.parametrize(
1530+
"shape", [(2, 4, 0), (4, 0, 5)], ids=["(2, 4, 0)", "(4, 0, 5)"]
1531+
)
1532+
def test_order_trivial(self, order, shape):
1533+
# input is both c-contiguous and f-contiguous
1534+
a = numpy.ones(shape)
1535+
a_dp = dpnp.asarray(a)
1536+
1537+
result = dpnp.vecdot(a_dp, a_dp, order=order)
1538+
expected = numpy.vecdot(a, a, order=order)
1539+
if shape == (2, 4, 0) and order == "A":
1540+
# NumPy does not behave correctly for this case, for order="A",
1541+
# if input is both c- and f-contiguous, output is c-contiguous
1542+
assert result.flags.c_contiguous
1543+
assert not result.flags.f_contiguous
1544+
else:
1545+
assert result.flags.c_contiguous == expected.flags.c_contiguous
1546+
assert result.flags.f_contiguous == expected.flags.f_contiguous
1547+
assert_dtype_allclose(result, expected)
1548+
15281549
@pytest.mark.parametrize(
15291550
"order1, order2, out_order",
15301551
[
@@ -1538,7 +1559,7 @@ def test_order(self, order1, order2, order, shape):
15381559
("F", "F", "C"),
15391560
],
15401561
)
1541-
def test_out(self, order1, order2, out_order):
1562+
def test_out_order(self, order1, order2, out_order):
15421563
a1 = numpy.arange(20).reshape(5, 4, order=order1)
15431564
a2 = numpy.arange(20).reshape(5, 4, order=order2)
15441565

@@ -1555,6 +1576,34 @@ def test_out(self, order1, order2, out_order):
15551576
assert result.flags.f_contiguous == expected.flags.f_contiguous
15561577
assert_dtype_allclose(result, expected)
15571578

1579+
@pytest.mark.parametrize("dtype1", get_all_dtypes(no_none=True))
1580+
@pytest.mark.parametrize("dtype2", get_all_dtypes(no_none=True))
1581+
@pytest.mark.parametrize(
1582+
"shape_pair",
1583+
[
1584+
((4,), ()),
1585+
((1, 1, 4), (1, 1)),
1586+
((6, 7, 4, 3), (6, 7, 4)),
1587+
((2, 0), (2,)), # zero-size inputs, 1D output
1588+
((3, 0, 4), (3, 0)), # zero-size output
1589+
],
1590+
)
1591+
def test_out_dtype(self, dtype1, dtype2, shape_pair):
1592+
shape1, shape2 = shape_pair
1593+
a = numpy.ones(shape1, dtype=dtype1)
1594+
b = dpnp.asarray(a)
1595+
1596+
out_np = numpy.empty(shape2, dtype=dtype2)
1597+
out_dp = dpnp.asarray(out_np)
1598+
1599+
if dpnp.can_cast(dtype1, dtype2, casting="same_kind"):
1600+
result = dpnp.vecdot(b, b, out=out_dp)
1601+
expected = numpy.vecdot(a, a, out=out_np)
1602+
assert_dtype_allclose(result, expected)
1603+
else:
1604+
with pytest.raises(TypeError):
1605+
dpnp.vecdot(b, b, out=out_dp)
1606+
15581607
@pytest.mark.parametrize(
15591608
"out_shape",
15601609
[

0 commit comments

Comments
 (0)