Skip to content

Commit 4061965

Browse files
committed
Include func/op and param dtypes in type promotion error messages
1 parent 2f7933a commit 4061965

File tree

1 file changed

+49
-26
lines changed

1 file changed

+49
-26
lines changed

array_api_tests/test_type_promotion.py

Lines changed: 49 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
https://data-apis.github.io/array-api/latest/API_specification/type_promotion.html
33
"""
44
from collections import defaultdict
5+
from functools import lru_cache
56
from typing import Tuple, Type, Union, List
67

78
import pytest
@@ -22,6 +23,26 @@
2223
Param = Tuple
2324

2425

26+
@lru_cache
27+
def fmt_types(types: Tuple[Union[DT, ScalarType], ...]) -> str:
28+
f_types = []
29+
for type_ in types:
30+
try:
31+
f_types.append(dh.dtype_to_name[type_])
32+
except KeyError:
33+
# i.e. dtype is bool, int, or float
34+
f_types.append(type_.__name__)
35+
return ', '.join(f_types)
36+
37+
38+
def assert_dtype(test_case: str, result_name: str, dtype: DT, expected: DT):
39+
msg = (
40+
f'{result_name}={dh.dtype_to_name[dtype]}, '
41+
f'but should be {dh.dtype_to_name[expected]} [{test_case}]'
42+
)
43+
assert dtype == expected, msg
44+
45+
2546
def multi_promotable_dtypes(
2647
allow_bool: bool = True,
2748
) -> st.SearchStrategy[Tuple[DT, ...]]:
@@ -39,8 +60,9 @@ def multi_promotable_dtypes(
3960
@given(multi_promotable_dtypes())
4061
def test_result_type(dtypes):
4162
out = xp.result_type(*dtypes)
42-
expected = dh.result_type(*dtypes)
43-
assert out == expected, f'{out=!s}, but should be {expected}'
63+
assert_dtype(
64+
f'result_type({fmt_types(dtypes)})', 'out', out, dh.result_type(*dtypes)
65+
)
4466

4567

4668
@given(
@@ -56,10 +78,9 @@ def test_meshgrid(dtypes, kw, data):
5678
arrays.append(x)
5779
out = xp.meshgrid(*arrays, **kw)
5880
expected = dh.result_type(*dtypes)
59-
for i in range(len(out)):
60-
assert (
61-
out[i].dtype == expected
62-
), f'out[{i}]={out[i].dtype}, but should be {expected}'
81+
test_case = f'meshgrid({fmt_types(dtypes)})'
82+
for i, x in enumerate(out):
83+
assert_dtype(test_case, f'out[{i}].dtype', x.dtype, expected)
6384

6485

6586
@given(
@@ -74,8 +95,9 @@ def test_concat(shape, dtypes, kw, data):
7495
x = data.draw(xps.arrays(dtype=dtype, shape=shape), label=f'x{i}')
7596
arrays.append(x)
7697
out = xp.concat(arrays, **kw)
77-
expected = dh.result_type(*dtypes)
78-
assert out.dtype == expected, f'{out.dtype=!s}, but should be {expected}'
98+
assert_dtype(
99+
f'concat({fmt_types(dtypes)})', 'out.dtype', out.dtype, dh.result_type(*dtypes)
100+
)
79101

80102

81103
@given(
@@ -90,8 +112,9 @@ def test_stack(shape, dtypes, kw, data):
90112
x = data.draw(xps.arrays(dtype=dtype, shape=shape), label=f'x{i}')
91113
arrays.append(x)
92114
out = xp.stack(arrays, **kw)
93-
expected = dh.result_type(*dtypes)
94-
assert out.dtype == expected, f'{out.dtype=!s}, but should be {expected}'
115+
assert_dtype(
116+
f'stack({fmt_types(dtypes)})', 'out.dtype', out.dtype, dh.result_type(*dtypes)
117+
)
95118

96119

97120
bitwise_shift_funcs = [
@@ -115,14 +138,7 @@ def test_stack(shape, dtypes, kw, data):
115138
def make_id(
116139
func_name: str, in_dtypes: Tuple[Union[DT, ScalarType], ...], out_dtype: DT
117140
) -> str:
118-
f_in_dtypes = []
119-
for dtype in in_dtypes:
120-
try:
121-
f_in_dtypes.append(dh.dtype_to_name[dtype])
122-
except KeyError:
123-
# i.e. dtype is bool, int, or float
124-
f_in_dtypes.append(dtype.__name__)
125-
f_args = ', '.join(f_in_dtypes)
141+
f_args = fmt_types(in_dtypes)
126142
f_out_dtype = dh.dtype_to_name[out_dtype]
127143
return f'{func_name}({f_args}) -> {f_out_dtype}'
128144

@@ -183,7 +199,9 @@ def test_func_promotion(func_name, in_dtypes, out_dtype, data):
183199
out = func(*arrays)
184200
except OverflowError:
185201
reject()
186-
assert out.dtype == out_dtype, f'{out.dtype=!s}, but should be {out_dtype}'
202+
assert_dtype(
203+
f'{func_name}({fmt_types(in_dtypes)})', 'out.dtype', out.dtype, out_dtype
204+
)
187205

188206

189207
promotion_params: List[Param[Tuple[DT, DT], DT]] = []
@@ -203,7 +221,7 @@ def test_where(in_dtypes, out_dtype, shapes, data):
203221
x2 = data.draw(xps.arrays(dtype=in_dtypes[1], shape=shapes[1]), label='x2')
204222
cond = data.draw(xps.arrays(dtype=xp.bool, shape=shapes[2]), label='condition')
205223
out = xp.where(cond, x1, x2)
206-
assert out.dtype == out_dtype, f'{out.dtype=!s}, but should be {out_dtype}'
224+
assert_dtype(f'where({fmt_types(in_dtypes)})', 'out.dtype', out.dtype, out_dtype)
207225

208226

209227
numeric_promotion_params = promotion_params[1:]
@@ -215,7 +233,7 @@ def test_matmul(in_dtypes, out_dtype, shapes, data):
215233
x1 = data.draw(xps.arrays(dtype=in_dtypes[0], shape=shapes[0]), label='x1')
216234
x2 = data.draw(xps.arrays(dtype=in_dtypes[1], shape=shapes[1]), label='x2')
217235
out = xp.matmul(x1, x2)
218-
assert out.dtype == out_dtype, f'{out.dtype=!s}, but should be {out_dtype}'
236+
assert_dtype(f'matmul({fmt_types(in_dtypes)})', 'out.dtype', out.dtype, out_dtype)
219237

220238

221239
@pytest.mark.parametrize('in_dtypes, out_dtype', numeric_promotion_params)
@@ -224,7 +242,9 @@ def test_tensordot(in_dtypes, out_dtype, shapes, data):
224242
x1 = data.draw(xps.arrays(dtype=in_dtypes[0], shape=shapes[0]), label='x1')
225243
x2 = data.draw(xps.arrays(dtype=in_dtypes[1], shape=shapes[1]), label='x2')
226244
out = xp.tensordot(x1, x2)
227-
assert out.dtype == out_dtype, f'{out.dtype=!s}, but should be {out_dtype}'
245+
assert_dtype(
246+
f'tensordot({fmt_types(in_dtypes)})', 'out.dtype', out.dtype, out_dtype
247+
)
228248

229249

230250
@pytest.mark.parametrize('in_dtypes, out_dtype', numeric_promotion_params)
@@ -233,7 +253,7 @@ def test_vecdot(in_dtypes, out_dtype, shapes, data):
233253
x1 = data.draw(xps.arrays(dtype=in_dtypes[0], shape=shapes[0]), label='x1')
234254
x2 = data.draw(xps.arrays(dtype=in_dtypes[1], shape=shapes[1]), label='x2')
235255
out = xp.vecdot(x1, x2)
236-
assert out.dtype == out_dtype, f'{out.dtype=!s}, but should be {out_dtype}'
256+
assert_dtype(f'vecdot({fmt_types(in_dtypes)})', 'out.dtype', out.dtype, out_dtype)
237257

238258

239259
op_params: List[Param[str, str, Tuple[DT, ...], DT]] = []
@@ -301,7 +321,7 @@ def test_op_promotion(op, expr, in_dtypes, out_dtype, data):
301321
out = eval(expr, locals_)
302322
except OverflowError:
303323
reject()
304-
assert out.dtype == out_dtype, f'{out.dtype=!s}, but should be {out_dtype}'
324+
assert_dtype(f'{op}({fmt_types(in_dtypes)})', 'out.dtype', out.dtype, out_dtype)
305325

306326

307327
inplace_params: List[Param[str, str, Tuple[DT, ...], DT]] = []
@@ -342,7 +362,7 @@ def test_inplace_op_promotion(op, expr, in_dtypes, out_dtype, shapes, data):
342362
except OverflowError:
343363
reject()
344364
x1 = locals_['x1']
345-
assert x1.dtype == out_dtype, f'{x1.dtype=!s}, but should be {out_dtype}'
365+
assert_dtype(f'{op}({fmt_types(in_dtypes)})', 'x1.dtype', x1.dtype, out_dtype)
346366

347367

348368
op_scalar_params: List[Param[str, str, DT, ScalarType, DT]] = []
@@ -376,7 +396,9 @@ def test_op_scalar_promotion(op, expr, in_dtype, in_stype, out_dtype, data):
376396
out = eval(expr, {'x': x, 's': s})
377397
except OverflowError:
378398
reject()
379-
assert out.dtype == out_dtype, f'{out.dtype=!s}, but should be {out_dtype}'
399+
assert_dtype(
400+
f'{op}({fmt_types((in_dtype, in_stype))})', 'out.dtype', out.dtype, out_dtype
401+
)
380402

381403

382404
inplace_scalar_params: List[Param[str, str, DT, ScalarType]] = []
@@ -411,6 +433,7 @@ def test_inplace_op_scalar_promotion(op, expr, dtype, in_stype, data):
411433
reject()
412434
x = locals_['x']
413435
assert x.dtype == dtype, f'{x.dtype=!s}, but should be {dtype}'
436+
assert_dtype(f'{op}({fmt_types((dtype, in_stype))})', 'x.dtype', x.dtype, dtype)
414437

415438

416439
if __name__ == '__main__':

0 commit comments

Comments
 (0)