Skip to content

Commit d68f53f

Browse files
committed
Add more precise output type info to RFFT Op
1 parent e9f58c9 commit d68f53f

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

pytensor/tensor/fft.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,13 @@ class RFFTOp(Op):
1414

1515
def output_type(self, inp):
1616
# add extra dim for real/imag
17-
return TensorType(inp.dtype, shape=(None,) * (inp.type.ndim + 1))
17+
return TensorType(inp.dtype, shape=((None,) * inp.type.ndim) + (2,))
1818

1919
def make_node(self, a, s=None):
2020
a = as_tensor_variable(a)
2121
if a.ndim < 2:
2222
raise TypeError(
23-
f"{self.__class__.__name__}: input must have dimension > 2, with first dimension batches"
23+
f"{self.__class__.__name__}: input must have dimension >= 2, with first dimension batches"
2424
)
2525

2626
if s is None:
@@ -39,9 +39,10 @@ def perform(self, node, inputs, output_storage):
3939
a = inputs[0]
4040
s = inputs[1]
4141

42+
# FIXME: This call is deprecated in numpy 2.0
43+
# axis must be provided when s is not None
4244
A = np.fft.rfftn(a, s=tuple(s))
43-
# Format output with two extra dimensions for real and imaginary
44-
# parts.
45+
# Format output with two extra dimensions for real and imaginary parts.
4546
out = np.zeros((*A.shape, 2), dtype=a.dtype)
4647
out[..., 0], out[..., 1] = np.real(A), np.imag(A)
4748
output_storage[0][0] = out

0 commit comments

Comments
 (0)