@@ -1700,21 +1700,22 @@ def do_constant_folding(self, fgraph, node):
1700
1700
return False
1701
1701
1702
1702
for client , idx in clients :
1703
- if isinstance (client .op , Output ):
1703
+ client_op = client .op
1704
+ if isinstance (client_op , Output ):
1704
1705
# If the output is a constant, it will have to be deepcopied
1705
1706
# each time the function is called. So we do not fold.
1706
1707
return False
1707
- # Allow alloc to be lifted out of Elemwise before constant folding it
1708
- elif isinstance (client . op , Elemwise ):
1709
- return None
1708
+ # Op's through which Alloc can be lifted
1709
+ elif isinstance (client_op , Elemwise | DimShuffle | Alloc | Join ):
1710
+ return False
1710
1711
# Same for Blockwise, unless it has no batch_dims
1711
- elif isinstance (client . op , Blockwise ) and client .op .batch_ndim (client ):
1712
- return None
1712
+ elif isinstance (client_op , Blockwise ) and client .op .batch_ndim (client ):
1713
+ return False
1713
1714
elif (
1714
1715
# The following ops work inplace of their input id 0.
1715
1716
idx == 0
1716
1717
and isinstance (
1717
- client . op ,
1718
+ client_op ,
1718
1719
pytensor .tensor .subtensor .IncSubtensor
1719
1720
| pytensor .tensor .subtensor .AdvancedIncSubtensor1
1720
1721
| pytensor .tensor .subtensor .AdvancedIncSubtensor
@@ -2035,10 +2036,15 @@ def transpose(x, axes=None):
2035
2036
_x = as_tensor_variable (x )
2036
2037
2037
2038
if axes is None :
2038
- axes = list (range ((_x .type .ndim - 1 ), - 1 , - 1 ))
2039
+ axes = tuple (range ((_x .type .ndim - 1 ), - 1 , - 1 ))
2040
+
2041
+ if tuple (axes ) == tuple (range (len (axes ))):
2042
+ # No-op
2043
+ return _x
2044
+
2039
2045
ret = DimShuffle (tuple (s == 1 for s in _x .type .shape ), axes )(_x )
2040
2046
2041
- if _x .name and axes == list (range ((_x .type .ndim - 1 ), - 1 , - 1 )):
2047
+ if _x .name and axes == tuple (range ((_x .type .ndim - 1 ), - 1 , - 1 )):
2042
2048
ret .name = _x .name + ".T"
2043
2049
2044
2050
return ret
@@ -3950,6 +3956,10 @@ def moveaxis(
3950
3956
source = normalize_axis_tuple (source , a .ndim , "source" )
3951
3957
destination = normalize_axis_tuple (destination , a .ndim , "destination" )
3952
3958
3959
+ if source == destination :
3960
+ # It's a no-op
3961
+ return a
3962
+
3953
3963
if len (source ) != len (destination ):
3954
3964
raise ValueError (
3955
3965
"`source` and `destination` arguments must have the same number of elements"
@@ -4260,9 +4270,7 @@ def atleast_Nd(
4260
4270
atleast_3d = partial (atleast_Nd , n = 3 )
4261
4271
4262
4272
4263
- def expand_dims (
4264
- a : np .ndarray | TensorVariable , axis : tuple [int , ...]
4265
- ) -> TensorVariable :
4273
+ def expand_dims (a : np .ndarray | TensorVariable , axis : Sequence [int ]) -> TensorVariable :
4266
4274
"""Expand the shape of an array.
4267
4275
4268
4276
Insert a new axis that will appear at the `axis` position in the expanded
@@ -4281,7 +4289,7 @@ def expand_dims(
4281
4289
"""
4282
4290
a = as_tensor (a )
4283
4291
4284
- if not isinstance (axis , tuple | list ):
4292
+ if not isinstance (axis , Sequence ):
4285
4293
axis = (axis ,)
4286
4294
4287
4295
out_ndim = len (axis ) + a .ndim
0 commit comments