Skip to content

Commit 58aad80

Browse files
authored
Avoid dimshuffle if expand_dims has empty axis (#724)
1 parent 86bc1d2 commit 58aad80

File tree

1 file changed

+14
-0
lines changed

1 file changed

+14
-0
lines changed

pytensor/tensor/basic.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4247,6 +4247,17 @@ def expand_dims(
42474247
42484248
Insert a new axis that will appear at the `axis` position in the expanded
42494249
array shape.
4250+
4251+
Parameters
4252+
----------
4253+
a :
4254+
The input array.
4255+
axis :
4256+
Position in the expanded axes where the new axis is placed.
4257+
If `axis` is empty, `a` will be returned immediately.
4258+
Returns
4259+
-------
4260+
`a` with a new axis at the `axis` position.
42504261
"""
42514262
a = as_tensor(a)
42524263

@@ -4256,6 +4267,9 @@ def expand_dims(
42564267
out_ndim = len(axis) + a.ndim
42574268
axis = np.core.numeric.normalize_axis_tuple(axis, out_ndim)
42584269

4270+
if not axis:
4271+
return a
4272+
42594273
dim_it = iter(range(a.ndim))
42604274
pattern = ["x" if ax in axis else next(dim_it) for ax in range(out_ndim)]
42614275

0 commit comments

Comments
 (0)