Skip to content

Avoid dimshuffle if expand_dims has empty axis #707

Closed
@ricardoV94

Description

@ricardoV94

A small eager optimization we can do, to avoid useless dimshuffles.

Just return a if not axis before dim_it.

def expand_dims(
a: np.ndarray | TensorVariable, axis: tuple[int, ...]
) -> TensorVariable:
"""Expand the shape of an array.
Insert a new axis that will appear at the `axis` position in the expanded
array shape.
"""
a = as_tensor(a)
if not isinstance(axis, tuple | list):
axis = (axis,)
out_ndim = len(axis) + a.ndim
axis = np.core.numeric.normalize_axis_tuple(axis, out_ndim)
dim_it = iter(range(a.ndim))
pattern = ["x" if ax in axis else next(dim_it) for ax in range(out_ndim)]
return a.dimshuffle(pattern)

Squeeze already does this:

if not axis:
# Nothing to do
return _x

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions