@@ -303,15 +303,14 @@ def test_rfftfreq(n, kw):
303
303
ph .assert_shape ("rfftfreq" , out_shape = out .shape , expected = (n // 2 + 1 ,), kw = {"n" : n })
304
304
305
305
306
- @given (x = hh .arrays (xps .floating_dtypes (), fft_shapes_strat ))
307
- def test_fftshift (x ):
308
- out = xp .fft .fftshift (x )
309
- ph .assert_dtype ("fftshift" , in_dtype = x .dtype , out_dtype = out .dtype )
310
- ph .assert_shape ("fftshift" , out_shape = out .shape , expected = x .shape )
311
-
312
-
313
- @given (x = hh .arrays (xps .floating_dtypes (), fft_shapes_strat ))
314
- def test_ifftshift (x ):
315
- out = xp .fft .ifftshift (x )
316
- ph .assert_dtype ("ifftshift" , in_dtype = x .dtype , out_dtype = out .dtype )
317
- ph .assert_shape ("ifftshift" , out_shape = out .shape , expected = x .shape )
306
+ @pytest .mark .parametrize ("func_name" , ["fftshift" , "ifftshift" ])
307
+ @given (x = hh .arrays (xps .floating_dtypes (), fft_shapes_strat ), data = st .data ())
308
+ def test_shift_func (func_name , x , data ):
309
+ func = getattr (xp .fft , func_name )
310
+ axes = data .draw (
311
+ st .none () | st .lists (st .sampled_from (list (range (x .ndim ))), min_size = 1 , unique = True ),
312
+ label = "axes" ,
313
+ )
314
+ out = func (x , axes = axes )
315
+ ph .assert_dtype (func_name , in_dtype = x .dtype , out_dtype = out .dtype )
316
+ ph .assert_shape (func_name , out_shape = out .shape , expected = x .shape )
0 commit comments