Skip to content

Commit 8dc67ea

Browse files
committed
Fix bug in local_dimshuffle_subtensor rewrite
1 parent 4efbd19 commit 8dc67ea

File tree

2 files changed

+41
-13
lines changed

2 files changed

+41
-13
lines changed

pytensor/tensor/rewriting/uncanonicalize.py

Lines changed: 33 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -174,11 +174,15 @@ def local_dimshuffle_alloc(fgraph, node):
174174
def local_dimshuffle_subtensor(fgraph, node):
175175
"""If a subtensor is inside a dimshuffle which only drop
176176
broadcastable dimensions, scrap the dimshuffle and index the
177-
subtensor with 0
177+
subtensor in a way that avoids the degenerate dimension
178178
179179
x[i:j, :, k:l].dimshuffle(0, 2) =>
180180
x[i:j, 0, k:l] if x.broadcastable == (False, True, False)
181181
182+
x[i:j, k:l, :].dimshuffle(0, 2) => x[i:j, k, :]
183+
x[i:j, k:, :].dimshuffle(0, 2) => x[i:j, k, :]
184+
x[i:j, :l, :].dimshuffle(0, 2) => x[i:j, 0, :]
185+
182186
"""
183187
if isinstance(node.op, DimShuffle) and node.inputs[0].owner:
184188
# the dimshuffle can only drop dimensions (cannot reshape nor add 'x')
@@ -217,24 +221,40 @@ def local_dimshuffle_subtensor(fgraph, node):
217221
new_idx_list = list(input_.owner.op.idx_list)
218222
new_inputs = [input_.owner.inputs[0]]
219223
zero = constant(0)
220-
slice_attr_list = ["start", "stop", "step"]
221224
j = 0
222225
slice_i = -1
223226
subtensor_removed_dims = 0
224227
for i, idx in enumerate(input_.owner.op.idx_list):
225228
if isinstance(idx, slice):
226-
past_j = j
227229
slice_i += 1
228-
for slice_attr in slice_attr_list:
229-
if getattr(idx, slice_attr) is not None:
230-
new_inputs += [input_.owner.inputs[1 + j]]
231-
j += 1
232-
# if past_j == j indicates a slice(None, None, None),
233-
# that's where we want to index with 0 if it is also at
234-
# the same spot of a missing dim
235-
if past_j == j and slice_i in missing_dims:
236-
new_idx_list[i] = zero
237-
new_inputs += [zero]
230+
if slice_i in missing_dims:
231+
# Missing dim is a slice(None), remove by indexing by 0
232+
if idx == slice(None):
233+
new_idx_list[i] = zero
234+
new_inputs += [zero]
235+
# Missing dim is an ordinary slice with known output dim length of 1
236+
# Remove by indexing by start
237+
else:
238+
if idx.start is None:
239+
start = zero
240+
else:
241+
start = input_.owner.inputs[1 + j]
242+
j += 1
243+
new_idx_list[i] = start
244+
new_inputs += [start]
245+
246+
# Ignore useless stop and step input if there is one
247+
for slice_attr in ("stop", "step"):
248+
if getattr(idx, slice_attr) is not None:
249+
j += 1
250+
251+
# Keep non-dropped slice inputs
252+
else:
253+
for slice_attr in ("start", "stop", "step"):
254+
if getattr(idx, slice_attr) is not None:
255+
new_inputs += [input_.owner.inputs[1 + j]]
256+
j += 1
257+
# Keep non-dropped non-slice inputs
238258
else:
239259
new_inputs += [input_.owner.inputs[1 + j]]
240260
j += 1

tests/tensor/rewriting/test_uncanonicalize.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,3 +214,11 @@ def test_local_dimshuffle_subtensor():
214214
assert x[:, :, 0:3, ::-1].dimshuffle(0, 2, 3).eval(
215215
{x: np.ones((5, 1, 6, 7))}
216216
).shape == (5, 3, 7)
217+
218+
# Test dropped sliced dimensions
219+
x = matrix("x", shape=(5, 4), dtype="float64")
220+
221+
assert x[2:3, :-2].dimshuffle(1).eval({x: np.ones(x.type.shape)}).shape == (2,)
222+
assert x[:1, 0:3].dimshuffle(1).eval({x: np.ones(x.type.shape)}).shape == (3,)
223+
assert x[-1:, :].dimshuffle(1).eval({x: np.ones(x.type.shape)}).shape == (4,)
224+
assert x[4:3:-1, 1:].dimshuffle(1).eval({x: np.ones(x.type.shape)}).shape == (3,)

0 commit comments

Comments
 (0)