2
2
from operator import mul
3
3
from math import sqrt
4
4
5
- from hypothesis .strategies import (lists , integers , builds , sampled_from ,
5
+ from hypothesis import assume
6
+ from hypothesis .strategies import (lists , integers , sampled_from ,
6
7
shared , floats , just , composite , one_of ,
7
8
none , booleans )
8
- from hypothesis .extra .numpy import mutually_broadcastable_shapes
9
- from hypothesis import assume
9
+ from hypothesis .extra .array_api import make_strategies_namespace
10
10
11
11
from .pytest_helpers import nargs
12
12
from .array_helpers import (dtype_ranges , integer_dtype_objects ,
15
15
integer_or_boolean_dtype_objects , dtype_objects )
16
16
from ._array_module import full , float32 , float64 , bool as bool_dtype , _UndefinedStub
17
17
from . import _array_module
18
+ from . import _array_module as xp
18
19
19
20
from .function_stubs import elementwise_functions
20
21
21
22
23
+ xps = make_strategies_namespace (xp )
24
+
25
+
22
26
# Set this to True to not fail tests just because a dtype isn't implemented.
23
27
# If no compatible dtype is implemented for a given test, the test will fail
24
28
# with a hypothesis health check error. Note that this functionality will not
42
46
boolean_dtypes = boolean_dtypes .filter (lambda x : not isinstance (x , _UndefinedStub ))
43
47
dtypes = dtypes .filter (lambda x : not isinstance (x , _UndefinedStub ))
44
48
45
- shared_dtypes = shared (dtypes )
49
+ shared_dtypes = shared (dtypes , key = "dtype" )
46
50
51
+ # TODO: Importing things from test_type_promotion should be replaced by
52
+ # something that won't cause a circular import. Right now we use @st.composite
53
+ # only because it returns a lazy-evaluated strategy - in the future this method
54
+ # should remove the composite wrapper, just returning sampled_from(dtype_pairs)
55
+ # instead of drawing from it.
47
56
@composite
48
57
def mutually_promotable_dtypes (draw , dtype_objects = dtype_objects ):
49
58
from .test_type_promotion import dtype_mapping , promotion_table
@@ -55,17 +64,20 @@ def mutually_promotable_dtypes(draw, dtype_objects=dtype_objects):
55
64
# pairs (XXX: Can we redesign the strategies so that they can prefer
56
65
# shrinking dtypes over values?)
57
66
sorted_table = sorted (promotion_table )
58
- sorted_table = sorted (sorted_table , key = lambda ij : - 1 if ij [0 ] == ij [1 ] else sorted_table .index (ij ))
59
- dtype_pairs = [(dtype_mapping [i ], dtype_mapping [j ]) for i , j in
60
- sorted_table ]
61
-
62
- filtered_dtype_pairs = [(i , j ) for i , j in dtype_pairs if i in
63
- dtype_objects and j in dtype_objects ]
67
+ sorted_table = sorted (
68
+ sorted_table , key = lambda ij : - 1 if ij [0 ] == ij [1 ] else sorted_table .index (ij )
69
+ )
70
+ dtype_pairs = [(dtype_mapping [i ], dtype_mapping [j ]) for i , j in sorted_table ]
64
71
if FILTER_UNDEFINED_DTYPES :
65
- filtered_dtype_pairs = [(i , j ) for i , j in filtered_dtype_pairs
66
- if not isinstance (i , _UndefinedStub )
67
- and not isinstance (j , _UndefinedStub )]
68
- return draw (sampled_from (filtered_dtype_pairs ))
72
+ dtype_pairs = [(i , j ) for i , j in dtype_pairs
73
+ if not isinstance (i , _UndefinedStub )
74
+ and not isinstance (j , _UndefinedStub )]
75
+ dtype_pairs = [(i , j ) for i , j in dtype_pairs if i in dtype_objects and j in dtype_objects ]
76
+ return draw (sampled_from (dtype_pairs ))
77
+
78
+ shared_mutually_promotable_dtype_pairs = shared (
79
+ mutually_promotable_dtypes (), key = "mutually_promotable_dtype_pair"
80
+ )
69
81
70
82
# shared() allows us to draw either the function or the function name and they
71
83
# will both correspond to the same function.
@@ -96,36 +108,35 @@ def tuples(elements, *, min_size=0, max_size=None, unique_by=None, unique=False)
96
108
return lists (elements , min_size = min_size , max_size = max_size ,
97
109
unique_by = unique_by , unique = unique ).map (tuple )
98
110
99
- shapes = tuples (integers (0 , 10 )).filter (lambda shape : prod (shape ) < MAX_ARRAY_SIZE )
100
-
101
111
# Use this to avoid memory errors with NumPy.
102
112
# See https://github.com/numpy/numpy/issues/15753
103
- shapes = tuples (integers (0 , 10 )).filter (
104
- lambda shape : prod ([i for i in shape if i ]) < MAX_ARRAY_SIZE )
113
+ shapes = xps .array_shapes (min_dims = 0 , min_side = 0 ).filter (
114
+ lambda shape : prod (i for i in shape if i ) < MAX_ARRAY_SIZE
115
+ )
105
116
106
- two_mutually_broadcastable_shapes = mutually_broadcastable_shapes (num_shapes = 2 )\
117
+ two_mutually_broadcastable_shapes = xps . mutually_broadcastable_shapes (num_shapes = 2 )\
107
118
.map (lambda S : S .input_shapes )\
108
- .filter (lambda S : all (prod ([ i for i in shape if i ] ) < MAX_ARRAY_SIZE for shape in S ))
119
+ .filter (lambda S : all (prod (i for i in shape if i ) < MAX_ARRAY_SIZE for shape in S ))
109
120
110
121
@composite
111
- def two_broadcastable_shapes (draw , shapes = shapes ):
122
+ def two_broadcastable_shapes (draw ):
112
123
"""
113
124
This will produce two shapes (shape1, shape2) such that shape2 can be
114
125
broadcast to shape1.
115
-
116
126
"""
117
127
from .test_broadcasting import broadcast_shapes
118
-
119
- shape1 , shape2 = draw (two_mutually_broadcastable_shapes )
120
- if broadcast_shapes (shape1 , shape2 ) != shape1 :
121
- assume (False )
128
+ shape1 , shape2 = draw (two_mutually_broadcastable_shapes )
129
+ assume (broadcast_shapes (shape1 , shape2 ) == shape1 )
122
130
return (shape1 , shape2 )
123
131
124
132
sizes = integers (0 , MAX_ARRAY_SIZE )
125
133
sqrt_sizes = integers (0 , SQRT_MAX_ARRAY_SIZE )
126
134
127
135
# TODO: Generate general arrays here, rather than just scalars.
128
- numeric_arrays = builds (full , just ((1 ,)), floats ())
136
+ numeric_arrays = xps .arrays (
137
+ dtype = shared (xps .floating_dtypes (), key = 'dtypes' ),
138
+ shape = shared (xps .array_shapes (), key = 'shapes' ),
139
+ )
129
140
130
141
@composite
131
142
def scalars (draw , dtypes , finite = False ):
@@ -230,3 +241,22 @@ def multiaxis_indices(draw, shapes):
230
241
extra = draw (lists (one_of (integer_indices (sizes ), slices (sizes )), min_size = 0 , max_size = 3 ))
231
242
res += extra
232
243
return tuple (res )
244
+
245
+
246
+ shared_arrays1 = xps .arrays (
247
+ dtype = shared_mutually_promotable_dtype_pairs .map (lambda pair : pair [0 ]),
248
+ shape = shared (two_mutually_broadcastable_shapes , key = "shape_pair" ).map (lambda pair : pair [0 ]),
249
+ )
250
+ shared_arrays2 = xps .arrays (
251
+ dtype = shared_mutually_promotable_dtype_pairs .map (lambda pair : pair [1 ]),
252
+ shape = shared (two_mutually_broadcastable_shapes , key = "shape_pair" ).map (lambda pair : pair [1 ]),
253
+ )
254
+
255
+
256
+ @composite
257
+ def kwargs (draw , ** kw ):
258
+ result = {}
259
+ for k , strat in kw .items ():
260
+ if draw (booleans ()):
261
+ result [k ] = draw (strat )
262
+ return result
0 commit comments