Skip to content

Commit f108941

Browse files
committed
Cover most things in test_sort
1 parent d6c4fc6 commit f108941

File tree

1 file changed

+55
-5
lines changed

1 file changed

+55
-5
lines changed

array_api_tests/test_sorting.py

Lines changed: 55 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,14 @@
11
from hypothesis import given
2+
from hypothesis import strategies as st
3+
from hypothesis.control import assume
24

35
from . import _array_module as xp
6+
from . import array_helpers as ah
7+
from . import dtype_helpers as dh
48
from . import hypothesis_helpers as hh
9+
from . import pytest_helpers as ph
510
from . import xps
11+
from .test_manipulation_functions import assert_equals, axis_ndindex
612

713

814
# TODO: generate kwargs
@@ -12,8 +18,52 @@ def test_argsort(x):
1218
# TODO
1319

1420

15-
# TODO: generate 0d arrays, generate kwargs
16-
@given(xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes(min_dims=1)))
17-
def test_sort(x):
18-
xp.sort(x)
19-
# TODO
21+
# TODO: Test with signed zeros and NaNs (and ignore them somehow)
22+
@given(
23+
x=xps.arrays(
24+
dtype=xps.scalar_dtypes(),
25+
shape=hh.shapes(min_dims=1, min_side=1),
26+
elements={"allow_nan": False},
27+
),
28+
data=st.data(),
29+
)
30+
def test_sort(x, data):
31+
if dh.is_float_dtype(x.dtype):
32+
assume(not xp.any(x == -0.0) and not xp.any(x == +0.0))
33+
34+
kw = data.draw(
35+
hh.kwargs(
36+
axis=st.integers(-x.ndim, x.ndim - 1),
37+
descending=st.booleans(),
38+
stable=st.booleans(),
39+
),
40+
label="kw",
41+
)
42+
43+
out = xp.sort(x, **kw)
44+
45+
ph.assert_dtype("sort", out.dtype, x.dtype)
46+
ph.assert_shape("sort", out.shape, x.shape, **kw)
47+
axis = kw.get("axis", -1)
48+
_axis = axis if axis >= 0 else x.ndim + axis
49+
descending = kw.get("descending", False)
50+
scalar_type = dh.get_scalar_type(x.dtype)
51+
for idx in axis_ndindex(x.shape, _axis):
52+
f_idx = ", ".join(str(i) if isinstance(i, int) else ":" for i in idx)
53+
indexed_x = x[idx]
54+
indexed_out = out[idx]
55+
out_indices = list(ah.ndindex(indexed_x.shape))
56+
elements = [scalar_type(indexed_x[idx2]) for idx2 in out_indices]
57+
indices_order = sorted(
58+
range(len(out_indices)), key=elements.__getitem__, reverse=descending
59+
)
60+
x_indices = [out_indices[o] for o in indices_order]
61+
for out_idx, x_idx in zip(out_indices, x_indices):
62+
assert_equals(
63+
"sort",
64+
f"x[{f_idx}][{x_idx}]",
65+
indexed_x[x_idx],
66+
f"out[{f_idx}][{out_idx}]",
67+
indexed_out[out_idx],
68+
**kw,
69+
)

0 commit comments

Comments
 (0)