Skip to content

Commit 490264e

Browse files
committed
Cover all shift/axes scenarios in test_roll
1 parent c4f109e commit 490264e

File tree

1 file changed

+28
-12
lines changed

1 file changed

+28
-12
lines changed

array_api_tests/test_manipulation_functions.py

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -274,29 +274,45 @@ def test_reshape(x, data):
274274

275275
@given(xps.arrays(dtype=xps.scalar_dtypes(), shape=shared_shapes()), st.data())
276276
def test_roll(x, data):
277-
shift = data.draw(
278-
st.integers() | st.lists(st.integers(), max_size=x.ndim).map(tuple),
279-
label="shift",
280-
)
281-
axis_strats = [st.none()]
282-
if x.shape != ():
283-
axis_strats.append(st.integers(-x.ndim, x.ndim - 1))
284-
if isinstance(shift, int):
285-
axis_strats.append(xps.valid_tuple_axes(x.ndim))
286-
kw = data.draw(hh.kwargs(axis=st.one_of(axis_strats)), label="kw")
277+
shift_strat = st.integers(-hh.MAX_ARRAY_SIZE, hh.MAX_ARRAY_SIZE)
278+
if x.ndim > 0:
279+
shift_strat = shift_strat | st.lists(
280+
shift_strat, min_size=1, max_size=x.ndim
281+
).map(tuple)
282+
shift = data.draw(shift_strat, label="shift")
283+
if isinstance(shift, tuple):
284+
axis_strat = xps.valid_tuple_axes(x.ndim).filter(lambda t: len(t) == len(shift))
285+
kw_strat = axis_strat.map(lambda t: {"axis": t})
286+
else:
287+
axis_strat = st.none()
288+
if x.ndim != 0:
289+
axis_strat = axis_strat | st.integers(-x.ndim, x.ndim - 1)
290+
kw_strat = hh.kwargs(axis=axis_strat)
291+
kw = data.draw(kw_strat, label="kw")
287292

288293
out = xp.roll(x, shift, **kw)
289294

290295
ph.assert_dtype("roll", x.dtype, out.dtype)
291296

292297
ph.assert_result_shape("roll", (x.shape,), out.shape)
293298

294-
# TODO: test all shift/axis scenarios
295-
if isinstance(shift, int) and kw.get("axis", None) is None:
299+
if kw.get("axis", None) is None:
300+
assert isinstance(shift, int) # sanity check
296301
indices = list(ah.ndindex(x.shape))
297302
shifted_indices = deque(indices)
298303
shifted_indices.rotate(-shift)
299304
assert_array_ndindex("roll", x, indices, out, shifted_indices)
305+
else:
306+
_shift = (shift,) if isinstance(shift, int) else shift
307+
axes = normalise_axis(kw["axis"], x.ndim)
308+
all_indices = list(ah.ndindex(x.shape))
309+
for s, a in zip(_shift, axes):
310+
side = x.shape[a]
311+
for i in range(side):
312+
indices = [idx for idx in all_indices if idx[a] == i]
313+
shifted_indices = deque(indices)
314+
shifted_indices.rotate(-s)
315+
assert_array_ndindex("roll", x, indices, out, shifted_indices)
300316

301317

302318
@given(

0 commit comments

Comments
 (0)