|
25 | 25 | mutually_promotable_dtypes, one_d_shapes,
|
26 | 26 | two_mutually_broadcastable_shapes,
|
27 | 27 | SQRT_MAX_ARRAY_SIZE, finite_matrices)
|
28 |
| -from .pytest_helpers import raises |
29 | 28 | from . import dtype_helpers as dh
|
| 29 | +from . import pytest_helpers as ph |
30 | 30 |
|
31 | 31 | from .test_broadcasting import broadcast_shapes
|
32 | 32 |
|
@@ -265,13 +265,19 @@ def test_matmul(x1, x2):
|
265 | 265 | or len(x1.shape) >= 2 and len(x2.shape) >= 2 and x1.shape[-1] != x2.shape[-2]):
|
266 | 266 | # The spec doesn't specify what kind of exception is used here. Most
|
267 | 267 | # libraries will use a custom exception class.
|
268 |
| - raises(Exception, lambda: _array_module.matmul(x1, x2), |
| 268 | + ph.raises(Exception, lambda: _array_module.matmul(x1, x2), |
269 | 269 | "matmul did not raise an exception for invalid shapes")
|
270 | 270 | return
|
271 | 271 | else:
|
272 | 272 | res = _array_module.matmul(x1, x2)
|
273 | 273 |
|
274 |
| - assert res.dtype == dh.promotion_table[x1.dtype, x2.dtype], "matmul() did not return the correct dtype" |
| 274 | + ph.assert_dtype( |
| 275 | + "matmul", |
| 276 | + (x1.dtype, x2.dtype), |
| 277 | + "out.dtype", |
| 278 | + res.dtype, |
| 279 | + dh.promotion_table[x1.dtype, x2.dtype], |
| 280 | + ) |
275 | 281 |
|
276 | 282 | if len(x1.shape) == len(x2.shape) == 1:
|
277 | 283 | assert res.shape == ()
|
|
0 commit comments