1
1
from hypothesis import given
2
+ from hypothesis import strategies as st
3
+ from hypothesis .control import assume
2
4
3
5
from . import _array_module as xp
6
+ from . import array_helpers as ah
7
+ from . import dtype_helpers as dh
4
8
from . import hypothesis_helpers as hh
9
+ from . import pytest_helpers as ph
5
10
from . import xps
11
+ from .test_manipulation_functions import assert_equals , axis_ndindex
6
12
7
13
8
14
# TODO: generate kwargs
@@ -12,8 +18,52 @@ def test_argsort(x):
12
18
# TODO
13
19
14
20
15
- # TODO: generate 0d arrays, generate kwargs
16
- @given (xps .arrays (dtype = xps .scalar_dtypes (), shape = hh .shapes (min_dims = 1 )))
17
- def test_sort (x ):
18
- xp .sort (x )
19
- # TODO
21
+ # TODO: Test with signed zeros and NaNs (and ignore them somehow)
22
+ @given (
23
+ x = xps .arrays (
24
+ dtype = xps .scalar_dtypes (),
25
+ shape = hh .shapes (min_dims = 1 , min_side = 1 ),
26
+ elements = {"allow_nan" : False },
27
+ ),
28
+ data = st .data (),
29
+ )
30
+ def test_sort (x , data ):
31
+ if dh .is_float_dtype (x .dtype ):
32
+ assume (not xp .any (x == - 0.0 ) and not xp .any (x == + 0.0 ))
33
+
34
+ kw = data .draw (
35
+ hh .kwargs (
36
+ axis = st .integers (- x .ndim , x .ndim - 1 ),
37
+ descending = st .booleans (),
38
+ stable = st .booleans (),
39
+ ),
40
+ label = "kw" ,
41
+ )
42
+
43
+ out = xp .sort (x , ** kw )
44
+
45
+ ph .assert_dtype ("sort" , out .dtype , x .dtype )
46
+ ph .assert_shape ("sort" , out .shape , x .shape , ** kw )
47
+ axis = kw .get ("axis" , - 1 )
48
+ _axis = axis if axis >= 0 else x .ndim + axis
49
+ descending = kw .get ("descending" , False )
50
+ scalar_type = dh .get_scalar_type (x .dtype )
51
+ for idx in axis_ndindex (x .shape , _axis ):
52
+ f_idx = ", " .join (str (i ) if isinstance (i , int ) else ":" for i in idx )
53
+ indexed_x = x [idx ]
54
+ indexed_out = out [idx ]
55
+ out_indices = list (ah .ndindex (indexed_x .shape ))
56
+ elements = [scalar_type (indexed_x [idx2 ]) for idx2 in out_indices ]
57
+ indices_order = sorted (
58
+ range (len (out_indices )), key = elements .__getitem__ , reverse = descending
59
+ )
60
+ x_indices = [out_indices [o ] for o in indices_order ]
61
+ for out_idx , x_idx in zip (out_indices , x_indices ):
62
+ assert_equals (
63
+ "sort" ,
64
+ f"x[{ f_idx } ][{ x_idx } ]" ,
65
+ indexed_x [x_idx ],
66
+ f"out[{ f_idx } ][{ out_idx } ]" ,
67
+ indexed_out [out_idx ],
68
+ ** kw ,
69
+ )
0 commit comments