Skip to content

Commit d0d9696

Browse files
authored
Merge pull request #229 from honno/remaining-fft-tests
Remaining FFT tests
2 parents ae0017a + 6e806ec commit d0d9696

File tree

1 file changed

+23
-5
lines changed

1 file changed

+23
-5
lines changed

array_api_tests/test_fft.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -291,8 +291,26 @@ def test_ihfft(x, data):
291291
assert_n_axis_shape("ihfft", x=x, n=n, axis=axis, out=out, size_gt_1=True)
292292

293293

294-
# TODO:
295-
# fftfreq
296-
# rfftfreq
297-
# fftshift
298-
# ifftshift
294+
@given( n=st.integers(1, 100), kw=hh.kwargs(d=st.floats(0.1, 5)))
295+
def test_fftfreq(n, kw):
296+
out = xp.fft.fftfreq(n, **kw)
297+
ph.assert_shape("fftfreq", out_shape=out.shape, expected=(n,), kw={"n": n})
298+
299+
300+
@given(n=st.integers(1, 100), kw=hh.kwargs(d=st.floats(0.1, 5)))
301+
def test_rfftfreq(n, kw):
302+
out = xp.fft.rfftfreq(n, **kw)
303+
ph.assert_shape("rfftfreq", out_shape=out.shape, expected=(n // 2 + 1,), kw={"n": n})
304+
305+
306+
@pytest.mark.parametrize("func_name", ["fftshift", "ifftshift"])
307+
@given(x=hh.arrays(xps.floating_dtypes(), fft_shapes_strat), data=st.data())
308+
def test_shift_func(func_name, x, data):
309+
func = getattr(xp.fft, func_name)
310+
axes = data.draw(
311+
st.none() | st.lists(st.sampled_from(list(range(x.ndim))), min_size=1, unique=True),
312+
label="axes",
313+
)
314+
out = func(x, axes=axes)
315+
ph.assert_dtype(func_name, in_dtype=x.dtype, out_dtype=out.dtype)
316+
ph.assert_shape(func_name, out_shape=out.shape, expected=x.shape)

0 commit comments

Comments
 (0)