@@ -355,13 +355,14 @@ def test_eye(n_rows, n_cols, kw):
355
355
_n_cols = n_rows if n_cols is None else n_cols
356
356
ph .assert_shape ("eye" , out_shape = out .shape , expected = (n_rows , _n_cols ), kw = dict (n_rows = n_rows , n_cols = n_cols ))
357
357
f_func = f"[eye({ n_rows = } , { n_cols = } )]"
358
- for i in range (n_rows ):
359
- for j in range (_n_cols ):
358
+ k = kw .get ("k" , 0 )
359
+ expected = xp .asarray ([[1 if j - i == k else 0
360
+ for j in range (_n_cols )] for i in range (n_rows )]).reshape (n_rows , _n_cols )
361
+ assert out .shape == expected .shape
362
+ if xp .any (out != expected ):
363
+ for i , j in zip (* xp .where (out != expected )):
360
364
f_indexed_out = f"out[{ i } , { j } ]={ out [i , j ]} "
361
- if j - i == kw .get ("k" , 0 ):
362
- assert out [i , j ] == 1 , f"{ f_indexed_out } , should be 1 { f_func } "
363
- else :
364
- assert out [i , j ] == 0 , f"{ f_indexed_out } , should be 0 { f_func } "
365
+ assert out [i , j ] == expected [i , j ], f"{ f_indexed_out } , should be { expected [i , j ]} { f_func } "
365
366
366
367
367
368
default_unsafe_dtypes = [xp .uint64 ]
0 commit comments