Description
Hello all! I raised this issue on array-api-compat earlier (data-apis/array-api-compat#105), but I think it might be more properly directed here.
In the array API, expand_dims
supports only a single axis (https://data-apis.org/array-api/latest/API_specification/generated/array_api.expand_dims.html) as opposed to a tuple of axes. This is different to NumPy, CuPy, and Jax, which support a tuple of axes. PyTorch, however, supports only a single axis. I don't know the justification for why the array API supports only a single axis as opposed to a tuple, but what it means is that expand_dims
no longer works in many places when adopting the array API.
In practice, expand_dims is just a light wrapper for reshape, see https://github.com/numpy/numpy/blob/3b246c6488cf246d488bbe5726ca58dc26b6ea74/numpy/lib/_shape_base_impl.py#L594. But it's not great to force users to write their own version of expand_dims
in every library now. Is the array API willing to update expand_dims
to support a tuple of axes? If not, and if expand_dims
will only support a single axis going forward, that effectively makes all users of expand_dims
copy and paste the NumPy implementation.
@lucascolley Pointed out to me that when expand_dims
was added to the array API, only NumPy supported a tuple of axes. See #42. That was 4 years ago and the situation has changed, as above.
Metadata
Metadata
Assignees
Labels
Type
Projects
Status