@@ -291,8 +291,26 @@ def test_ihfft(x, data):
291
291
assert_n_axis_shape ("ihfft" , x = x , n = n , axis = axis , out = out , size_gt_1 = True )
292
292
293
293
294
- # TODO:
295
- # fftfreq
296
- # rfftfreq
297
- # fftshift
298
- # ifftshift
294
+ @given ( n = st .integers (1 , 100 ), kw = hh .kwargs (d = st .floats (0.1 , 5 )))
295
+ def test_fftfreq (n , kw ):
296
+ out = xp .fft .fftfreq (n , ** kw )
297
+ ph .assert_shape ("fftfreq" , out_shape = out .shape , expected = (n ,), kw = {"n" : n })
298
+
299
+
300
+ @given (n = st .integers (1 , 100 ), kw = hh .kwargs (d = st .floats (0.1 , 5 )))
301
+ def test_rfftfreq (n , kw ):
302
+ out = xp .fft .rfftfreq (n , ** kw )
303
+ ph .assert_shape ("rfftfreq" , out_shape = out .shape , expected = (n // 2 + 1 ,), kw = {"n" : n })
304
+
305
+
306
+ @pytest .mark .parametrize ("func_name" , ["fftshift" , "ifftshift" ])
307
+ @given (x = hh .arrays (xps .floating_dtypes (), fft_shapes_strat ), data = st .data ())
308
+ def test_shift_func (func_name , x , data ):
309
+ func = getattr (xp .fft , func_name )
310
+ axes = data .draw (
311
+ st .none () | st .lists (st .sampled_from (list (range (x .ndim ))), min_size = 1 , unique = True ),
312
+ label = "axes" ,
313
+ )
314
+ out = func (x , axes = axes )
315
+ ph .assert_dtype (func_name , in_dtype = x .dtype , out_dtype = out .dtype )
316
+ ph .assert_shape (func_name , out_shape = out .shape , expected = x .shape )
0 commit comments