Skip to content

Commit 66a1fd4

Browse files
committed
Favour lists compared to tuples for ph.assert_dtypes()
Tuples give the impression of `in_dtype` being hetereogenous
1 parent 2077986 commit 66a1fd4

File tree

6 files changed

+24
-25
lines changed

6 files changed

+24
-25
lines changed

array_api_tests/meta/test_pytest_helpers.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@
55

66

77
def test_assert_dtype():
8-
ph.assert_dtype("promoted_func", (xp.uint8, xp.int8), xp.int16)
8+
ph.assert_dtype("promoted_func", [xp.uint8, xp.int8], xp.int16)
99
with raises(AssertionError):
10-
ph.assert_dtype("bad_func", (xp.uint8, xp.int8), xp.float32)
11-
ph.assert_dtype("bool_func", (xp.uint8, xp.int8), xp.bool, xp.bool)
12-
ph.assert_dtype("single_promoted_func", (xp.uint8,), xp.uint8)
13-
ph.assert_dtype("single_bool_func", (xp.uint8,), xp.bool, xp.bool)
10+
ph.assert_dtype("bad_func", [xp.uint8, xp.int8], xp.float32)
11+
ph.assert_dtype("bool_func", [xp.uint8, xp.int8], xp.bool, xp.bool)
12+
ph.assert_dtype("single_promoted_func", [xp.uint8], xp.uint8)
13+
ph.assert_dtype("single_bool_func", [xp.uint8], xp.bool, xp.bool)

array_api_tests/pytest_helpers.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import math
22
from inspect import getfullargspec
3-
from typing import Any, Dict, Optional, Tuple, Union
3+
from typing import Any, Dict, Optional, Sequence, Tuple, Union
44

55
from . import _array_module as xp
66
from . import array_helpers as ah
@@ -71,15 +71,14 @@ def fmt_kw(kw: Dict[str, Any]) -> str:
7171

7272
def assert_dtype(
7373
func_name: str,
74-
in_dtypes: Union[DataType, Tuple[DataType, ...]],
74+
in_dtype: Union[DataType, Sequence[DataType]],
7575
out_dtype: DataType,
7676
expected: Optional[DataType] = None,
7777
*,
7878
repr_name: str = "out.dtype",
7979
):
80-
if not isinstance(in_dtypes, tuple):
81-
in_dtypes = (in_dtypes,)
82-
f_in_dtypes = dh.fmt_types(in_dtypes)
80+
in_dtypes = in_dtype if isinstance(in_dtype, Sequence) else [in_dtype]
81+
f_in_dtypes = dh.fmt_types(tuple(in_dtypes))
8382
f_out_dtype = dh.dtype_to_name[out_dtype]
8483
if expected is None:
8584
expected = dh.result_type(*in_dtypes)

