Closed
Description
This RFC proposes that n-dimensional FFT APIs should support specifying a placeholder value when specifying the shape of the transformed output.
Prior Art
In PyTorch, when specifying the shape of the transformed output, one can specify a placeholder value of -1
to indicate that a dimension should not be padded.
>>> a = torch.tensor([[[0, 0, 0],
... [0, 0, 0],
... [0, 0, 0]],
...
... [[1, 1, 1],
... [1, 1, 1],
... [1, 1, 1]],
...
... [[2, 2, 2],
... [2, 2, 2],
... [2, 2, 2]]])
>>> y = torch.fft.fftn(a, s=(a.shape[1],3))
>>> y
tensor([[[ 0.+0.j, 0.+0.j, 0.+0.j],
[ 0.+0.j, 0.+0.j, 0.+0.j],
[ 0.+0.j, 0.+0.j, 0.+0.j]],
[[ 9.+0.j, 0.+0.j, 0.+0.j],
[ 0.+0.j, 0.+0.j, 0.+0.j],
[ 0.+0.j, 0.+0.j, 0.+0.j]],
[[18.+0.j, 0.+0.j, 0.+0.j],
[ 0.+0.j, 0.+0.j, 0.+0.j],
[ 0.+0.j, 0.+0.j, 0.+0.j]]])
>>> y = torch.fft.fftn(a, s=(-1,3))
>>> y
tensor([[[ 0.+0.j, 0.+0.j, 0.+0.j],
[ 0.+0.j, 0.+0.j, 0.+0.j],
[ 0.+0.j, 0.+0.j, 0.+0.j]],
[[ 9.+0.j, 0.+0.j, 0.+0.j],
[ 0.+0.j, 0.+0.j, 0.+0.j],
[ 0.+0.j, 0.+0.j, 0.+0.j]],
[[18.+0.j, 0.+0.j, 0.+0.j],
[ 0.+0.j, 0.+0.j, 0.+0.j],
[ 0.+0.j, 0.+0.j, 0.+0.j]]])
This is especially convenient when not wanting to access the array shape (as might be the case for accelerator libraries).
In contrast, NumPy does not support a placeholder, requiring that, in order for a dimension to not be padded, one must specify the existing dimension size.
>>> np.fft.fftn(a, s=(3, 3))
array([[[ 0.+0.j, 0.+0.j, 0.+0.j],
[ 0.+0.j, 0.+0.j, 0.+0.j],
[ 0.+0.j, 0.+0.j, 0.+0.j]],
[[ 9.+0.j, 0.+0.j, 0.+0.j],
[ 0.+0.j, 0.+0.j, 0.+0.j],
[ 0.+0.j, 0.+0.j, 0.+0.j]],
[[18.+0.j, 0.+0.j, 0.+0.j],
[ 0.+0.j, 0.+0.j, 0.+0.j],
[ 0.+0.j, 0.+0.j, 0.+0.j]]])
>>> np.fft.fftn(a, s=(-1, 3))
Traceback (most recent call last):
File "/.../numpy/fft/_pocketfft.py", line 80, in _get_forward_norm
raise ValueError(f"Invalid number of FFT data points ({n}) specified.")
ValueError: Invalid number of FFT data points (-1) specified.
Proposal
This RFC proposes that the proposed FFT specification adopt PyTorch's API and support -1
as a valid value to indicate that a dimension should not be padded (or truncated).
Questions
- Does this RFC present overriding backward compatibility concerns?
- Is there an alternative approach which would be better?