Skip to content

RFC: specifying the shape of the transformed output in n-dimensional FFT APIs  #476

Closed
@kgryte

Description

@kgryte

This RFC proposes that n-dimensional FFT APIs should support specifying a placeholder value when specifying the shape of the transformed output.

Prior Art

In PyTorch, when specifying the shape of the transformed output, one can specify a placeholder value of -1 to indicate that a dimension should not be padded.

>>>  a = torch.tensor([[[0, 0, 0],
...         [0, 0, 0],
...         [0, 0, 0]],
... 
...        [[1, 1, 1],
...         [1, 1, 1],
...         [1, 1, 1]],
... 
...        [[2, 2, 2],
...         [2, 2, 2],
...         [2, 2, 2]]])
>>> y = torch.fft.fftn(a, s=(a.shape[1],3))
>>> y
tensor([[[ 0.+0.j,  0.+0.j,  0.+0.j],
         [ 0.+0.j,  0.+0.j,  0.+0.j],
         [ 0.+0.j,  0.+0.j,  0.+0.j]],

        [[ 9.+0.j,  0.+0.j,  0.+0.j],
         [ 0.+0.j,  0.+0.j,  0.+0.j],
         [ 0.+0.j,  0.+0.j,  0.+0.j]],

        [[18.+0.j,  0.+0.j,  0.+0.j],
         [ 0.+0.j,  0.+0.j,  0.+0.j],
         [ 0.+0.j,  0.+0.j,  0.+0.j]]])
>>> y = torch.fft.fftn(a, s=(-1,3))
>>> y
tensor([[[ 0.+0.j,  0.+0.j,  0.+0.j],
         [ 0.+0.j,  0.+0.j,  0.+0.j],
         [ 0.+0.j,  0.+0.j,  0.+0.j]],

        [[ 9.+0.j,  0.+0.j,  0.+0.j],
         [ 0.+0.j,  0.+0.j,  0.+0.j],
         [ 0.+0.j,  0.+0.j,  0.+0.j]],

        [[18.+0.j,  0.+0.j,  0.+0.j],
         [ 0.+0.j,  0.+0.j,  0.+0.j],
         [ 0.+0.j,  0.+0.j,  0.+0.j]]])

This is especially convenient when not wanting to access the array shape (as might be the case for accelerator libraries).

In contrast, NumPy does not support a placeholder, requiring that, in order for a dimension to not be padded, one must specify the existing dimension size.

>>> np.fft.fftn(a, s=(3, 3))
array([[[ 0.+0.j,  0.+0.j,  0.+0.j],
        [ 0.+0.j,  0.+0.j,  0.+0.j],
        [ 0.+0.j,  0.+0.j,  0.+0.j]],

       [[ 9.+0.j,  0.+0.j,  0.+0.j],
        [ 0.+0.j,  0.+0.j,  0.+0.j],
        [ 0.+0.j,  0.+0.j,  0.+0.j]],

       [[18.+0.j,  0.+0.j,  0.+0.j],
        [ 0.+0.j,  0.+0.j,  0.+0.j],
        [ 0.+0.j,  0.+0.j,  0.+0.j]]])
>>> np.fft.fftn(a, s=(-1, 3))
Traceback (most recent call last):
  File "/.../numpy/fft/_pocketfft.py", line 80, in _get_forward_norm
    raise ValueError(f"Invalid number of FFT data points ({n}) specified.")
ValueError: Invalid number of FFT data points (-1) specified.

Proposal

This RFC proposes that the proposed FFT specification adopt PyTorch's API and support -1 as a valid value to indicate that a dimension should not be padded (or truncated).

Questions

  • Does this RFC present overriding backward compatibility concerns?
  • Is there an alternative approach which would be better?

Metadata

Metadata

Assignees

No one assigned

    Labels

    API extensionAdds new functions or objects to the API.topic: FFTFast Fourier transforms.

    Type

    No type

    Projects

    No projects

    Milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions