Skip to content

Commit e88117e

Browse files
committed
Only require input_ndim and not input_broadcastable in DimShuffle
1 parent d68f53f commit e88117e

24 files changed

+132
-181
lines changed

pytensor/sparse/sandbox/sp.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
from pytensor.tensor.math import dot
2020
from pytensor.tensor.math import max as pt_max
2121
from pytensor.tensor.shape import reshape
22-
from pytensor.tensor.subtensor import DimShuffle
2322

2423

2524
def register_specialize(lopt, *tags, **kwargs):
@@ -375,7 +374,7 @@ def convolve(
375374
[images.shape[0], pt.as_tensor(np.prod(outshp)), pt.as_tensor(nkern)]
376375
)
377376
tensout = reshape(output, newshp, ndim=3)
378-
output = DimShuffle((False,) * tensout.ndim, (0, 2, 1))(tensout)
377+
output = tensout.transpose(0, 2, 1)
379378
if flatten:
380379
output = pt.flatten(output, 2)
381380

@@ -443,6 +442,6 @@ def max_pool(images, imgshp, maxpoolshp):
443442
)
444443
out2 = reshape(out1, pshape, ndim=3)
445444

446-
out3 = DimShuffle(out2.broadcastable, (0, 2, 1))(out2)
445+
out3 = out2.transpose(0, 2, 1)
447446

448447
return pt.flatten(out3, 2), outshp

pytensor/tensor/basic.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -2042,7 +2042,7 @@ def transpose(x, axes=None):
20422042
# No-op
20432043
return _x
20442044

2045-
ret = DimShuffle(tuple(s == 1 for s in _x.type.shape), axes)(_x)
2045+
ret = _x.dimshuffle(axes)
20462046

20472047
if _x.name and axes == tuple(range((_x.type.ndim - 1), -1, -1)):
20482048
ret.name = _x.name + ".T"
@@ -3518,7 +3518,7 @@ def grad(self, inp, grads):
35183518
newdims.append(i)
35193519
i += 1
35203520

