Skip to content

Commit 9decccd

Browse files
committed
Make test_eye more efficient
1 parent f82c7bc commit 9decccd

File tree

1 file changed

+7
-6
lines changed

1 file changed

+7
-6
lines changed

array_api_tests/test_creation_functions.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -355,13 +355,14 @@ def test_eye(n_rows, n_cols, kw):
355355
_n_cols = n_rows if n_cols is None else n_cols
356356
ph.assert_shape("eye", out_shape=out.shape, expected=(n_rows, _n_cols), kw=dict(n_rows=n_rows, n_cols=n_cols))
357357
f_func = f"[eye({n_rows=}, {n_cols=})]"
358-
for i in range(n_rows):
359-
for j in range(_n_cols):
358+
k = kw.get("k", 0)
359+
expected = xp.asarray([[1 if j - i == k else 0
360+
for j in range(_n_cols)] for i in range(n_rows)]).reshape(n_rows, _n_cols)
361+
assert out.shape == expected.shape
362+
if xp.any(out != expected):
363+
for i, j in zip(*xp.where(out != expected)):
360364
f_indexed_out = f"out[{i}, {j}]={out[i, j]}"
361-
if j - i == kw.get("k", 0):
362-
assert out[i, j] == 1, f"{f_indexed_out}, should be 1 {f_func}"
363-
else:
364-
assert out[i, j] == 0, f"{f_indexed_out}, should be 0 {f_func}"
365+
assert out[i, j] == expected[i, j], f"{f_indexed_out}, should be {expected[i, j]} {f_func}"
365366

366367

367368
default_unsafe_dtypes = [xp.uint64]

0 commit comments

Comments
 (0)