array_api_tests/test_creation_functions.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ def test_arange(dtype, data):
152152
else:
153153
ph.assert_default_float("arange", out.dtype)
154154
else:
155-
ph.assert_dtype("arange", (out.dtype,), dtype)
155+
ph.assert_kw_dtype("arange", dtype, out.dtype)
156156
f_sig = ", ".join(str(n) for n in args)
157157
if len(kwargs) > 0:
158158
f_sig += f", {ph.fmt_kw(kwargs)}"
@@ -302,7 +302,7 @@ def test_empty(shape, kw):
302302
def test_empty_like(x, kw):
303303
out = xp.empty_like(x, **kw)
304304
if kw.get("dtype", None) is None:
305-
ph.assert_dtype("empty_like", (x.dtype,), out.dtype)
305+
ph.assert_dtype("empty_like", x.dtype, out.dtype)
306306
else:
307307
ph.assert_kw_dtype("empty_like", kw["dtype"], out.dtype)
308308
ph.assert_shape("empty_like", out.shape, x.shape)
@@ -399,7 +399,7 @@ def test_full_like(x, fill_value, kw):
399399
out = xp.full_like(x, fill_value, **kw)
400400
dtype = kw.get("dtype", None) or x.dtype
401401
if kw.get("dtype", None) is None:
402-
ph.assert_dtype("full_like", (x.dtype,), out.dtype)
402+
ph.assert_dtype("full_like", x.dtype, out.dtype)
403403
else:
404404
ph.assert_kw_dtype("full_like", kw["dtype"], out.dtype)
405405
ph.assert_shape("full_like", out.shape, x.shape)
@@ -459,7 +459,7 @@ def test_linspace(num, dtype, endpoint, data):
459459
if dtype is None:
460460
ph.assert_default_float("linspace", out.dtype)
461461
else:
462-
ph.assert_dtype("linspace", (out.dtype,), dtype)
462+
ph.assert_kw_dtype("linspace", dtype, out.dtype)
463463
ph.assert_shape("linspace", out.shape, num, start=stop, stop=stop, num=num)
464464
f_func = f"[linspace({start}, {stop}, {num})]"
465465
if num > 0:
@@ -529,7 +529,7 @@ def test_ones(shape, kw):
529529
def test_ones_like(x, kw):
530530
out = xp.ones_like(x, **kw)
531531
if kw.get("dtype", None) is None:
532-
ph.assert_dtype("ones_like", (x.dtype,), out.dtype)
532+
ph.assert_dtype("ones_like", x.dtype, out.dtype)
533533
else:
534534
ph.assert_kw_dtype("ones_like", kw["dtype"], out.dtype)
535535
ph.assert_shape("ones_like", out.shape, x.shape)
@@ -565,7 +565,7 @@ def test_zeros(shape, kw):
565565
def test_zeros_like(x, kw):
566566
out = xp.zeros_like(x, **kw)
567567
if kw.get("dtype", None) is None:
568-
ph.assert_dtype("zeros_like", (x.dtype,), out.dtype)
568+
ph.assert_dtype("zeros_like", x.dtype, out.dtype)
569569
else:
570570
ph.assert_kw_dtype("zeros_like", kw["dtype"], out.dtype)
571571
ph.assert_shape("zeros_like", out.shape, x.shape)

array_api_tests/test_linalg.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,7 @@ def test_matmul(x1, x2):
297297
else:
298298
res = _array_module.matmul(x1, x2)
299299

300-
ph.assert_dtype("matmul", (x1.dtype, x2.dtype), res.dtype)
300+
ph.assert_dtype("matmul", [x1.dtype, x2.dtype], res.dtype)
301301

302302
if len(x1.shape) == len(x2.shape) == 1:
303303
assert res.shape == ()

