Skip to content

Commit 5422dc9

Browse files
authored
Merge branch 'master' into harlowjo/debug/wheels-tests
2 parents ba98594 + c844b26 commit 5422dc9

File tree

3 files changed

+98
-0
lines changed

3 files changed

+98
-0
lines changed

dpnp/tests/test_ndarray.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -460,3 +460,29 @@ def test_clip():
460460
expected = numpy.clip(numpy_array, 3, 7)
461461

462462
assert_array_equal(expected, result)
463+
464+
465+
def test_rmatmul_dpnp_array():
466+
a = dpnp.ones(10)
467+
b = dpnp.ones(10)
468+
469+
class Dummy(dpnp.ndarray):
470+
def __init__(self, x):
471+
self._array_obj = x.get_array()
472+
473+
def __matmul__(self, other):
474+
return NotImplemented
475+
476+
d = Dummy(a)
477+
478+
result = d @ b
479+
expected = a @ b
480+
assert (result == expected).all()
481+
482+
483+
def test_rmatmul_numpy_array():
484+
a = dpnp.ones(10)
485+
b = numpy.ones(10)
486+
487+
with pytest.raises(TypeError):
488+
b @ a

dpnp/tests/test_utils.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
import dpctl
2+
import dpctl.tensor as dpt
3+
import numpy
4+
import pytest
5+
6+
import dpnp
7+
8+
9+
class TestIsSupportedArrayOrScalar:
10+
@pytest.mark.parametrize(
11+
"array",
12+
[
13+
dpnp.array([1, 2, 3]),
14+
dpnp.array(1),
15+
dpt.asarray([1, 2, 3]),
16+
],
17+
)
18+
def test_valid_arrays(self, array):
19+
assert dpnp.is_supported_array_or_scalar(array) is True
20+
21+
@pytest.mark.parametrize(
22+
"value",
23+
[
24+
42,
25+
True,
26+
"1",
27+
],
28+
)
29+
def test_valid_scalars(self, value):
30+
assert dpnp.is_supported_array_or_scalar(value) is True
31+
32+
@pytest.mark.parametrize(
33+
"array",
34+
[
35+
[1, 2, 3],
36+
(1, 2, 3),
37+
None,
38+
numpy.array([1, 2, 3]),
39+
],
40+
)
41+
def test_invalid_arrays(self, array):
42+
assert not dpnp.is_supported_array_or_scalar(array) is True
43+
44+
45+
class TestSynchronizeArrayData:
46+
@pytest.mark.parametrize(
47+
"array",
48+
[
49+
dpnp.array([1, 2, 3]),
50+
dpt.asarray([1, 2, 3]),
51+
],
52+
)
53+
def test_synchronize_array_data(self, array):
54+
a_copy = dpnp.copy(array, sycl_queue=array.sycl_queue)
55+
try:
56+
dpnp.synchronize_array_data(a_copy)
57+
except Exception as e:
58+
pytest.fail(f"synchronize_array_data failed: {e}")
59+
60+
@pytest.mark.parametrize(
61+
"input",
62+
[
63+
[1, 2, 3],
64+
numpy.array([1, 2, 3]),
65+
],
66+
)
67+
def test_unsupported_type(self, input):
68+
with pytest.raises(TypeError):
69+
dpnp.synchronize_array_data(input)

setup.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ def _get_cmdclass():
4343
Topic :: Software Development
4444
Topic :: Scientific/Engineering
4545
Operating System :: Microsoft :: Windows
46+
Operating System :: POSIX :: Linux
4647
Operating System :: POSIX
4748
Operating System :: Unix
4849
"""
@@ -82,4 +83,6 @@ def _get_cmdclass():
8283
]
8384
},
8485
include_package_data=False,
86+
python_requires=">=3.9,<3.14",
87+
install_requires=["dpctl >= 0.19.0dev0", "numpy"],
8588
)

0 commit comments

Comments
 (0)