|
35 | 35 | nonzero,
|
36 | 36 | scalar_from_tensor,
|
37 | 37 | )
|
| 38 | +from pytensor.tensor.basic import ( |
| 39 | + constant as tensor_constant, |
| 40 | +) |
38 | 41 | from pytensor.tensor.blockwise import vectorize_node_fallback
|
39 | 42 | from pytensor.tensor.elemwise import DimShuffle
|
40 | 43 | from pytensor.tensor.exceptions import AdvancedIndexingError, NotScalarConstantError
|
@@ -252,6 +255,23 @@ def get_idx_list(inputs, idx_list):
|
252 | 255 | return indices_from_subtensor(inputs[1:], idx_list)
|
253 | 256 |
|
254 | 257 |
|
| 258 | +def undo_scalarization(x): |
| 259 | + """Undo scalarization of a variable. |
| 260 | +
|
| 261 | + PyTensor Basic index operations use ScalarVariables for the indices/slice arguments. |
| 262 | + When reason symbolically about the result of multiple indexing operations, we usually |
| 263 | + want to work on TensorVariables, since rewrites work on those and not ScalarVariables. |
| 264 | +
|
| 265 | + This function undoes ScalarFromTensor operation or converts ScalarConstants to TensorConstants. |
| 266 | + """ |
| 267 | + if isinstance(x, ScalarVariable): |
| 268 | + if isinstance(x, ScalarConstant): |
| 269 | + return tensor_constant(x.data, dtype=x.dtype) |
| 270 | + elif x.owner is not None and isinstance(x.owner.op, ScalarFromTensor): |
| 271 | + return x.owner.inputs[0] |
| 272 | + return x |
| 273 | + |
| 274 | + |
255 | 275 | @overload
|
256 | 276 | def get_canonical_form_slice(
|
257 | 277 | theslice: slice,
|
@@ -298,6 +318,7 @@ def get_canonical_form_slice(
|
298 | 318 |
|
299 | 319 | # Other non-slice types are the scalar indexing case
|
300 | 320 | if not isinstance(theslice, slice):
|
| 321 | + theslice = undo_scalarization(theslice) |
301 | 322 | if isinstance(theslice, int | np.integer | ScalarVariable) or (
|
302 | 323 | isinstance(theslice, TensorVariable) and theslice.ndim == 0
|
303 | 324 | ):
|
@@ -381,6 +402,7 @@ def analyze(x):
|
381 | 402 | elif is_stop_length:
|
382 | 403 | # start:length:1
|
383 | 404 | if is_start_constant and start >= 0:
|
| 405 | + length = undo_scalarization(length) |
384 | 406 | return slice(switch(lt(start, length), start, length), length, 1), 1
|
385 | 407 | start_plus_len = start + length
|
386 | 408 | start = switch(
|
|
0 commit comments