1
+ from collections .abc import Sequence
1
2
from copy import copy
2
3
from textwrap import dedent
4
+ from typing import Literal
3
5
4
6
import numpy as np
5
7
from numpy .core .numeric import normalize_axis_tuple
@@ -54,15 +56,14 @@ class DimShuffle(ExternalCOp):
54
56
55
57
Parameters
56
58
----------
57
- input_broadcastable
58
- The expected broadcastable pattern of the input
59
+ input_ndim
60
+ The expected number of dimension of the input
59
61
new_order
60
62
A list representing the relationship between the input's
61
63
dimensions and the output's dimensions. Each element of the
62
64
list can either be an index or 'x'. Indices must be encoded
63
65
as python integers, not pytensor symbolic integers.
64
- inplace : bool, optional
65
- If True (default), the output will be a view of the input.
66
+ Missing indexes correspond to drop dimensions.
66
67
67
68
Notes
68
69
-----
@@ -77,50 +78,47 @@ class DimShuffle(ExternalCOp):
77
78
78
79
.. code-block:: python
79
80
80
- DimShuffle((False, False, False), ["x", 2, "x", 0, 1])
81
+ DimShuffle(input_ndim=3, new_order= ["x", 2, "x", 0, 1])
81
82
82
- This `Op` will only work on 3d tensors with no broadcastable
83
- dimensions. The first dimension will be broadcastable,
83
+ This `Op` will only work on 3d tensors.
84
+ The first dimension of the output will be broadcastable,
84
85
then we will have the third dimension of the input tensor as
85
86
the second of the resulting tensor, etc. If the tensor has
86
87
shape (20, 30, 40), the resulting tensor will have dimensions
87
88
(1, 40, 1, 20, 30). (AxBxC tensor is mapped to 1xCx1xAxB tensor)
88
89
89
90
.. code-block:: python
90
91
91
- DimShuffle((True, False), [1])
92
+ DimShuffle(input_ndim=2, new_order= [1])
92
93
93
- This `Op` will only work on 2d tensors with the first dimension
94
- broadcastable.
95
- The second dimension of the input tensor will be the first dimension of
96
- the resulting tensor.
97
- If the tensor has shape (1, 20), the resulting tensor will have shape
98
- (20, ).
94
+ This `Op` will only work on 2d tensors with the first dimension broadcastable.
95
+ The second dimension of the input tensor will be the first dimension of the resulting tensor.
96
+ If the tensor has shape (1, 20), the resulting tensor will have shape (20, ).
99
97
100
98
Examples
101
99
--------
102
100
.. code-block:: python
103
101
104
- DimShuffle((), ["x"]) # make a 0d (scalar) into a 1d vector
105
- DimShuffle((False, False), [0, 1]) # identity
106
- DimShuffle((False, False), [1, 0]) # inverts the 1st and 2nd dimensions
107
- DimShuffle((False,), ["x", 0]) # make a row out of a 1d vector
108
- # (N to 1xN)
109
- DimShuffle((False,), [0, "x"]) # make a column out of a 1d vector
110
- # (N to Nx1)
111
- DimShuffle((False, False, False), [2, 0, 1]) # AxBxC to CxAxB
112
- DimShuffle((False, False), [0, "x", 1]) # AxB to Ax1xB
113
- DimShuffle((False, False), [1, "x", 0]) # AxB to Bx1xA
114
-
115
- The reordering of the dimensions can be done with the numpy.transpose
116
- function.
117
- Adding, subtracting dimensions can be done with reshape.
102
+ DimShuffle(input_ndim=0, new_order=["x"]) # make a 0d (scalar) into a 1d vector
103
+ DimShuffle(input_ndim=2, new_order=[0, 1]) # identity
104
+ DimShuffle(input_ndim=2, new_order=[1, 0]) # transposition
105
+ # Make a row out of a 1d vector (N to 1xN)
106
+ DimShuffle(input_ndim=1, new_order=["x", 0])
107
+ # Make a colum out of a 1d vector (N to Nx1)
108
+ DimShuffle(input_ndim=1, new_order=[0, "x"])
109
+ DimShuffle(input_ndim=3, new_order=[2, 0, 1]) # AxBxC to CxAxB
110
+ DimShuffle(input_ndim=2, new_order=[0, "x", 1]) # AxB to Ax1xB
111
+ DimShuffle(input_ndim=2, new_order=[1, "x", 0]) # AxB to Bx1xA
118
112
113
+ Notes
114
+ -----
115
+ The python implementation of this Op combines numpy.transpose for reordering of the dimensions
116
+ and numpy.reshape for subtracting and adding broadcastable dimensions.
119
117
"""
120
118
121
119
_f16_ok = True
122
120
check_input = False
123
- __props__ = ("input_broadcastable " , "new_order" , "inplace" )
121
+ __props__ = ("input_ndim " , "new_order" , "inplace" )
124
122
c_func_file = "c_code/dimshuffle.c"
125
123
c_func_name = "APPLY_SPECIFIC(cpu_dimshuffle)"
126
124
@@ -133,16 +131,14 @@ def params_type(self):
133
131
inplace = scalar_bool ,
134
132
)
135
133
136
- def __init__ (self , input_broadcastable , new_order ):
134
+ def __init__ (self , * , input_ndim : int , new_order : Sequence [ int | Literal [ "x" ]] ):
137
135
super ().__init__ ([self .c_func_file ], self .c_func_name )
138
136
139
- self .input_broadcastable = tuple (input_broadcastable )
140
- if not all (isinstance (bs , bool | np .bool_ ) for bs in self .input_broadcastable ):
141
- raise ValueError (
142
- f"input_broadcastable must be boolean, { self .input_broadcastable } "
143
- )
144
- self .new_order = tuple (new_order )
137
+ if not isinstance (input_ndim , int ):
138
+ raise TypeError (f"input_ndim must be an integer, got { type (int )} " )
145
139
140
+ self .input_ndim = input_ndim
141
+ self .new_order = tuple (new_order )
146
142
self .inplace = True
147
143
148
144
for i , j in enumerate (new_order ):
@@ -152,10 +148,10 @@ def __init__(self, input_broadcastable, new_order):
152
148
"DimShuffle indices must be Python ints; got "
153
149
f"{ j } of type { type (j )} ."
154
150
)
155
- if j >= len ( input_broadcastable ) :
151
+ if j >= input_ndim :
156
152
raise ValueError (
157
153
f"new_order[{ i } ] is { j } , but the input only has "
158
- f"{ len ( input_broadcastable ) } axes."
154
+ f"{ input_ndim } axes."
159
155
)
160
156
if j in new_order [(i + 1 ) :]:
161
157
raise ValueError (
@@ -164,19 +160,7 @@ def __init__(self, input_broadcastable, new_order):
164
160
)
165
161
166
162
# List of input dimensions to drop
167
- drop = []
168
- for i , b in enumerate (input_broadcastable ):
169
- if i not in new_order :
170
- # We want to drop this dimension because it's not a value in
171
- # `new_order`
172
- if b == 1 :
173
- drop .append (i )
174
- else :
175
- # We cannot drop non-broadcastable dimensions
176
- raise ValueError (
177
- "Cannot drop a non-broadcastable dimension: "
178
- f"{ input_broadcastable } , { new_order } "
179
- )
163
+ drop = [i for i in range (input_ndim ) if i not in new_order ]
180
164
181
165
# This is the list of the original dimensions that we keep
182
166
self .shuffle = [x for x in new_order if x != "x" ]
@@ -186,7 +170,6 @@ def __init__(self, input_broadcastable, new_order):
186
170
self .augment = sorted (i for i , x in enumerate (new_order ) if x == "x" )
187
171
self .drop = drop
188
172
189
- input_ndim = len (input_broadcastable )
190
173
self .is_left_expand_dims = self .augment and (
191
174
input_ndim == 0 or new_order [- input_ndim :] == list (range (input_ndim ))
192
175
)
@@ -204,30 +187,29 @@ def __setstate__(self, state):
204
187
# Let's just build the ExternalCOp.
205
188
super ().__init__ ([self .c_func_file ], self .c_func_name )
206
189
207
- def make_node (self , _input ):
208
- input = as_tensor_variable (_input )
209
- ib = tuple (s == 1 for s in input .type .shape )
210
- if ib != self .input_broadcastable :
211
- if len (ib ) != len (self .input_broadcastable ):
190
+ def make_node (self , inp ):
191
+ input = as_tensor_variable (inp )
192
+ if input .type .ndim != self .input_ndim :
193
+ raise TypeError (
194
+ "The number of dimensions of the input is incorrect for this op. "
195
+ f"Expected { self .input_ndim } , got { input .type .ndim } ."
196
+ )
197
+
198
+ input_static_shape = input .type .shape
199
+
200
+ # Runtime check for invalid drop
201
+ for d in self .drop :
202
+ if input_static_shape [d ] not in (1 , None ):
212
203
raise TypeError (
213
- "The number of dimensions of the "
214
- f"input is incorrect for this op. Expected { self .input_broadcastable } , got { ib } ."
204
+ f"Input dropped dimension { d } must have length 1 but has { input_static_shape [d ]} "
215
205
)
216
- for expected , b in zip (self .input_broadcastable , ib ):
217
- if expected and not b :
218
- raise TypeError (
219
- "The broadcastable pattern of the "
220
- f"input is incorrect for this op. Expected { self .input_broadcastable } , got { ib } ."
221
- )
222
- # else, expected == b or not expected and b
223
- # Both case are good.
224
206
225
207
out_static_shape = []
226
208
for dim_idx in self .new_order :
227
209
if dim_idx == "x" :
228
210
out_static_shape .append (1 )
229
211
else :
230
- out_static_shape .append (input . type . shape [dim_idx ])
212
+ out_static_shape .append (input_static_shape [dim_idx ])
231
213
232
214
output = TensorType (dtype = input .type .dtype , shape = out_static_shape )()
233
215
@@ -254,12 +236,14 @@ def perform(self, node, inp, out):
254
236
if not isinstance (res , np .ndarray | np .memmap ):
255
237
raise TypeError (res )
256
238
239
+ # Put dropped axis at end
257
240
res = res .transpose (self .transposition )
258
241
259
- shape = list (res .shape [: len (self .shuffle )])
242
+ # Define new shape without dropped axis and including new ones
243
+ new_shape = list (res .shape [: len (self .shuffle )])
260
244
for augm in self .augment :
261
- shape .insert (augm , 1 )
262
- res = res .reshape (shape )
245
+ new_shape .insert (augm , 1 )
246
+ res = res .reshape (new_shape )
263
247
264
248
if not self .inplace :
265
249
res = np .copy (res )
@@ -284,22 +268,15 @@ def R_op(self, inputs, eval_points):
284
268
def grad (self , inp , grads ):
285
269
(x ,) = inp
286
270
(gz ,) = grads
287
- gz = as_tensor_variable (gz )
288
271
grad_order = ["x" ] * x .type .ndim
289
272
for i , v in enumerate (self .new_order ):
290
273
if v != "x" :
291
274
grad_order [v ] = i
292
- # Do not make the DimShuffle inplace as an optimization at the
293
- # canonicalization optimization phase will remove the inplace.
294
- # The inplace will be reintroduced automatically later in the graph.
295
- if inp [0 ].dtype in discrete_dtypes :
296
- return [inp [0 ].zeros_like (dtype = config .floatX )]
275
+
276
+ if x .type .dtype in discrete_dtypes :
277
+ return [x .zeros_like (dtype = config .floatX )]
297
278
else :
298
- return [
299
- DimShuffle (tuple (s == 1 for s in gz .type .shape ), grad_order )(
300
- Elemwise (scalar_identity )(gz )
301
- )
302
- ]
279
+ return [gz .dimshuffle (grad_order )]
303
280
304
281
305
282
class DimShufflePrinter (Printer ):
@@ -409,7 +386,7 @@ def __setstate__(self, d):
409
386
self .nfunc = None
410
387
self .inplace_pattern = frozendict (self .inplace_pattern )
411
388
412
- def get_output_info (self , dim_shuffle , * inputs ):
389
+ def get_output_info (self , * inputs ):
413
390
"""Return the outputs dtype and broadcastable pattern and the
414
391
dimshuffled inputs.
415
392
@@ -427,12 +404,7 @@ def get_output_info(self, dim_shuffle, *inputs):
427
404
if not difference :
428
405
args .append (input )
429
406
else :
430
- args .append (
431
- dim_shuffle (
432
- input .type .broadcastable ,
433
- ["x" ] * difference + list (range (length )),
434
- )(input )
435
- )
407
+ args .append (input .dimshuffle (["x" ] * difference + list (range (length ))))
436
408
inputs = args
437
409
438
410
# HERE: all the broadcast dims have the same length now
@@ -489,7 +461,7 @@ def make_node(self, *inputs):
489
461
using DimShuffle.
490
462
"""
491
463
inputs = [as_tensor_variable (i ) for i in inputs ]
492
- out_dtypes , out_shapes , inputs = self .get_output_info (DimShuffle , * inputs )
464
+ out_dtypes , out_shapes , inputs = self .get_output_info (* inputs )
493
465
outputs = [
494
466
TensorType (dtype = dtype , shape = shape )()
495
467
for dtype , shape in zip (out_dtypes , out_shapes )
@@ -634,7 +606,7 @@ def transform(r):
634
606
res = pytensor .tensor .basic .constant (
635
607
np .asarray (r .data ), dtype = r .type .dtype
636
608
)
637
- return DimShuffle ((), ["x" ] * nd )( res )
609
+ return res . dimshuffle ( ["x" ] * nd )
638
610
639
611
new_r = Elemwise (node .op , {})(* [transform (ipt ) for ipt in node .inputs ])
640
612
if isinstance (new_r , list | tuple ):
@@ -1707,13 +1679,12 @@ def vectorize_dimshuffle(op: DimShuffle, node: Apply, x: TensorVariable) -> Appl
1707
1679
batched_ndims = x .type .ndim - node .inputs [0 ].type .ndim
1708
1680
if not batched_ndims :
1709
1681
return node .op .make_node (x )
1710
- input_broadcastable = x .type .broadcastable [:batched_ndims ] + op .input_broadcastable
1711
- # e.g., ds(matrix, order=(1, "x", 0)) -> ds(tensor4, order=(0, 1, 3, "x", 2))
1712
- # e.g., ds(row, order=(1, "x")) -> ds(tensor4, order=(0, 1, 3, "x"))
1682
+ # e.g., ds(input_ndim=2, order=(1, "x", 0)) -> ds(input_ndim=4, order=(0, 1, 3, "x", 2))
1683
+ # e.g., ds(input_ndim=2, order=(1, "x")) -> ds(input_ndim=4, order=(0, 1, 3, "x"))
1713
1684
new_order = list (range (batched_ndims )) + [
1714
1685
"x" if (o == "x" ) else (o + batched_ndims ) for o in op .new_order
1715
1686
]
1716
- return DimShuffle ( input_broadcastable , new_order ).make_node ( x )
1687
+ return x . dimshuffle ( new_order ).owner
1717
1688
1718
1689
1719
1690
def get_normalized_batch_axes (
0 commit comments