@@ -174,11 +174,15 @@ def local_dimshuffle_alloc(fgraph, node):
174
174
def local_dimshuffle_subtensor (fgraph , node ):
175
175
"""If a subtensor is inside a dimshuffle which only drop
176
176
broadcastable dimensions, scrap the dimshuffle and index the
177
- subtensor with 0
177
+ subtensor in a way that avoids the degenerate dimension
178
178
179
179
x[i:j, :, k:l].dimshuffle(0, 2) =>
180
180
x[i:j, 0, k:l] if x.broadcastable == (False, True, False)
181
181
182
+ x[i:j, k:l, :].dimshuffle(0, 2) => x[i:j, k, :]
183
+ x[i:j, k:, :].dimshuffle(0, 2) => x[i:j, k, :]
184
+ x[i:j, :l, :].dimshuffle(0, 2) => x[i:j, 0, :]
185
+
182
186
"""
183
187
if isinstance (node .op , DimShuffle ) and node .inputs [0 ].owner :
184
188
# the dimshuffle can only drop dimensions (cannot reshape nor add 'x')
@@ -217,24 +221,40 @@ def local_dimshuffle_subtensor(fgraph, node):
217
221
new_idx_list = list (input_ .owner .op .idx_list )
218
222
new_inputs = [input_ .owner .inputs [0 ]]
219
223
zero = constant (0 )
220
- slice_attr_list = ["start" , "stop" , "step" ]
221
224
j = 0
222
225
slice_i = - 1
223
226
subtensor_removed_dims = 0
224
227
for i , idx in enumerate (input_ .owner .op .idx_list ):
225
228
if isinstance (idx , slice ):
226
- past_j = j
227
229
slice_i += 1
228
- for slice_attr in slice_attr_list :
229
- if getattr (idx , slice_attr ) is not None :
230
- new_inputs += [input_ .owner .inputs [1 + j ]]
231
- j += 1
232
- # if past_j == j indicates a slice(None, None, None),
233
- # that's where we want to index with 0 if it is also at
234
- # the same spot of a missing dim
235
- if past_j == j and slice_i in missing_dims :
236
- new_idx_list [i ] = zero
237
- new_inputs += [zero ]
230
+ if slice_i in missing_dims :
231
+ # Missing dim is a slice(None), remove by indexing by 0
232
+ if idx == slice (None ):
233
+ new_idx_list [i ] = zero
234
+ new_inputs += [zero ]
235
+ # Missing dim is an ordinary slice with known output dim length of 1
236
+ # Remove by indexing by start
237
+ else :
238
+ if idx .start is None :
239
+ start = zero
240
+ else :
241
+ start = input_ .owner .inputs [1 + j ]
242
+ j += 1
243
+ new_idx_list [i ] = start
244
+ new_inputs += [start ]
245
+
246
+ # Ignore useless stop and step input if there is one
247
+ for slice_attr in ("stop" , "step" ):
248
+ if getattr (idx , slice_attr ) is not None :
249
+ j += 1
250
+
251
+ # Keep non-dropped slice inputs
252
+ else :
253
+ for slice_attr in ("start" , "stop" , "step" ):
254
+ if getattr (idx , slice_attr ) is not None :
255
+ new_inputs += [input_ .owner .inputs [1 + j ]]
256
+ j += 1
257
+ # Keep non-dropped non-slice inputs
238
258
else :
239
259
new_inputs += [input_ .owner .inputs [1 + j ]]
240
260
j += 1
0 commit comments