3521-
gx = DimShuffle(tuple(s == 1 for s in gx.type.shape), newdims)(gx)
3521+
gx = gx.dimshuffle(newdims)
35223522
assert gx.type.ndim == x.type.ndim
35233523
assert all(
35243524
s1 == s2

pytensor/tensor/elemwise.py

+66-95
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
from collections.abc import Sequence
12
from copy import copy
23
from textwrap import dedent
4+
from typing import Literal
35

46
import numpy as np
57
from numpy.core.numeric import normalize_axis_tuple
@@ -54,15 +56,14 @@ class DimShuffle(ExternalCOp):
5456
5557
Parameters
5658
----------
57-
input_broadcastable
58-
The expected broadcastable pattern of the input
59+
input_ndim
60+
The expected number of dimension of the input
5961
new_order
6062
A list representing the relationship between the input's
6163
dimensions and the output's dimensions. Each element of the
6264
list can either be an index or 'x'. Indices must be encoded
6365
as python integers, not pytensor symbolic integers.
64-
inplace : bool, optional
65-
If True (default), the output will be a view of the input.
66+
Missing indexes correspond to drop dimensions.
6667
6768
Notes
6869
-----
@@ -77,50 +78,47 @@ class DimShuffle(ExternalCOp):
7778
7879
.. code-block:: python
7980
80-
DimShuffle((False, False, False), ["x", 2, "x", 0, 1])
81+
DimShuffle(input_ndim=3, new_order=["x", 2, "x", 0, 1])
8182
82-
This `Op` will only work on 3d tensors with no broadcastable
83-
dimensions. The first dimension will be broadcastable,
83+
This `Op` will only work on 3d tensors.
84+
The first dimension of the output will be broadcastable,
8485
then we will have the third dimension of the input tensor as
8586
the second of the resulting tensor, etc. If the tensor has
8687
shape (20, 30, 40), the resulting tensor will have dimensions
8788
(1, 40, 1, 20, 30). (AxBxC tensor is mapped to 1xCx1xAxB tensor)
8889
8990
.. code-block:: python
9091
91-
DimShuffle((True, False), [1])
92+
DimShuffle(input_ndim=2, new_order=[1])
9293
93-
This `Op` will only work on 2d tensors with the first dimension
94-
broadcastable.
95-
The second dimension of the input tensor will be the first dimension of
96-
the resulting tensor.
97-
If the tensor has shape (1, 20), the resulting tensor will have shape
98-
(20, ).
94+
This `Op` will only work on 2d tensors with the first dimension broadcastable.
95+
The second dimension of the input tensor will be the first dimension of the resulting tensor.
96+
If the tensor has shape (1, 20), the resulting tensor will have shape (20, ).
9997
10098
Examples
10199
--------
102100
.. code-block:: python
103101
104-
DimShuffle((), ["x"]) # make a 0d (scalar) into a 1d vector
105-
DimShuffle((False, False), [0, 1]) # identity
106-
DimShuffle((False, False), [1, 0]) # inverts the 1st and 2nd dimensions
107-
DimShuffle((False,), ["x", 0]) # make a row out of a 1d vector
108-
# (N to 1xN)
109-
DimShuffle((False,), [0, "x"]) # make a column out of a 1d vector
110-
# (N to Nx1)
111-
DimShuffle((False, False, False), [2, 0, 1]) # AxBxC to CxAxB
112-
DimShuffle((False, False), [0, "x", 1]) # AxB to Ax1xB
113-
DimShuffle((False, False), [1, "x", 0]) # AxB to Bx1xA
114-
115-
The reordering of the dimensions can be done with the numpy.transpose
116-
function.
117-
Adding, subtracting dimensions can be done with reshape.
102+
DimShuffle(input_ndim=0, new_order=["x"]) # make a 0d (scalar) into a 1d vector
103+
DimShuffle(input_ndim=2, new_order=[0, 1]) # identity
104+
DimShuffle(input_ndim=2, new_order=[1, 0]) # transposition
105+
# Make a row out of a 1d vector (N to 1xN)
106+
DimShuffle(input_ndim=1, new_order=["x", 0])
107+
# Make a colum out of a 1d vector (N to Nx1)
108+
DimShuffle(input_ndim=1, new_order=[0, "x"])
109+
DimShuffle(input_ndim=3, new_order=[2, 0, 1]) # AxBxC to CxAxB
110+
DimShuffle(input_ndim=2, new_order=[0, "x", 1]) # AxB to Ax1xB
111+
DimShuffle(input_ndim=2, new_order=[1, "x", 0]) # AxB to Bx1xA
118112
113+
Notes
114+
-----
115+
The python implementation of this Op combines numpy.transpose for reordering of the dimensions
116+
and numpy.reshape for subtracting and adding broadcastable dimensions.
119117
"""
120118

121119
_f16_ok = True
122120
check_input = False
123-
__props__ = ("input_broadcastable", "new_order", "inplace")
121+
__props__ = ("input_ndim", "new_order", "inplace")
124122
c_func_file = "c_code/dimshuffle.c"
125123
c_func_name = "APPLY_SPECIFIC(cpu_dimshuffle)"
126124

@@ -133,16 +131,14 @@ def params_type(self):
133131
inplace=scalar_bool,
134132
)
135133

136-
def __init__(self, input_broadcastable, new_order):
134+
def __init__(self, *, input_ndim: int, new_order: Sequence[int | Literal["x"]]):
137135
super().__init__([self.c_func_file], self.c_func_name)
138136

139-
self.input_broadcastable = tuple(input_broadcastable)
140-
if not all(isinstance(bs, bool | np.bool_) for bs in self.input_broadcastable):
141-
raise ValueError(
142-
f"input_broadcastable must be boolean, {self.input_broadcastable}"
143-
)
144-
self.new_order = tuple(new_order)
137+
if not isinstance(input_ndim, int):
138+
raise TypeError(f"input_ndim must be an integer, got {type(int)}")
145139

140+
self.input_ndim = input_ndim
141+
self.new_order = tuple(new_order)
146142
self.inplace = True
147143

148144
for i, j in enumerate(new_order):
@@ -152,10 +148,10 @@ def __init__(self, input_broadcastable, new_order):
152148
"DimShuffle indices must be Python ints; got "
153149
f"{j} of type {type(j)}."
154150
)
155-
if j >= len(input_broadcastable):
151+
if j >= input_ndim:
156152
raise ValueError(
157153
f"new_order[{i}] is {j}, but the input only has "
158-
f"{len(input_broadcastable)} axes."
154+
f"{input_ndim} axes."
159155
)
160156
if j in new_order[(i + 1) :]:
161157
raise ValueError(
@@ -164,19 +160,7 @@ def __init__(self, input_broadcastable, new_order):
164160
)
165161

166162
# List of input dimensions to drop
167-
drop = []
168-
for i, b in enumerate(input_broadcastable):
169-
if i not in new_order:
170-
# We want to drop this dimension because it's not a value in
171-
# `new_order`
172-
if b == 1:
173-
drop.append(i)
174-
else:
175-
# We cannot drop non-broadcastable dimensions
176-
raise ValueError(
177-
"Cannot drop a non-broadcastable dimension: "
178-
f"{input_broadcastable}, {new_order}"
179-
)
163+
drop = [i for i in range(input_ndim) if i not in new_order]
180164

181165
# This is the list of the original dimensions that we keep
182166
self.shuffle = [x for x in new_order if x != "x"]
@@ -186,7 +170,6 @@ def __init__(self, input_broadcastable, new_order):
186170
self.augment = sorted(i for i, x in enumerate(new_order) if x == "x")
187171
self.drop = drop
188172

189-
input_ndim = len(input_broadcastable)
190173
self.is_left_expand_dims = self.augment and (
191174
input_ndim == 0 or new_order[-input_ndim:] == list(range(input_ndim))
192175
)
@@ -204,30 +187,29 @@ def __setstate__(self, state):
204187
# Let's just build the ExternalCOp.
205188
super().__init__([self.c_func_file], self.c_func_name)
206189

207-
def make_node(self, _input):
208-
input = as_tensor_variable(_input)
209-
ib = tuple(s == 1 for s in input.type.shape)
210-
if ib != self.input_broadcastable:
211-
if len(ib) != len(self.input_broadcastable):
190+
def make_node(self, inp):
191+
input = as_tensor_variable(inp)
192+
if input.type.ndim != self.input_ndim:
193+
raise TypeError(
194+
"The number of dimensions of the input is incorrect for this op. "
195+
f"Expected {self.input_ndim}, got {input.type.ndim}."
196+
)
197+
198+
input_static_shape = input.type.shape
199+
200+
# Runtime check for invalid drop
201+
for d in self.drop:
202+
if input_static_shape[d] not in (1, None):
212203
raise TypeError(
213-
"The number of dimensions of the "
214-
f"input is incorrect for this op. Expected {self.input_broadcastable}, got {ib}."
204+
f"Input dropped dimension {d} must have length 1 but has {input_static_shape[d]}"
215205
)
216-
for expected, b in zip(self.input_broadcastable, ib):
217-
if expected and not b:
218-
raise TypeError(
219-
"The broadcastable pattern of the "
220-
f"input is incorrect for this op. Expected {self.input_broadcastable}, got {ib}."
221-
)
222-
# else, expected == b or not expected and b
223-
# Both case are good.
224206

225207
out_static_shape = []
226208
for dim_idx in self.new_order:
227209
if dim_idx == "x":
228210
out_static_shape.append(1)
229211
else:
230-
out_static_shape.append(input.type.shape[dim_idx])
212+
out_static_shape.append(input_static_shape[dim_idx])
231213

232214
output = TensorType(dtype=input.type.dtype, shape=out_static_shape)()
233215

@@ -254,12 +236,14 @@ def perform(self, node, inp, out):
254236
if not isinstance(res, np.ndarray | np.memmap):
255237
raise TypeError(res)
256238

239+
# Put dropped axis at end
257240
res = res.transpose(self.transposition)
258241

259-
shape = list(res.shape[: len(self.shuffle)])
242+
# Define new shape without dropped axis and including new ones
243+
new_shape = list(res.shape[: len(self.shuffle)])
260244
for augm in self.augment:
261-
shape.insert(augm, 1)
262-
res = res.reshape(shape)
245+
new_shape.insert(augm, 1)
246+
res = res.reshape(new_shape)
263247

264248
if not self.inplace:
265249
res = np.copy(res)
@@ -284,22 +268,15 @@ def R_op(self, inputs, eval_points):
284268
def grad(self, inp, grads):
285269
(x,) = inp
286270
(gz,) = grads
287-
gz = as_tensor_variable(gz)
288271
grad_order = ["x"] * x.type.ndim
289272
for i, v in enumerate(self.new_order):
290273
if v != "x":
291274
grad_order[v] = i
292-
# Do not make the DimShuffle inplace as an optimization at the
293-
# canonicalization optimization phase will remove the inplace.
294-
# The inplace will be reintroduced automatically later in the graph.
295-
if inp[0].dtype in discrete_dtypes:
296-
return [inp[0].zeros_like(dtype=config.floatX)]
275+
276+
if x.type.dtype in discrete_dtypes:
277+
return [x.zeros_like(dtype=config.floatX)]
297278
else:
298-
return [
299-
DimShuffle(tuple(s == 1 for s in gz.type.shape), grad_order)(
300-
Elemwise(scalar_identity)(gz)
301-
)
302-
]
279+
return [gz.dimshuffle(grad_order)]
303280

304281

305282
class DimShufflePrinter(Printer):
@@ -409,7 +386,7 @@ def __setstate__(self, d):
409386
self.nfunc = None
410387
self.inplace_pattern = frozendict(self.inplace_pattern)
411388

412-
def get_output_info(self, dim_shuffle, *inputs):
389+
def get_output_info(self, *inputs):
413390
"""Return the outputs dtype and broadcastable pattern and the
414391
dimshuffled inputs.
415392
@@ -427,12 +404,7 @@ def get_output_info(self, dim_shuffle, *inputs):
427404
if not difference:
428405
args.append(input)
429406
else:
430-
args.append(
431-
dim_shuffle(
432-
input.type.broadcastable,
433-
["x"] * difference + list(range(length)),
434-
)(input)
435-
)
407+
args.append(input.dimshuffle(["x"] * difference + list(range(length))))
436408
inputs = args
437409

438410
# HERE: all the broadcast dims have the same length now
@@ -489,7 +461,7 @@ def make_node(self, *inputs):
489461
using DimShuffle.
490462
"""
491463
inputs = [as_tensor_variable(i) for i in inputs]
492-
out_dtypes, out_shapes, inputs = self.get_output_info(DimShuffle, *inputs)
464+
out_dtypes, out_shapes, inputs = self.get_output_info(*inputs)
493465
outputs = [
494466
TensorType(dtype=dtype, shape=shape)()
495467
for dtype, shape in zip(out_dtypes, out_shapes)
@@ -634,7 +606,7 @@ def transform(r):
634606
res = pytensor.tensor.basic.constant(
635607
np.asarray(r.data), dtype=r.type.dtype
636608
)
637-
return DimShuffle((), ["x"] * nd)(res)
609+
return res.dimshuffle(["x"] * nd)
638610

639611
new_r = Elemwise(node.op, {})(*[transform(ipt) for ipt in node.inputs])
640612
if isinstance(new_r, list | tuple):
@@ -1707,13 +1679,12 @@ def vectorize_dimshuffle(op: DimShuffle, node: Apply, x: TensorVariable) -> Appl
17071679
batched_ndims = x.type.ndim - node.inputs[0].type.ndim
17081680
if not batched_ndims:
17091681
return node.op.make_node(x)
1710-
input_broadcastable = x.type.broadcastable[:batched_ndims] + op.input_broadcastable
1711-
# e.g., ds(matrix, order=(1, "x", 0)) -> ds(tensor4, order=(0, 1, 3, "x", 2))
1712-
# e.g., ds(row, order=(1, "x")) -> ds(tensor4, order=(0, 1, 3, "x"))
1682+
# e.g., ds(input_ndim=2, order=(1, "x", 0)) -> ds(input_ndim=4, order=(0, 1, 3, "x", 2))
1683+
# e.g., ds(input_ndim=2, order=(1, "x")) -> ds(input_ndim=4, order=(0, 1, 3, "x"))
17131684
new_order = list(range(batched_ndims)) + [
17141685
"x" if (o == "x") else (o + batched_ndims) for o in op.new_order
17151686
]
1716-
return DimShuffle(input_broadcastable, new_order).make_node(x)
1687+
return x.dimshuffle(new_order).owner
17171688

17181689

17191690
def get_normalized_batch_axes(

pytensor/tensor/extra_ops.py

+1-6
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
)
4242
from pytensor.tensor.math import max as pt_max
4343
from pytensor.tensor.math import sum as pt_sum
44-
from pytensor.tensor.shape import Shape_i, specify_broadcastable
44+
from pytensor.tensor.shape import Shape_i
4545
from pytensor.tensor.subtensor import advanced_inc_subtensor1, set_subtensor
4646
from pytensor.tensor.type import TensorType, dvector, int_dtypes, integer_dtypes, vector
4747
from pytensor.tensor.variable import TensorVariable
@@ -609,11 +609,6 @@ def squeeze(x, axis=None):
609609
# Nothing could be squeezed
610610
return _x
611611

612-
# `Dimshuffle` raises when we try to drop an axis that is not statically broadcastable.
613-
# We add a `specify_broadcastable` instead of raising.
614-
non_broadcastable_axis = [i for i in axis if not _x.broadcastable[i]]
615-
_x = specify_broadcastable(_x, *non_broadcastable_axis)
616-
617612
return _x.dimshuffle([i for i in range(_x.ndim) if i not in axis])
618613

619614

pytensor/tensor/inplace.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from pytensor import printing
22
from pytensor.printing import pprint
3-
from pytensor.tensor.elemwise import DimShuffle, scalar_elemwise
3+
from pytensor.tensor.elemwise import scalar_elemwise
44

55

66
@scalar_elemwise
@@ -429,4 +429,4 @@ def hyp2f1_inplace(a, b, c, z):
429429
def transpose_inplace(x, **kwargs):
430430
"Perform a transpose on a tensor without copying the underlying storage"
431431
dims = list(range(x.ndim - 1, -1, -1))
432-
return DimShuffle(x.broadcastable, dims)(x)
432+
return x.dimshuffle(dims)

0 commit comments

Comments
 (0)