array_api_tests/test_operators_and_elementwise_functions.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,7 @@ def assert_binary_param_dtype(
273273
if ctx.right_is_scalar:
274274
in_dtypes = left.dtype
275275
else:
276-
in_dtypes = (left.dtype, right.dtype) # type: ignore
276+
in_dtypes = [left.dtype, right.dtype] # type: ignore
277277
ph.assert_dtype(
278278
ctx.func_name, in_dtypes, res.dtype, expected, repr_name=f"{ctx.res_name}.dtype"
279279
)
@@ -443,7 +443,7 @@ def test_atan(x):
443443
@given(*hh.two_mutual_arrays(dh.float_dtypes))
444444
def test_atan2(x1, x2):
445445
out = xp.atan2(x1, x2)
446-
ph.assert_dtype("atan2", (x1.dtype, x2.dtype), out.dtype)
446+
ph.assert_dtype("atan2", [x1.dtype, x2.dtype], out.dtype)
447447
ph.assert_result_shape("atan2", (x1.shape, x2.shape), out.shape)
448448
INFINITY1 = ah.infinity(x1.shape, x1.dtype)
449449
INFINITY2 = ah.infinity(x2.shape, x2.dtype)
@@ -1294,7 +1294,7 @@ def test_log10(x):
12941294
@given(*hh.two_mutual_arrays(dh.float_dtypes))
12951295
def test_logaddexp(x1, x2):
12961296
out = xp.logaddexp(x1, x2)
1297-
ph.assert_dtype("logaddexp", (x1.dtype, x2.dtype), out.dtype)
1297+
ph.assert_dtype("logaddexp", [x1.dtype, x2.dtype], out.dtype)
12981298
# The spec doesn't require any behavior for this function. We could test
12991299
# that this is indeed an approximation of log(exp(x1) + exp(x2)), but we
13001300
# don't have tests for this sort of thing for any functions yet.
@@ -1303,7 +1303,7 @@ def test_logaddexp(x1, x2):
13031303
@given(*hh.two_mutual_arrays([xp.bool]))
13041304
def test_logical_and(x1, x2):
13051305
out = ah.logical_and(x1, x2)
1306-
ph.assert_dtype("logical_and", (x1.dtype, x2.dtype), out.dtype)
1306+
ph.assert_dtype("logical_and", [x1.dtype, x2.dtype], out.dtype)
13071307
ph.assert_result_shape("logical_and", (x1.shape, x2.shape), out.shape)
13081308
binary_assert_against_refimpl(
13091309
"logical_and",
@@ -1329,7 +1329,7 @@ def test_logical_not(x):
13291329
@given(*hh.two_mutual_arrays([xp.bool]))
13301330
def test_logical_or(x1, x2):
13311331
out = ah.logical_or(x1, x2)
1332-
ph.assert_dtype("logical_or", (x1.dtype, x2.dtype), out.dtype)
1332+
ph.assert_dtype("logical_or", [x1.dtype, x2.dtype], out.dtype)
13331333
ph.assert_result_shape("logical_or", (x1.shape, x2.shape), out.shape)
13341334
binary_assert_against_refimpl(
13351335
"logical_or", bool, x1, x2, out, lambda l, r: l or r, "({} or {})={}"
@@ -1339,7 +1339,7 @@ def test_logical_or(x1, x2):
13391339
@given(*hh.two_mutual_arrays([xp.bool]))
13401340
def test_logical_xor(x1, x2):
13411341
out = xp.logical_xor(x1, x2)
1342-
ph.assert_dtype("logical_xor", (x1.dtype, x2.dtype), out.dtype)
1342+
ph.assert_dtype("logical_xor", [x1.dtype, x2.dtype], out.dtype)
13431343
ph.assert_result_shape("logical_xor", (x1.shape, x2.shape), out.shape)
13441344
binary_assert_against_refimpl(
13451345
"logical_xor", bool, x1, x2, out, lambda l, r: l ^ r, "({} ^ {})={}"

array_api_tests/test_type_promotion.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,7 @@ def test_op_scalar_promotion(op, expr, in_dtype, in_stype, out_dtype, data):
271271
out = eval(expr, {"x": x, "s": s})
272272
except OverflowError:
273273
reject()
274-
ph.assert_dtype(op, (in_dtype, in_stype), out.dtype, out_dtype)
274+
ph.assert_dtype(op, [in_dtype, in_stype], out.dtype, out_dtype)
275275

276276

277277
inplace_scalar_params: List[Param[str, str, DataType, ScalarType]] = []
@@ -307,7 +307,7 @@ def test_inplace_op_scalar_promotion(op, expr, dtype, in_stype, data):
307307
reject()
308308
x = locals_["x"]
309309
assert x.dtype == dtype, f"{x.dtype=!s}, but should be {dtype}"
310-
ph.assert_dtype(op, (dtype, in_stype), x.dtype, dtype, repr_name="x.dtype")
310+
ph.assert_dtype(op, [dtype, in_stype], x.dtype, dtype, repr_name="x.dtype")
311311

312312

313313
if __name__ == "__main__":

0 commit comments

Comments
 (0)