@@ -66,14 +66,7 @@ def draw_s_axes_norm_kwargs(x: Array, data: st.DataObject, *, size_gt_1=False) -
66
66
if axes is None :
67
67
s_strat = st .none () | s_strat
68
68
s = data .draw (s_strat , label = "s" )
69
- if size_gt_1 :
70
- _s = x .shape if s is None else s
71
- for i in range (x .ndim ):
72
- if i in _axes :
73
- side = _s [_axes .index (i )]
74
- else :
75
- side = x .shape [i ]
76
- assume (side > 1 )
69
+
77
70
norm = data .draw (st .sampled_from (["backward" , "ortho" , "forward" ]), label = "norm" )
78
71
kwargs = data .draw (
79
72
hh .specified_kwargs (
@@ -86,14 +79,14 @@ def draw_s_axes_norm_kwargs(x: Array, data: st.DataObject, *, size_gt_1=False) -
86
79
return s , axes , norm , kwargs
87
80
88
81
89
- def assert_fft_dtype (func_name : str , * , in_dtype : DataType , out_dtype : DataType ):
82
+ def assert_float_to_complex_dtype (
83
+ func_name : str , * , in_dtype : DataType , out_dtype : DataType
84
+ ):
90
85
if in_dtype == xp .float32 :
91
86
expected = xp .complex64
92
- elif in_dtype == xp .float64 :
93
- expected = xp .complex128
94
87
else :
95
- assert dh . is_float_dtype ( in_dtype ) # sanity check
96
- expected = in_dtype
88
+ assert in_dtype == xp . float64 # sanity check
89
+ expected = xp . complex128
97
90
ph .assert_dtype (
98
91
func_name , in_dtype = in_dtype , out_dtype = out_dtype , expected = expected
99
92
)
@@ -106,14 +99,10 @@ def assert_n_axis_shape(
106
99
n : Optional [int ],
107
100
axis : int ,
108
101
out : Array ,
109
- size_gt_1 : bool = False ,
110
102
):
111
103
_axis = len (x .shape ) - 1 if axis == - 1 else axis
112
104
if n is None :
113
- if size_gt_1 :
114
- axis_side = 2 * (x .shape [_axis ] - 1 )
115
- else :
116
- axis_side = x .shape [_axis ]
105
+ axis_side = x .shape [_axis ]
117
106
else :
118
107
axis_side = n
119
108
expected = x .shape [:_axis ] + (axis_side ,) + x .shape [_axis + 1 :]
@@ -127,7 +116,6 @@ def assert_s_axes_shape(
127
116
s : Optional [List [int ]],
128
117
axes : Optional [List [int ]],
129
118
out : Array ,
130
- size_gt_1 : bool = False ,
131
119
):
132
120
_axes = sh .normalise_axis (axes , x .ndim )
133
121
_s = x .shape if s is None else s
@@ -138,88 +126,78 @@ def assert_s_axes_shape(
138
126
else :
139
127
side = x .shape [i ]
140
128
expected .append (side )
141
- if size_gt_1 :
142
- last_axis = _axes [- 1 ]
143
- expected [last_axis ] = 2 * (expected [last_axis ] - 1 )
144
- assume (expected [last_axis ] > 0 ) # TODO: generate valid examples
145
129
ph .assert_shape (func_name , out_shape = out .shape , expected = tuple (expected ))
146
130
147
131
148
- @given (
149
- x = hh .arrays (dtype = hh .all_floating_dtypes (), shape = fft_shapes_strat ),
150
- data = st .data (),
151
- )
132
+ @given (x = hh .arrays (dtype = xps .complex_dtypes (), shape = fft_shapes_strat ), data = st .data ())
152
133
def test_fft (x , data ):
153
134
n , axis , norm , kwargs = draw_n_axis_norm_kwargs (x , data )
154
135
155
136
out = xp .fft .fft (x , ** kwargs )
156
137
157
- assert_fft_dtype ("fft" , in_dtype = x .dtype , out_dtype = out .dtype )
138
+ ph . assert_dtype ("fft" , in_dtype = x .dtype , out_dtype = out .dtype )
158
139
assert_n_axis_shape ("fft" , x = x , n = n , axis = axis , out = out )
159
140
160
141
161
- @given (
162
- x = hh .arrays (dtype = hh .all_floating_dtypes (), shape = fft_shapes_strat ),
163
- data = st .data (),
164
- )
142
+ @given (x = hh .arrays (dtype = xps .complex_dtypes (), shape = fft_shapes_strat ), data = st .data ())
165
143
def test_ifft (x , data ):
166
144
n , axis , norm , kwargs = draw_n_axis_norm_kwargs (x , data )
167
145
168
146
out = xp .fft .ifft (x , ** kwargs )
169
147
170
- assert_fft_dtype ("ifft" , in_dtype = x .dtype , out_dtype = out .dtype )
148
+ ph . assert_dtype ("ifft" , in_dtype = x .dtype , out_dtype = out .dtype )
171
149
assert_n_axis_shape ("ifft" , x = x , n = n , axis = axis , out = out )
172
150
173
151
174
- @given (
175
- x = hh .arrays (dtype = hh .all_floating_dtypes (), shape = fft_shapes_strat ),
176
- data = st .data (),
177
- )
152
+ @given (x = hh .arrays (dtype = xps .complex_dtypes (), shape = fft_shapes_strat ), data = st .data ())
178
153
def test_fftn (x , data ):
179
154
s , axes , norm , kwargs = draw_s_axes_norm_kwargs (x , data )
180
155
181
156
out = xp .fft .fftn (x , ** kwargs )
182
157
183
- assert_fft_dtype ("fftn" , in_dtype = x .dtype , out_dtype = out .dtype )
158
+ ph . assert_dtype ("fftn" , in_dtype = x .dtype , out_dtype = out .dtype )
184
159
assert_s_axes_shape ("fftn" , x = x , s = s , axes = axes , out = out )
185
160
186
161
187
- @given (
188
- x = hh .arrays (dtype = hh .all_floating_dtypes (), shape = fft_shapes_strat ),
189
- data = st .data (),
190
- )
162
+ @given (x = hh .arrays (dtype = xps .complex_dtypes (), shape = fft_shapes_strat ), data = st .data ())
191
163
def test_ifftn (x , data ):
192
164
s , axes , norm , kwargs = draw_s_axes_norm_kwargs (x , data )
193
165
194
166
out = xp .fft .ifftn (x , ** kwargs )
195
167
196
- assert_fft_dtype ("ifftn" , in_dtype = x .dtype , out_dtype = out .dtype )
168
+ ph . assert_dtype ("ifftn" , in_dtype = x .dtype , out_dtype = out .dtype )
197
169
assert_s_axes_shape ("ifftn" , x = x , s = s , axes = axes , out = out )
198
170
199
171
200
- @given (
201
- x = hh .arrays (dtype = xps .floating_dtypes (), shape = fft_shapes_strat ),
202
- data = st .data (),
203
- )
172
+ @given (x = hh .arrays (dtype = xps .floating_dtypes (), shape = fft_shapes_strat ), data = st .data ())
204
173
def test_rfft (x , data ):
205
174
n , axis , norm , kwargs = draw_n_axis_norm_kwargs (x , data )
206
175
207
176
out = xp .fft .rfft (x , ** kwargs )
208
177
209
- assert_fft_dtype ("rfft" , in_dtype = x .dtype , out_dtype = out .dtype )
210
- assert_n_axis_shape ("rfft" , x = x , n = n , axis = axis , out = out )
178
+ assert_float_to_complex_dtype ("rfft" , in_dtype = x .dtype , out_dtype = out .dtype )
179
+
180
+ _axis = x .ndim - 1 if axis == - 1 else axis
181
+ if n is None :
182
+ axis_side = x .shape [_axis ] // 2 + 1
183
+ else :
184
+ axis_side = n // 2 + 1
185
+ expected_shape = x .shape [:_axis ] + (axis_side ,) + x .shape [_axis + 1 :]
186
+ ph .assert_shape ("rfft" , out_shape = out .shape , expected = expected_shape )
211
187
212
188
213
- @given (
214
- x = hh .arrays (dtype = xps .complex_dtypes (), shape = fft_shapes_strat ),
215
- data = st .data (),
216
- )
189
+ @given (x = hh .arrays (dtype = xps .complex_dtypes (), shape = fft_shapes_strat ), data = st .data ())
217
190
def test_irfft (x , data ):
218
191
n , axis , norm , kwargs = draw_n_axis_norm_kwargs (x , data , size_gt_1 = True )
219
192
220
193
out = xp .fft .irfft (x , ** kwargs )
221
194
222
- assert_fft_dtype ("irfft" , in_dtype = x .dtype , out_dtype = out .dtype )
195
+ ph .assert_dtype (
196
+ "irfft" ,
197
+ in_dtype = x .dtype ,
198
+ out_dtype = out .dtype ,
199
+ expected = dh .dtype_components [x .dtype ],
200
+ )
223
201
224
202
_axis = x .ndim - 1 if axis == - 1 else axis
225
203
if n is None :
@@ -230,17 +208,25 @@ def test_irfft(x, data):
230
208
ph .assert_shape ("irfft" , out_shape = out .shape , expected = expected_shape )
231
209
232
210
233
- @given (
234
- x = hh .arrays (dtype = xps .floating_dtypes (), shape = fft_shapes_strat ),
235
- data = st .data (),
236
- )
211
+ @given (x = hh .arrays (dtype = xps .floating_dtypes (), shape = fft_shapes_strat ), data = st .data ())
237
212
def test_rfftn (x , data ):
238
213
s , axes , norm , kwargs = draw_s_axes_norm_kwargs (x , data )
239
214
240
215
out = xp .fft .rfftn (x , ** kwargs )
241
216
242
- assert_fft_dtype ("rfftn" , in_dtype = x .dtype , out_dtype = out .dtype )
243
- assert_s_axes_shape ("rfftn" , x = x , s = s , axes = axes , out = out )
217
+ assert_float_to_complex_dtype ("rfftn" , in_dtype = x .dtype , out_dtype = out .dtype )
218
+
219
+ _axes = sh .normalise_axis (axes , x .ndim )
220
+ _s = x .shape if s is None else s
221
+ expected = []
222
+ for i in range (x .ndim ):
223
+ if i in _axes :
224
+ side = _s [_axes .index (i )]
225
+ else :
226
+ side = x .shape [i ]
227
+ expected .append (side )
228
+ expected [_axes [- 1 ]] = _s [- 1 ] // 2 + 1
229
+ ph .assert_shape ("rfftn" , out_shape = out .shape , expected = tuple (expected ))
244
230
245
231
246
232
@given (
@@ -250,24 +236,44 @@ def test_rfftn(x, data):
250
236
data = st .data (),
251
237
)
252
238
def test_irfftn (x , data ):
253
- s , axes , norm , kwargs = draw_s_axes_norm_kwargs (x , data , size_gt_1 = True )
239
+ s , axes , norm , kwargs = draw_s_axes_norm_kwargs (x , data )
254
240
255
241
out = xp .fft .irfftn (x , ** kwargs )
256
242
257
- assert_fft_dtype ("irfftn" , in_dtype = x .dtype , out_dtype = out .dtype )
258
- assert_s_axes_shape ("rfftn" , x = x , s = s , axes = axes , out = out , size_gt_1 = True )
259
-
243
+ ph .assert_dtype (
244
+ "irfftn" ,
245
+ in_dtype = x .dtype ,
246
+ out_dtype = out .dtype ,
247
+ expected = dh .dtype_components [x .dtype ],
248
+ )
260
249
261
- @given (
262
- x = hh .arrays (dtype = hh .all_floating_dtypes (), shape = fft_shapes_strat ),
263
- data = st .data (),
264
- )
250
+ # TODO: assert shape correctly
251
+ # _axes = sh.normalise_axis(axes, x.ndim)
252
+ # _s = x.shape if s is None else s
253
+ # expected = []
254
+ # for i in range(x.ndim):
255
+ # if i in _axes:
256
+ # side = _s[_axes.index(i)]
257
+ # else:
258
+ # side = x.shape[i]
259
+ # expected.append(side)
260
+ # last_axis = max(_axes)
261
+ # expected[last_axis] = _s[_axes.index(last_axis)] // 2 + 1
262
+ # ph.assert_shape("irfftn", out_shape=out.shape, expected=tuple(expected))
263
+
264
+
265
+ @given (x = hh .arrays (dtype = xps .complex_dtypes (), shape = fft_shapes_strat ), data = st .data ())
265
266
def test_hfft (x , data ):
266
267
n , axis , norm , kwargs = draw_n_axis_norm_kwargs (x , data , size_gt_1 = True )
267
268
268
269
out = xp .fft .hfft (x , ** kwargs )
269
270
270
- assert_fft_dtype ("hfft" , in_dtype = x .dtype , out_dtype = out .dtype )
271
+ ph .assert_dtype (
272
+ "hfft" ,
273
+ in_dtype = x .dtype ,
274
+ out_dtype = out .dtype ,
275
+ expected = dh .dtype_components [x .dtype ],
276
+ )
271
277
272
278
_axis = x .ndim - 1 if axis == - 1 else axis
273
279
if n is None :
@@ -278,20 +284,24 @@ def test_hfft(x, data):
278
284
ph .assert_shape ("hfft" , out_shape = out .shape , expected = expected_shape )
279
285
280
286
281
- @given (
282
- x = hh .arrays (dtype = xps .floating_dtypes (), shape = fft_shapes_strat ),
283
- data = st .data (),
284
- )
287
+ @given (x = hh .arrays (dtype = xps .floating_dtypes (), shape = fft_shapes_strat ), data = st .data ())
285
288
def test_ihfft (x , data ):
286
289
n , axis , norm , kwargs = draw_n_axis_norm_kwargs (x , data )
287
290
288
291
out = xp .fft .ihfft (x , ** kwargs )
289
292
290
- assert_fft_dtype ("ihfft" , in_dtype = x .dtype , out_dtype = out .dtype )
291
- assert_n_axis_shape ("ihfft" , x = x , n = n , axis = axis , out = out , size_gt_1 = True )
293
+ assert_float_to_complex_dtype ("ihfft" , in_dtype = x .dtype , out_dtype = out .dtype )
294
+
295
+ _axis = x .ndim - 1 if axis == - 1 else axis
296
+ if n is None :
297
+ axis_side = x .shape [_axis ] // 2 + 1
298
+ else :
299
+ axis_side = n // 2 + 1
300
+ expected_shape = x .shape [:_axis ] + (axis_side ,) + x .shape [_axis + 1 :]
301
+ ph .assert_shape ("ihfft" , out_shape = out .shape , expected = expected_shape )
292
302
293
303
294
- @given ( n = st .integers (1 , 100 ), kw = hh .kwargs (d = st .floats (0.1 , 5 )))
304
+ @given (n = st .integers (1 , 100 ), kw = hh .kwargs (d = st .floats (0.1 , 5 )))
295
305
def test_fftfreq (n , kw ):
296
306
out = xp .fft .fftfreq (n , ** kw )
297
307
ph .assert_shape ("fftfreq" , out_shape = out .shape , expected = (n ,), kw = {"n" : n })
@@ -300,15 +310,18 @@ def test_fftfreq(n, kw):
300
310
@given (n = st .integers (1 , 100 ), kw = hh .kwargs (d = st .floats (0.1 , 5 )))
301
311
def test_rfftfreq (n , kw ):
302
312
out = xp .fft .rfftfreq (n , ** kw )
303
- ph .assert_shape ("rfftfreq" , out_shape = out .shape , expected = (n // 2 + 1 ,), kw = {"n" : n })
313
+ ph .assert_shape (
314
+ "rfftfreq" , out_shape = out .shape , expected = (n // 2 + 1 ,), kw = {"n" : n }
315
+ )
304
316
305
317
306
318
@pytest .mark .parametrize ("func_name" , ["fftshift" , "ifftshift" ])
307
319
@given (x = hh .arrays (xps .floating_dtypes (), fft_shapes_strat ), data = st .data ())
308
320
def test_shift_func (func_name , x , data ):
309
321
func = getattr (xp .fft , func_name )
310
322
axes = data .draw (
311
- st .none () | st .lists (st .sampled_from (list (range (x .ndim ))), min_size = 1 , unique = True ),
323
+ st .none ()
324
+ | st .lists (st .sampled_from (list (range (x .ndim ))), min_size = 1 , unique = True ),
312
325
label = "axes" ,
313
326
)
314
327
out = func (x , axes = axes )
0 commit comments