@@ -274,29 +274,45 @@ def test_reshape(x, data):
274
274
275
275
@given (xps .arrays (dtype = xps .scalar_dtypes (), shape = shared_shapes ()), st .data ())
276
276
def test_roll (x , data ):
277
- shift = data .draw (
278
- st .integers () | st .lists (st .integers (), max_size = x .ndim ).map (tuple ),
279
- label = "shift" ,
280
- )
281
- axis_strats = [st .none ()]
282
- if x .shape != ():
283
- axis_strats .append (st .integers (- x .ndim , x .ndim - 1 ))
284
- if isinstance (shift , int ):
285
- axis_strats .append (xps .valid_tuple_axes (x .ndim ))
286
- kw = data .draw (hh .kwargs (axis = st .one_of (axis_strats )), label = "kw" )
277
+ shift_strat = st .integers (- hh .MAX_ARRAY_SIZE , hh .MAX_ARRAY_SIZE )
278
+ if x .ndim > 0 :
279
+ shift_strat = shift_strat | st .lists (
280
+ shift_strat , min_size = 1 , max_size = x .ndim
281
+ ).map (tuple )
282
+ shift = data .draw (shift_strat , label = "shift" )
283
+ if isinstance (shift , tuple ):
284
+ axis_strat = xps .valid_tuple_axes (x .ndim ).filter (lambda t : len (t ) == len (shift ))
285
+ kw_strat = axis_strat .map (lambda t : {"axis" : t })
286
+ else :
287
+ axis_strat = st .none ()
288
+ if x .ndim != 0 :
289
+ axis_strat = axis_strat | st .integers (- x .ndim , x .ndim - 1 )
290
+ kw_strat = hh .kwargs (axis = axis_strat )
291
+ kw = data .draw (kw_strat , label = "kw" )
287
292
288
293
out = xp .roll (x , shift , ** kw )
289
294
290
295
ph .assert_dtype ("roll" , x .dtype , out .dtype )
291
296
292
297
ph .assert_result_shape ("roll" , (x .shape ,), out .shape )
293
298
294
- # TODO: test all shift/axis scenarios
295
- if isinstance (shift , int ) and kw . get ( "axis" , None ) is None :
299
+ if kw . get ( "axis" , None ) is None :
300
+ assert isinstance (shift , int ) # sanity check
296
301
indices = list (ah .ndindex (x .shape ))
297
302
shifted_indices = deque (indices )
298
303
shifted_indices .rotate (- shift )
299
304
assert_array_ndindex ("roll" , x , indices , out , shifted_indices )
305
+ else :
306
+ _shift = (shift ,) if isinstance (shift , int ) else shift
307
+ axes = normalise_axis (kw ["axis" ], x .ndim )
308
+ all_indices = list (ah .ndindex (x .shape ))
309
+ for s , a in zip (_shift , axes ):
310
+ side = x .shape [a ]
311
+ for i in range (side ):
312
+ indices = [idx for idx in all_indices if idx [a ] == i ]
313
+ shifted_indices = deque (indices )
314
+ shifted_indices .rotate (- s )
315
+ assert_array_ndindex ("roll" , x , indices , out , shifted_indices )
300
316
301
317
302
318
@given (
0 commit comments