2
2
https://data-apis.github.io/array-api/latest/API_specification/type_promotion.html
3
3
"""
4
4
from collections import defaultdict
5
+ from functools import lru_cache
5
6
from typing import Tuple , Type , Union , List
6
7
7
8
import pytest
22
23
Param = Tuple
23
24
24
25
26
+ @lru_cache
27
+ def fmt_types (types : Tuple [Union [DT , ScalarType ], ...]) -> str :
28
+ f_types = []
29
+ for type_ in types :
30
+ try :
31
+ f_types .append (dh .dtype_to_name [type_ ])
32
+ except KeyError :
33
+ # i.e. dtype is bool, int, or float
34
+ f_types .append (type_ .__name__ )
35
+ return ', ' .join (f_types )
36
+
37
+
38
+ def assert_dtype (test_case : str , result_name : str , dtype : DT , expected : DT ):
39
+ msg = (
40
+ f'{ result_name } ={ dh .dtype_to_name [dtype ]} , '
41
+ f'but should be { dh .dtype_to_name [expected ]} [{ test_case } ]'
42
+ )
43
+ assert dtype == expected , msg
44
+
45
+
25
46
def multi_promotable_dtypes (
26
47
allow_bool : bool = True ,
27
48
) -> st .SearchStrategy [Tuple [DT , ...]]:
@@ -39,8 +60,9 @@ def multi_promotable_dtypes(
39
60
@given (multi_promotable_dtypes ())
40
61
def test_result_type (dtypes ):
41
62
out = xp .result_type (* dtypes )
42
- expected = dh .result_type (* dtypes )
43
- assert out == expected , f'{ out = !s} , but should be { expected } '
63
+ assert_dtype (
64
+ f'result_type({ fmt_types (dtypes )} )' , 'out' , out , dh .result_type (* dtypes )
65
+ )
44
66
45
67
46
68
@given (
@@ -56,10 +78,9 @@ def test_meshgrid(dtypes, kw, data):
56
78
arrays .append (x )
57
79
out = xp .meshgrid (* arrays , ** kw )
58
80
expected = dh .result_type (* dtypes )
59
- for i in range (len (out )):
60
- assert (
61
- out [i ].dtype == expected
62
- ), f'out[{ i } ]={ out [i ].dtype } , but should be { expected } '
81
+ test_case = f'meshgrid({ fmt_types (dtypes )} )'
82
+ for i , x in enumerate (out ):
83
+ assert_dtype (test_case , f'out[{ i } ].dtype' , x .dtype , expected )
63
84
64
85
65
86
@given (
@@ -74,8 +95,9 @@ def test_concat(shape, dtypes, kw, data):
74
95
x = data .draw (xps .arrays (dtype = dtype , shape = shape ), label = f'x{ i } ' )
75
96
arrays .append (x )
76
97
out = xp .concat (arrays , ** kw )
77
- expected = dh .result_type (* dtypes )
78
- assert out .dtype == expected , f'{ out .dtype = !s} , but should be { expected } '
98
+ assert_dtype (
99
+ f'concat({ fmt_types (dtypes )} )' , 'out.dtype' , out .dtype , dh .result_type (* dtypes )
100
+ )
79
101
80
102
81
103
@given (
@@ -90,8 +112,9 @@ def test_stack(shape, dtypes, kw, data):
90
112
x = data .draw (xps .arrays (dtype = dtype , shape = shape ), label = f'x{ i } ' )
91
113
arrays .append (x )
92
114
out = xp .stack (arrays , ** kw )
93
- expected = dh .result_type (* dtypes )
94
- assert out .dtype == expected , f'{ out .dtype = !s} , but should be { expected } '
115
+ assert_dtype (
116
+ f'stack({ fmt_types (dtypes )} )' , 'out.dtype' , out .dtype , dh .result_type (* dtypes )
117
+ )
95
118
96
119
97
120
bitwise_shift_funcs = [
@@ -115,14 +138,7 @@ def test_stack(shape, dtypes, kw, data):
115
138
def make_id (
116
139
func_name : str , in_dtypes : Tuple [Union [DT , ScalarType ], ...], out_dtype : DT
117
140
) -> str :
118
- f_in_dtypes = []
119
- for dtype in in_dtypes :
120
- try :
121
- f_in_dtypes .append (dh .dtype_to_name [dtype ])
122
- except KeyError :
123
- # i.e. dtype is bool, int, or float
124
- f_in_dtypes .append (dtype .__name__ )
125
- f_args = ', ' .join (f_in_dtypes )
141
+ f_args = fmt_types (in_dtypes )
126
142
f_out_dtype = dh .dtype_to_name [out_dtype ]
127
143
return f'{ func_name } ({ f_args } ) -> { f_out_dtype } '
128
144
@@ -183,7 +199,9 @@ def test_func_promotion(func_name, in_dtypes, out_dtype, data):
183
199
out = func (* arrays )
184
200
except OverflowError :
185
201
reject ()
186
- assert out .dtype == out_dtype , f'{ out .dtype = !s} , but should be { out_dtype } '
202
+ assert_dtype (
203
+ f'{ func_name } ({ fmt_types (in_dtypes )} )' , 'out.dtype' , out .dtype , out_dtype
204
+ )
187
205
188
206
189
207
promotion_params : List [Param [Tuple [DT , DT ], DT ]] = []
@@ -203,7 +221,7 @@ def test_where(in_dtypes, out_dtype, shapes, data):
203
221
x2 = data .draw (xps .arrays (dtype = in_dtypes [1 ], shape = shapes [1 ]), label = 'x2' )
204
222
cond = data .draw (xps .arrays (dtype = xp .bool , shape = shapes [2 ]), label = 'condition' )
205
223
out = xp .where (cond , x1 , x2 )
206
- assert out .dtype == out_dtype , f' { out .dtype = !s } , but should be { out_dtype } '
224
+ assert_dtype ( f'where( { fmt_types ( in_dtypes ) } )' , ' out.dtype' , out .dtype , out_dtype )
207
225
208
226
209
227
numeric_promotion_params = promotion_params [1 :]
@@ -215,7 +233,7 @@ def test_matmul(in_dtypes, out_dtype, shapes, data):
215
233
x1 = data .draw (xps .arrays (dtype = in_dtypes [0 ], shape = shapes [0 ]), label = 'x1' )
216
234
x2 = data .draw (xps .arrays (dtype = in_dtypes [1 ], shape = shapes [1 ]), label = 'x2' )
217
235
out = xp .matmul (x1 , x2 )
218
- assert out .dtype == out_dtype , f' { out .dtype = !s } , but should be { out_dtype } '
236
+ assert_dtype ( f'matmul( { fmt_types ( in_dtypes ) } )' , ' out.dtype' , out .dtype , out_dtype )
219
237
220
238
221
239
@pytest .mark .parametrize ('in_dtypes, out_dtype' , numeric_promotion_params )
@@ -224,7 +242,9 @@ def test_tensordot(in_dtypes, out_dtype, shapes, data):
224
242
x1 = data .draw (xps .arrays (dtype = in_dtypes [0 ], shape = shapes [0 ]), label = 'x1' )
225
243
x2 = data .draw (xps .arrays (dtype = in_dtypes [1 ], shape = shapes [1 ]), label = 'x2' )
226
244
out = xp .tensordot (x1 , x2 )
227
- assert out .dtype == out_dtype , f'{ out .dtype = !s} , but should be { out_dtype } '
245
+ assert_dtype (
246
+ f'tensordot({ fmt_types (in_dtypes )} )' , 'out.dtype' , out .dtype , out_dtype
247
+ )
228
248
229
249
230
250
@pytest .mark .parametrize ('in_dtypes, out_dtype' , numeric_promotion_params )
@@ -233,7 +253,7 @@ def test_vecdot(in_dtypes, out_dtype, shapes, data):
233
253
x1 = data .draw (xps .arrays (dtype = in_dtypes [0 ], shape = shapes [0 ]), label = 'x1' )
234
254
x2 = data .draw (xps .arrays (dtype = in_dtypes [1 ], shape = shapes [1 ]), label = 'x2' )
235
255
out = xp .vecdot (x1 , x2 )
236
- assert out .dtype == out_dtype , f' { out .dtype = !s } , but should be { out_dtype } '
256
+ assert_dtype ( f'vecdot( { fmt_types ( in_dtypes ) } )' , ' out.dtype' , out .dtype , out_dtype )
237
257
238
258
239
259
op_params : List [Param [str , str , Tuple [DT , ...], DT ]] = []
@@ -301,7 +321,7 @@ def test_op_promotion(op, expr, in_dtypes, out_dtype, data):
301
321
out = eval (expr , locals_ )
302
322
except OverflowError :
303
323
reject ()
304
- assert out .dtype == out_dtype , f' { out .dtype = !s } , but should be { out_dtype } '
324
+ assert_dtype ( f' { op } ( { fmt_types ( in_dtypes ) } )' , ' out.dtype' , out .dtype , out_dtype )
305
325
306
326
307
327
inplace_params : List [Param [str , str , Tuple [DT , ...], DT ]] = []
@@ -342,7 +362,7 @@ def test_inplace_op_promotion(op, expr, in_dtypes, out_dtype, shapes, data):
342
362
except OverflowError :
343
363
reject ()
344
364
x1 = locals_ ['x1' ]
345
- assert x1 .dtype == out_dtype , f' { x1 .dtype = !s } , but should be { out_dtype } '
365
+ assert_dtype ( f' { op } ( { fmt_types ( in_dtypes ) } )' , ' x1.dtype' , x1 .dtype , out_dtype )
346
366
347
367
348
368
op_scalar_params : List [Param [str , str , DT , ScalarType , DT ]] = []
@@ -376,7 +396,9 @@ def test_op_scalar_promotion(op, expr, in_dtype, in_stype, out_dtype, data):
376
396
out = eval (expr , {'x' : x , 's' : s })
377
397
except OverflowError :
378
398
reject ()
379
- assert out .dtype == out_dtype , f'{ out .dtype = !s} , but should be { out_dtype } '
399
+ assert_dtype (
400
+ f'{ op } ({ fmt_types ((in_dtype , in_stype ))} )' , 'out.dtype' , out .dtype , out_dtype
401
+ )
380
402
381
403
382
404
inplace_scalar_params : List [Param [str , str , DT , ScalarType ]] = []
@@ -411,6 +433,7 @@ def test_inplace_op_scalar_promotion(op, expr, dtype, in_stype, data):
411
433
reject ()
412
434
x = locals_ ['x' ]
413
435
assert x .dtype == dtype , f'{ x .dtype = !s} , but should be { dtype } '
436
+ assert_dtype (f'{ op } ({ fmt_types ((dtype , in_stype ))} )' , 'x.dtype' , x .dtype , dtype )
414
437
415
438
416
439
if __name__ == '__main__' :
0 commit comments