Skip to content

RFC: add support for a tuple of axes in expand_dims #760

Open
@izaid

Description

@izaid

Hello all! I raised this issue on array-api-compat earlier (data-apis/array-api-compat#105), but I think it might be more properly directed here.

In the array API, expand_dims supports only a single axis (https://data-apis.org/array-api/latest/API_specification/generated/array_api.expand_dims.html) as opposed to a tuple of axes. This is different to NumPy, CuPy, and Jax, which support a tuple of axes. PyTorch, however, supports only a single axis. I don't know the justification for why the array API supports only a single axis as opposed to a tuple, but what it means is that expand_dims no longer works in many places when adopting the array API.

In practice, expand_dims is just a light wrapper for reshape, see https://github.com/numpy/numpy/blob/3b246c6488cf246d488bbe5726ca58dc26b6ea74/numpy/lib/_shape_base_impl.py#L594. But it's not great to force users to write their own version of expand_dims in every library now. Is the array API willing to update expand_dims to support a tuple of axes? If not, and if expand_dims will only support a single axis going forward, that effectively makes all users of expand_dims copy and paste the NumPy implementation.

@lucascolley Pointed out to me that when expand_dims was added to the array API, only NumPy supported a tuple of axes. See #42. That was 4 years ago and the situation has changed, as above.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Needs DiscussionNeeds further discussion.RFCRequest for comments. Feature requests and proposed changes.topic: ManipulationArray manipulation and transformation.

    Type

    No type

    Projects

    Status

    Stage 0

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions