Skip to content

Commit 6e806ec

Browse files
committed
Smoke axes argument in FFT shift tests
1 parent 674dd0a commit 6e806ec

File tree

1 file changed

+11
-12
lines changed

1 file changed

+11
-12
lines changed

array_api_tests/test_fft.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -303,15 +303,14 @@ def test_rfftfreq(n, kw):
303303
ph.assert_shape("rfftfreq", out_shape=out.shape, expected=(n // 2 + 1,), kw={"n": n})
304304

305305

306-
@given(x=hh.arrays(xps.floating_dtypes(), fft_shapes_strat))
307-
def test_fftshift(x):
308-
out = xp.fft.fftshift(x)
309-
ph.assert_dtype("fftshift", in_dtype=x.dtype, out_dtype=out.dtype)
310-
ph.assert_shape("fftshift", out_shape=out.shape, expected=x.shape)
311-
312-
313-
@given(x=hh.arrays(xps.floating_dtypes(), fft_shapes_strat))
314-
def test_ifftshift(x):
315-
out = xp.fft.ifftshift(x)
316-
ph.assert_dtype("ifftshift", in_dtype=x.dtype, out_dtype=out.dtype)
317-
ph.assert_shape("ifftshift", out_shape=out.shape, expected=x.shape)
306+
@pytest.mark.parametrize("func_name", ["fftshift", "ifftshift"])
307+
@given(x=hh.arrays(xps.floating_dtypes(), fft_shapes_strat), data=st.data())
308+
def test_shift_func(func_name, x, data):
309+
func = getattr(xp.fft, func_name)
310+
axes = data.draw(
311+
st.none() | st.lists(st.sampled_from(list(range(x.ndim))), min_size=1, unique=True),
312+
label="axes",
313+
)
314+
out = func(x, axes=axes)
315+
ph.assert_dtype(func_name, in_dtype=x.dtype, out_dtype=out.dtype)
316+
ph.assert_shape(func_name, out_shape=out.shape, expected=x.shape)

0 commit comments

Comments
 (0)