24
24
@given (hh .mutually_promotable_dtypes (None ))
25
25
def test_result_type (dtypes ):
26
26
out = xp .result_type (* dtypes )
27
- ph .assert_dtype (' result_type' , dtypes , out , out_name = ' out' )
27
+ ph .assert_dtype (" result_type" , dtypes , out , out_name = " out" )
28
28
29
29
30
+ # The number and size of generated arrays is arbitrarily limited to prevent
31
+ # meshgrid() running out of memory.
30
32
@given (
31
- dtypes = hh .mutually_promotable_dtypes (None , dtypes = dh .numeric_dtypes ),
33
+ dtypes = hh .mutually_promotable_dtypes (5 , dtypes = dh .numeric_dtypes ),
32
34
data = st .data (),
33
35
)
34
36
def test_meshgrid (dtypes , data ):
35
37
arrays = []
36
- shapes = data .draw (hh .mutually_broadcastable_shapes (len (dtypes )), label = 'shapes' )
38
+ shapes = data .draw (
39
+ hh .mutually_broadcastable_shapes (
40
+ len (dtypes ), min_dims = 1 , max_dims = 1 , max_side = 5
41
+ ),
42
+ label = "shapes" ,
43
+ )
37
44
for i , (dtype , shape ) in enumerate (zip (dtypes , shapes ), 1 ):
38
- x = data .draw (xps .arrays (dtype = dtype , shape = shape ), label = f' x{ i } ' )
45
+ x = data .draw (xps .arrays (dtype = dtype , shape = shape ), label = f" x{ i } " )
39
46
arrays .append (x )
40
47
out = xp .meshgrid (* arrays )
41
48
for i , x in enumerate (out ):
42
- ph .assert_dtype (' meshgrid' , dtypes , x .dtype , out_name = f' out[{ i } ].dtype' )
49
+ ph .assert_dtype (" meshgrid" , dtypes , x .dtype , out_name = f" out[{ i } ].dtype" )
43
50
44
51
45
52
@given (
@@ -50,10 +57,10 @@ def test_meshgrid(dtypes, data):
50
57
def test_concat (shape , dtypes , data ):
51
58
arrays = []
52
59
for i , dtype in enumerate (dtypes , 1 ):
53
- x = data .draw (xps .arrays (dtype = dtype , shape = shape ), label = f' x{ i } ' )
60
+ x = data .draw (xps .arrays (dtype = dtype , shape = shape ), label = f" x{ i } " )
54
61
arrays .append (x )
55
62
out = xp .concat (arrays )
56
- ph .assert_dtype (' concat' , dtypes , out .dtype )
63
+ ph .assert_dtype (" concat" , dtypes , out .dtype )
57
64
58
65
59
66
@given (
@@ -64,26 +71,26 @@ def test_concat(shape, dtypes, data):
64
71
def test_stack (shape , dtypes , data ):
65
72
arrays = []
66
73
for i , dtype in enumerate (dtypes , 1 ):
67
- x = data .draw (xps .arrays (dtype = dtype , shape = shape ), label = f' x{ i } ' )
74
+ x = data .draw (xps .arrays (dtype = dtype , shape = shape ), label = f" x{ i } " )
68
75
arrays .append (x )
69
76
out = xp .stack (arrays )
70
- ph .assert_dtype (' stack' , dtypes , out .dtype )
77
+ ph .assert_dtype (" stack" , dtypes , out .dtype )
71
78
72
79
73
80
bitwise_shift_funcs = [
74
- ' bitwise_left_shift' ,
75
- ' bitwise_right_shift' ,
76
- ' __lshift__' ,
77
- ' __rshift__' ,
78
- ' __ilshift__' ,
79
- ' __irshift__' ,
81
+ " bitwise_left_shift" ,
82
+ " bitwise_right_shift" ,
83
+ " __lshift__" ,
84
+ " __rshift__" ,
85
+ " __ilshift__" ,
86
+ " __irshift__" ,
80
87
]
81
88
82
89
83
90
# We pass kwargs to the elements strategy used by xps.arrays() so that we don't
84
91
# generate array elements that are erroneous or undefined for a function.
85
92
func_elements = defaultdict (
86
- lambda : None , {func : {' min_value' : 1 } for func in bitwise_shift_funcs }
93
+ lambda : None , {func : {" min_value" : 1 } for func in bitwise_shift_funcs }
87
94
)
88
95
89
96
@@ -94,7 +101,7 @@ def make_id(
94
101
) -> str :
95
102
f_args = dh .fmt_types (in_dtypes )
96
103
f_out_dtype = dh .dtype_to_name [out_dtype ]
97
- return f' { func_name } ({ f_args } ) -> { f_out_dtype } '
104
+ return f" { func_name } ({ f_args } ) -> { f_out_dtype } "
98
105
99
106
100
107
func_params : List [Param [str , Tuple [DataType , ...], DataType ]] = []
@@ -128,25 +135,25 @@ def make_id(
128
135
raise NotImplementedError ()
129
136
130
137
131
- @pytest .mark .parametrize (' func_name, in_dtypes, out_dtype' , func_params )
138
+ @pytest .mark .parametrize (" func_name, in_dtypes, out_dtype" , func_params )
132
139
@given (data = st .data ())
133
140
def test_func_promotion (func_name , in_dtypes , out_dtype , data ):
134
141
func = getattr (xp , func_name )
135
142
elements = func_elements [func_name ]
136
143
if len (in_dtypes ) == 1 :
137
144
x = data .draw (
138
145
xps .arrays (dtype = in_dtypes [0 ], shape = hh .shapes (), elements = elements ),
139
- label = 'x' ,
146
+ label = "x" ,
140
147
)
141
148
out = func (x )
142
149
else :
143
150
arrays = []
144
151
shapes = data .draw (
145
- hh .mutually_broadcastable_shapes (len (in_dtypes )), label = ' shapes'
152
+ hh .mutually_broadcastable_shapes (len (in_dtypes )), label = " shapes"
146
153
)
147
154
for i , (dtype , shape ) in enumerate (zip (in_dtypes , shapes ), 1 ):
148
155
x = data .draw (
149
- xps .arrays (dtype = dtype , shape = shape , elements = elements ), label = f' x{ i } '
156
+ xps .arrays (dtype = dtype , shape = shape , elements = elements ), label = f" x{ i } "
150
157
)
151
158
arrays .append (x )
152
159
try :
@@ -161,46 +168,46 @@ def test_func_promotion(func_name, in_dtypes, out_dtype, data):
161
168
p = pytest .param (
162
169
(dtype1 , dtype2 ),
163
170
promoted_dtype ,
164
- id = make_id ('' , (dtype1 , dtype2 ), promoted_dtype ),
171
+ id = make_id ("" , (dtype1 , dtype2 ), promoted_dtype ),
165
172
)
166
173
promotion_params .append (p )
167
174
168
175
169
- @pytest .mark .parametrize (' in_dtypes, out_dtype' , promotion_params )
176
+ @pytest .mark .parametrize (" in_dtypes, out_dtype" , promotion_params )
170
177
@given (shapes = hh .mutually_broadcastable_shapes (3 ), data = st .data ())
171
178
def test_where (in_dtypes , out_dtype , shapes , data ):
172
- x1 = data .draw (xps .arrays (dtype = in_dtypes [0 ], shape = shapes [0 ]), label = 'x1' )
173
- x2 = data .draw (xps .arrays (dtype = in_dtypes [1 ], shape = shapes [1 ]), label = 'x2' )
174
- cond = data .draw (xps .arrays (dtype = xp .bool , shape = shapes [2 ]), label = ' condition' )
179
+ x1 = data .draw (xps .arrays (dtype = in_dtypes [0 ], shape = shapes [0 ]), label = "x1" )
180
+ x2 = data .draw (xps .arrays (dtype = in_dtypes [1 ], shape = shapes [1 ]), label = "x2" )
181
+ cond = data .draw (xps .arrays (dtype = xp .bool , shape = shapes [2 ]), label = " condition" )
175
182
out = xp .where (cond , x1 , x2 )
176
- ph .assert_dtype (' where' , in_dtypes , out .dtype , out_dtype )
183
+ ph .assert_dtype (" where" , in_dtypes , out .dtype , out_dtype )
177
184
178
185
179
186
numeric_promotion_params = promotion_params [1 :]
180
187
181
188
182
- @pytest .mark .parametrize (' in_dtypes, out_dtype' , numeric_promotion_params )
189
+ @pytest .mark .parametrize (" in_dtypes, out_dtype" , numeric_promotion_params )
183
190
@given (shapes = hh .mutually_broadcastable_shapes (2 , min_dims = 2 ), data = st .data ())
184
191
def test_tensordot (in_dtypes , out_dtype , shapes , data ):
185
- x1 = data .draw (xps .arrays (dtype = in_dtypes [0 ], shape = shapes [0 ]), label = 'x1' )
186
- x2 = data .draw (xps .arrays (dtype = in_dtypes [1 ], shape = shapes [1 ]), label = 'x2' )
192
+ x1 = data .draw (xps .arrays (dtype = in_dtypes [0 ], shape = shapes [0 ]), label = "x1" )
193
+ x2 = data .draw (xps .arrays (dtype = in_dtypes [1 ], shape = shapes [1 ]), label = "x2" )
187
194
out = xp .tensordot (x1 , x2 )
188
- ph .assert_dtype (' tensordot' , in_dtypes , out .dtype , out_dtype )
195
+ ph .assert_dtype (" tensordot" , in_dtypes , out .dtype , out_dtype )
189
196
190
197
191
- @pytest .mark .parametrize (' in_dtypes, out_dtype' , numeric_promotion_params )
198
+ @pytest .mark .parametrize (" in_dtypes, out_dtype" , numeric_promotion_params )
192
199
@given (shapes = hh .mutually_broadcastable_shapes (2 , min_dims = 1 ), data = st .data ())
193
200
def test_vecdot (in_dtypes , out_dtype , shapes , data ):
194
- x1 = data .draw (xps .arrays (dtype = in_dtypes [0 ], shape = shapes [0 ]), label = 'x1' )
195
- x2 = data .draw (xps .arrays (dtype = in_dtypes [1 ], shape = shapes [1 ]), label = 'x2' )
201
+ x1 = data .draw (xps .arrays (dtype = in_dtypes [0 ], shape = shapes [0 ]), label = "x1" )
202
+ x2 = data .draw (xps .arrays (dtype = in_dtypes [1 ], shape = shapes [1 ]), label = "x2" )
196
203
out = xp .vecdot (x1 , x2 )
197
- ph .assert_dtype (' vecdot' , in_dtypes , out .dtype , out_dtype )
204
+ ph .assert_dtype (" vecdot" , in_dtypes , out .dtype , out_dtype )
198
205
199
206
200
207
op_params : List [Param [str , str , Tuple [DataType , ...], DataType ]] = []
201
208
op_to_symbol = {** dh .unary_op_to_symbol , ** dh .binary_op_to_symbol }
202
209
for op , symbol in op_to_symbol .items ():
203
- if op == ' __matmul__' :
210
+ if op == " __matmul__" :
204
211
continue
205
212
valid_in_dtypes = dh .func_in_dtypes [op ]
206
213
ndtypes = ph .nargs (op )
@@ -209,7 +216,7 @@ def test_vecdot(in_dtypes, out_dtype, shapes, data):
209
216
out_dtype = xp .bool if dh .func_returns_bool [op ] else in_dtype
210
217
p = pytest .param (
211
218
op ,
212
- f' { symbol } x' ,
219
+ f" { symbol } x" ,
213
220
(in_dtype ,),
214
221
out_dtype ,
215
222
id = make_id (op , (in_dtype ,), out_dtype ),
@@ -221,42 +228,42 @@ def test_vecdot(in_dtypes, out_dtype, shapes, data):
221
228
out_dtype = xp .bool if dh .func_returns_bool [op ] else promoted_dtype
222
229
p = pytest .param (
223
230
op ,
224
- f' x1 { symbol } x2' ,
231
+ f" x1 { symbol } x2" ,
225
232
(in_dtype1 , in_dtype2 ),
226
233
out_dtype ,
227
234
id = make_id (op , (in_dtype1 , in_dtype2 ), out_dtype ),
228
235
)
229
236
op_params .append (p )
230
237
# We generate params for abs seperately as it does not have an associated symbol
231
- for in_dtype in dh .func_in_dtypes [' __abs__' ]:
238
+ for in_dtype in dh .func_in_dtypes [" __abs__" ]:
232
239
p = pytest .param (
233
- ' __abs__' ,
234
- ' abs(x)' ,
240
+ " __abs__" ,
241
+ " abs(x)" ,
235
242
(in_dtype ,),
236
243
in_dtype ,
237
- id = make_id (' __abs__' , (in_dtype ,), in_dtype ),
244
+ id = make_id (" __abs__" , (in_dtype ,), in_dtype ),
238
245
)
239
246
op_params .append (p )
240
247
241
248
242
- @pytest .mark .parametrize (' op, expr, in_dtypes, out_dtype' , op_params )
249
+ @pytest .mark .parametrize (" op, expr, in_dtypes, out_dtype" , op_params )
243
250
@given (data = st .data ())
244
251
def test_op_promotion (op , expr , in_dtypes , out_dtype , data ):
245
252
elements = func_elements [func_name ]
246
253
if len (in_dtypes ) == 1 :
247
254
x = data .draw (
248
255
xps .arrays (dtype = in_dtypes [0 ], shape = hh .shapes (), elements = elements ),
249
- label = 'x' ,
256
+ label = "x" ,
250
257
)
251
- out = eval (expr , {'x' : x })
258
+ out = eval (expr , {"x" : x })
252
259
else :
253
260
locals_ = {}
254
261
shapes = data .draw (
255
- hh .mutually_broadcastable_shapes (len (in_dtypes )), label = ' shapes'
262
+ hh .mutually_broadcastable_shapes (len (in_dtypes )), label = " shapes"
256
263
)
257
264
for i , (dtype , shape ) in enumerate (zip (in_dtypes , shapes ), 1 ):
258
- locals_ [f' x{ i } ' ] = data .draw (
259
- xps .arrays (dtype = dtype , shape = shape , elements = elements ), label = f' x{ i } '
265
+ locals_ [f" x{ i } " ] = data .draw (
266
+ xps .arrays (dtype = dtype , shape = shape , elements = elements ), label = f" x{ i } "
260
267
)
261
268
try :
262
269
out = eval (expr , locals_ )
@@ -267,7 +274,7 @@ def test_op_promotion(op, expr, in_dtypes, out_dtype, data):
267
274
268
275
inplace_params : List [Param [str , str , Tuple [DataType , ...], DataType ]] = []
269
276
for op , symbol in dh .inplace_op_to_symbol .items ():
270
- if op == ' __imatmul__' :
277
+ if op == " __imatmul__" :
271
278
continue
272
279
valid_in_dtypes = dh .func_in_dtypes [op ]
273
280
for (in_dtype1 , in_dtype2 ), promoted_dtype in dh .promotion_table .items ():
@@ -278,44 +285,44 @@ def test_op_promotion(op, expr, in_dtypes, out_dtype, data):
278
285
):
279
286
p = pytest .param (
280
287
op ,
281
- f' x1 { symbol } x2' ,
288
+ f" x1 { symbol } x2" ,
282
289
(in_dtype1 , in_dtype2 ),
283
290
promoted_dtype ,
284
291
id = make_id (op , (in_dtype1 , in_dtype2 ), promoted_dtype ),
285
292
)
286
293
inplace_params .append (p )
287
294
288
295
289
- @pytest .mark .parametrize (' op, expr, in_dtypes, out_dtype' , inplace_params )
296
+ @pytest .mark .parametrize (" op, expr, in_dtypes, out_dtype" , inplace_params )
290
297
@given (shapes = hh .mutually_broadcastable_shapes (2 ), data = st .data ())
291
298
def test_inplace_op_promotion (op , expr , in_dtypes , out_dtype , shapes , data ):
292
299
assume (len (shapes [0 ]) >= len (shapes [1 ]))
293
300
elements = func_elements [func_name ]
294
301
x1 = data .draw (
295
- xps .arrays (dtype = in_dtypes [0 ], shape = shapes [0 ], elements = elements ), label = 'x1'
302
+ xps .arrays (dtype = in_dtypes [0 ], shape = shapes [0 ], elements = elements ), label = "x1"
296
303
)
297
304
x2 = data .draw (
298
- xps .arrays (dtype = in_dtypes [1 ], shape = shapes [1 ], elements = elements ), label = 'x2'
305
+ xps .arrays (dtype = in_dtypes [1 ], shape = shapes [1 ], elements = elements ), label = "x2"
299
306
)
300
- locals_ = {'x1' : x1 , 'x2' : x2 }
307
+ locals_ = {"x1" : x1 , "x2" : x2 }
301
308
try :
302
309
exec (expr , locals_ )
303
310
except OverflowError :
304
311
reject ()
305
- x1 = locals_ ['x1' ]
306
- ph .assert_dtype (op , in_dtypes , x1 .dtype , out_dtype , out_name = ' x1.dtype' )
312
+ x1 = locals_ ["x1" ]
313
+ ph .assert_dtype (op , in_dtypes , x1 .dtype , out_dtype , out_name = " x1.dtype" )
307
314
308
315
309
316
op_scalar_params : List [Param [str , str , DataType , ScalarType , DataType ]] = []
310
317
for op , symbol in dh .binary_op_to_symbol .items ():
311
- if op == ' __matmul__' :
318
+ if op == " __matmul__" :
312
319
continue
313
320
for in_dtype in dh .func_in_dtypes [op ]:
314
321
out_dtype = xp .bool if dh .func_returns_bool [op ] else in_dtype
315
322
for in_stype in dh .dtype_to_scalars [in_dtype ]:
316
323
p = pytest .param (
317
324
op ,
318
- f' x { symbol } s' ,
325
+ f" x { symbol } s" ,
319
326
in_dtype ,
320
327
in_stype ,
321
328
out_dtype ,
@@ -324,57 +331,57 @@ def test_inplace_op_promotion(op, expr, in_dtypes, out_dtype, shapes, data):
324
331
op_scalar_params .append (p )
325
332
326
333
327
- @pytest .mark .parametrize (' op, expr, in_dtype, in_stype, out_dtype' , op_scalar_params )
334
+ @pytest .mark .parametrize (" op, expr, in_dtype, in_stype, out_dtype" , op_scalar_params )
328
335
@given (data = st .data ())
329
336
def test_op_scalar_promotion (op , expr , in_dtype , in_stype , out_dtype , data ):
330
337
elements = func_elements [func_name ]
331
- kw = {k : in_stype is float for k in (' allow_nan' , ' allow_infinity' )}
332
- s = data .draw (xps .from_dtype (in_dtype , ** kw ).map (in_stype ), label = ' scalar' )
338
+ kw = {k : in_stype is float for k in (" allow_nan" , " allow_infinity" )}
339
+ s = data .draw (xps .from_dtype (in_dtype , ** kw ).map (in_stype ), label = " scalar" )
333
340
x = data .draw (
334
- xps .arrays (dtype = in_dtype , shape = hh .shapes (), elements = elements ), label = 'x'
341
+ xps .arrays (dtype = in_dtype , shape = hh .shapes (), elements = elements ), label = "x"
335
342
)
336
343
try :
337
- out = eval (expr , {'x' : x , 's' : s })
344
+ out = eval (expr , {"x" : x , "s" : s })
338
345
except OverflowError :
339
346
reject ()
340
347
ph .assert_dtype (op , (in_dtype , in_stype ), out .dtype , out_dtype )
341
348
342
349
343
350
inplace_scalar_params : List [Param [str , str , DataType , ScalarType ]] = []
344
351
for op , symbol in dh .inplace_op_to_symbol .items ():
345
- if op == ' __imatmul__' :
352
+ if op == " __imatmul__" :
346
353
continue
347
354
for dtype in dh .func_in_dtypes [op ]:
348
355
for in_stype in dh .dtype_to_scalars [dtype ]:
349
356
p = pytest .param (
350
357
op ,
351
- f' x { symbol } s' ,
358
+ f" x { symbol } s" ,
352
359
dtype ,
353
360
in_stype ,
354
361
id = make_id (op , (dtype , in_stype ), dtype ),
355
362
)
356
363
inplace_scalar_params .append (p )
357
364
358
365
359
- @pytest .mark .parametrize (' op, expr, dtype, in_stype' , inplace_scalar_params )
366
+ @pytest .mark .parametrize (" op, expr, dtype, in_stype" , inplace_scalar_params )
360
367
@given (data = st .data ())
361
368
def test_inplace_op_scalar_promotion (op , expr , dtype , in_stype , data ):
362
369
elements = func_elements [func_name ]
363
- kw = {k : in_stype is float for k in (' allow_nan' , ' allow_infinity' )}
364
- s = data .draw (xps .from_dtype (dtype , ** kw ).map (in_stype ), label = ' scalar' )
370
+ kw = {k : in_stype is float for k in (" allow_nan" , " allow_infinity" )}
371
+ s = data .draw (xps .from_dtype (dtype , ** kw ).map (in_stype ), label = " scalar" )
365
372
x = data .draw (
366
- xps .arrays (dtype = dtype , shape = hh .shapes (), elements = elements ), label = 'x'
373
+ xps .arrays (dtype = dtype , shape = hh .shapes (), elements = elements ), label = "x"
367
374
)
368
- locals_ = {'x' : x , 's' : s }
375
+ locals_ = {"x" : x , "s" : s }
369
376
try :
370
377
exec (expr , locals_ )
371
378
except OverflowError :
372
379
reject ()
373
- x = locals_ ['x' ]
374
- assert x .dtype == dtype , f' { x .dtype = !s} , but should be { dtype } '
375
- ph .assert_dtype (op , (dtype , in_stype ), x .dtype , dtype , out_name = ' x.dtype' )
380
+ x = locals_ ["x" ]
381
+ assert x .dtype == dtype , f" { x .dtype = !s} , but should be { dtype } "
382
+ ph .assert_dtype (op , (dtype , in_stype ), x .dtype , dtype , out_name = " x.dtype" )
376
383
377
384
378
- if __name__ == ' __main__' :
385
+ if __name__ == " __main__" :
379
386
for (i , j ), p in dh .promotion_table .items ():
380
- print (f' ({ i } , { j } ) -> { p } ' )
387
+ print (f" ({ i } , { j } ) -> { p } " )
0 commit comments