Skip to content

Commit 6fa2f9c

Browse files
committed
Use ph.assert_dtype in test_matmul (proof of concept)
1 parent 051d2b1 commit 6fa2f9c

File tree

1 file changed

+9
-3
lines changed

1 file changed

+9
-3
lines changed

array_api_tests/test_linalg.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@
2525
mutually_promotable_dtypes, one_d_shapes,
2626
two_mutually_broadcastable_shapes,
2727
SQRT_MAX_ARRAY_SIZE, finite_matrices)
28-
from .pytest_helpers import raises
2928
from . import dtype_helpers as dh
29+
from . import pytest_helpers as ph
3030

3131
from .test_broadcasting import broadcast_shapes
3232

@@ -265,13 +265,19 @@ def test_matmul(x1, x2):
265265
or len(x1.shape) >= 2 and len(x2.shape) >= 2 and x1.shape[-1] != x2.shape[-2]):
266266
# The spec doesn't specify what kind of exception is used here. Most
267267
# libraries will use a custom exception class.
268-
raises(Exception, lambda: _array_module.matmul(x1, x2),
268+
ph.raises(Exception, lambda: _array_module.matmul(x1, x2),
269269
"matmul did not raise an exception for invalid shapes")
270270
return
271271
else:
272272
res = _array_module.matmul(x1, x2)
273273

274-
assert res.dtype == dh.promotion_table[x1.dtype, x2.dtype], "matmul() did not return the correct dtype"
274+
ph.assert_dtype(
275+
"matmul",
276+
(x1.dtype, x2.dtype),
277+
"out.dtype",
278+
res.dtype,
279+
dh.promotion_table[x1.dtype, x2.dtype],
280+
)
275281

276282
if len(x1.shape) == len(x2.shape) == 1:
277283
assert res.shape == ()

0 commit comments

Comments
 (0)