Skip to content

Commit a0cca15

Browse files
committed
Cover elements in test_permute_dims
1 parent 2b60988 commit a0cca15

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

array_api_tests/test_manipulation_functions.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -237,11 +237,12 @@ def test_permute_dims(x, axes):
237237
for i, dim in enumerate(axes):
238238
side = x.shape[dim]
239239
shape[i] = side
240-
assert all(isinstance(side, int) for side in shape) # sanity check
241240
shape = tuple(shape)
242241
ph.assert_result_shape("permute_dims", (x.shape,), out.shape, shape, axes=axes)
243242

244-
# TODO: test elements
243+
indices = list(ah.ndindex(x.shape))
244+
permuted_indices = [tuple(idx[axis] for axis in axes) for idx in indices]
245+
assert_array_ndindex("permute_dims", x, indices, out, permuted_indices)
245246

246247

247248
@st.composite

0 commit comments

Comments
 (0)