|
33 | 33 | alloc,
|
34 | 34 | get_scalar_constant_value,
|
35 | 35 | nonzero,
|
| 36 | + switch, |
36 | 37 | )
|
37 | 38 | from pytensor.tensor.basic import (
|
38 | 39 | constant as tensor_constant,
|
39 | 40 | )
|
40 | 41 | from pytensor.tensor.blockwise import vectorize_node_fallback
|
41 | 42 | from pytensor.tensor.elemwise import DimShuffle
|
42 | 43 | from pytensor.tensor.exceptions import AdvancedIndexingError, NotScalarConstantError
|
43 |
| -from pytensor.tensor.math import clip |
| 44 | +from pytensor.tensor.math import abs as pt_abs |
| 45 | +from pytensor.tensor.math import clip, eq, ge, lt, maximum, minimum, sign |
44 | 46 | from pytensor.tensor.shape import Reshape, Shape_i, specify_broadcastable
|
45 | 47 | from pytensor.tensor.type import (
|
46 | 48 | TensorType,
|
|
55 | 57 | lscalar,
|
56 | 58 | tensor,
|
57 | 59 | ubscalar,
|
| 60 | + uint_dtypes, |
58 | 61 | uiscalar,
|
59 | 62 | ulscalar,
|
60 | 63 | uwscalar,
|
@@ -254,6 +257,25 @@ def get_idx_list(inputs, idx_list):
|
254 | 257 | return indices_from_subtensor(inputs[1:], idx_list)
|
255 | 258 |
|
256 | 259 |
|
| 260 | +def undo_scalarization(x): |
| 261 | + """Undo scalarization of a variable. |
| 262 | +
|
| 263 | + PyTensor Basic index operations use ScalarVariables for the indices/slice arguments. |
| 264 | + But reasoning symbolically about the result of multiple indexing operations, we usually |
| 265 | + want to work on TensorVariables, since rewrites work on those and not ScalarVariables. |
| 266 | +
|
| 267 | + This function undoes ScalarFromTensor operation or converts ScalarConstants to TensorConstants. |
| 268 | + """ |
| 269 | + if isinstance(x, ScalarVariable): |
| 270 | + if isinstance(x, ScalarConstant): |
| 271 | + return tensor_constant(x.data, dtype=x.dtype) |
| 272 | + elif x.owner is not None and isinstance(x.owner.op, ScalarFromTensor): |
| 273 | + return x.owner.inputs[0] |
| 274 | + else: |
| 275 | + return as_tensor_variable(x) |
| 276 | + return x |
| 277 | + |
| 278 | + |
257 | 279 | @overload
|
258 | 280 | def get_canonical_form_slice(
|
259 | 281 | theslice: slice,
|
@@ -296,25 +318,6 @@ def get_canonical_form_slice(
|
296 | 318 | direction
|
297 | 319 | Direction to iterate the resulting elements in. (-1 or 1). May be symbolic.
|
298 | 320 | """
|
299 |
| - from pytensor.tensor import ge, lt, sign, switch |
300 |
| - |
301 |
| - def undo_scalarization(x): |
302 |
| - """Undo scalarization of a variable. |
303 |
| -
|
304 |
| - PyTensor Basic index operations use ScalarVariables for the indices/slice arguments. |
305 |
| - But reasoning symbolically about the result of multiple indexing operations, we usually |
306 |
| - want to work on TensorVariables, since rewrites work on those and not ScalarVariables. |
307 |
| -
|
308 |
| - This function undoes ScalarFromTensor operation or converts ScalarConstants to TensorConstants. |
309 |
| - """ |
310 |
| - if isinstance(x, ScalarVariable): |
311 |
| - if isinstance(x, ScalarConstant): |
312 |
| - return tensor_constant(x.data, dtype=x.dtype) |
313 |
| - elif x.owner is not None and isinstance(x.owner.op, ScalarFromTensor): |
314 |
| - return x.owner.inputs[0] |
315 |
| - else: |
316 |
| - return as_tensor_variable(x) |
317 |
| - return x |
318 | 321 |
|
319 | 322 | def analyze(x):
|
320 | 323 | try:
|
@@ -845,6 +848,17 @@ def as_nontensor_scalar(a: Variable) -> ps.ScalarVariable:
|
845 | 848 | return ps.as_scalar(a)
|
846 | 849 |
|
847 | 850 |
|
| 851 | +def _eager_switch( |
| 852 | + cond: TensorVariable | bool, a: TensorVariable, b: TensorVariable |
| 853 | +) -> TensorVariable: |
| 854 | + # Do not create a switch if cond is True/False |
| 855 | + # We need this because uint types cannot be negative and creating the lazy switch could upcast everything to float64 |
| 856 | + # It also simplifies immediately the graph that's returned |
| 857 | + if isinstance(cond, bool): |
| 858 | + return a if cond else b |
| 859 | + return cast(TensorVariable, switch(cond, a, b)) |
| 860 | + |
| 861 | + |
848 | 862 | class Subtensor(COp):
|
849 | 863 | """Basic NumPy indexing operator."""
|
850 | 864 |
|
@@ -956,27 +970,112 @@ def infer_shape(self, fgraph, node, shapes):
|
956 | 970 | padded = actual_idx_list + [slice(None, None, None)] * (
|
957 | 971 | len(xshp) - len(self.idx_list)
|
958 | 972 | )
|
| 973 | + |
| 974 | + zero = tensor_constant(np.array(0, dtype="int64")) |
| 975 | + one = tensor_constant(np.array(1, dtype="int64")) |
959 | 976 | i = 0
|
960 | 977 | for idx, xl in zip(padded, xshp, strict=True):
|
961 | 978 | if isinstance(idx, slice):
|
962 |
| - # If it is the default (None, None, None) slice, or a variant, |
963 |
| - # the shape will be xl |
| 979 | + a, b, step = idx.start, idx.stop, idx.step |
964 | 980 | if (
|
965 |
| - (idx.start in [None, 0]) |
966 |
| - and (idx.stop in [None, sys.maxsize]) |
967 |
| - and (idx.step is None or idx.step == 1) |
| 981 | + a is None |
| 982 | + and b is None |
| 983 | + and step is not None |
| 984 | + and get_scalar_constant_value(step, raise_not_constant=False) == -1 |
968 | 985 | ):
|
| 986 | + # Shortcut for x[::-1] |
969 | 987 | outshp.append(xl)
|
| 988 | + |
970 | 989 | else:
|
971 |
| - cnf = get_canonical_form_slice(idx, xl)[0] |
972 |
| - if cnf.step == 1: |
973 |
| - length = cnf.stop - cnf.start |
| 990 | + if step is None: |
| 991 | + step_pos = True |
| 992 | + unit_step = True |
| 993 | + abs_step = one |
| 994 | + else: |
| 995 | + step = undo_scalarization(step) |
| 996 | + if step.dtype in uint_dtypes: |
| 997 | + step_pos = True |
| 998 | + abs_step = step.astype("int64") |
| 999 | + else: |
| 1000 | + step_pos = ge(step, zero) |
| 1001 | + abs_step = pt_abs(step) |
| 1002 | + unit_step = eq(abs_step, one) |
| 1003 | + |
| 1004 | + if a is None: |
| 1005 | + a_pos = True |
| 1006 | + a = _eager_switch(step_pos, zero, xl) |
974 | 1007 | else:
|
975 |
| - length = (cnf.stop - cnf.start - 1) // cnf.step + 1 |
976 |
| - outshp.append(length) |
| 1008 | + a = undo_scalarization(a) |
| 1009 | + if a.dtype in uint_dtypes: |
| 1010 | + a_pos = True |
| 1011 | + a = a.astype("int64") |
| 1012 | + else: |
| 1013 | + a_pos = ge(a, zero) |
| 1014 | + |
| 1015 | + if b is None: |
| 1016 | + # For negative steps there is no numerical equivalent for stop=None. |
| 1017 | + # The formulas below work if we set it to -1 and consider `b_pos=True` |
| 1018 | + b_pos = True |
| 1019 | + b = _eager_switch(step_pos, xl, -one) |
| 1020 | + else: |
| 1021 | + b = undo_scalarization(b) |
| 1022 | + if b.dtype in uint_dtypes: |
| 1023 | + b = b.astype("int64") |
| 1024 | + b_pos = True |
| 1025 | + else: |
| 1026 | + b_pos = ge(b, zero) |
| 1027 | + |
| 1028 | + slice_length_pos_step = _eager_switch( |
| 1029 | + a_pos, |
| 1030 | + _eager_switch( |
| 1031 | + b_pos, |
| 1032 | + minimum(b - a, xl - a), # [a: b] |
| 1033 | + ((xl + b) - a), # [a: -b] |
| 1034 | + ), |
| 1035 | + _eager_switch( |
| 1036 | + b_pos, |
| 1037 | + # The [-a: b] is peculiar, the slice length actually decreases for larger arrays |
| 1038 | + # The branch -a is useless when b - a / 2 <= -a. Similar for the branch b |
| 1039 | + minimum(minimum(xl, b - a - xl), minimum(-a, b)), # [-a: b] |
| 1040 | + minimum(b - a, xl + b), # [-a: -b] |
| 1041 | + ), |
| 1042 | + ) |
| 1043 | + |
| 1044 | + slice_length_neg_step = _eager_switch( |
| 1045 | + a_pos, |
| 1046 | + _eager_switch( |
| 1047 | + b_pos, |
| 1048 | + minimum(a - b, xl - b - one), # [a: b] |
| 1049 | + minimum( |
| 1050 | + minimum(xl, a - (xl + b)), minimum(a + one, -b - one) |
| 1051 | + ), # [a: -b] |
| 1052 | + ), |
| 1053 | + _eager_switch( |
| 1054 | + b_pos, |
| 1055 | + ((xl + a) - b), # [-a: b] |
| 1056 | + minimum(a - b, xl + a + one), # [-a: -b] |
| 1057 | + ), |
| 1058 | + ) |
| 1059 | + |
| 1060 | + slice_length = _eager_switch( |
| 1061 | + step_pos, |
| 1062 | + slice_length_pos_step, |
| 1063 | + slice_length_neg_step, |
| 1064 | + ) |
| 1065 | + |
| 1066 | + # Incorporate step size |
| 1067 | + slice_length = _eager_switch( |
| 1068 | + unit_step, |
| 1069 | + slice_length, |
| 1070 | + (slice_length - one) // abs_step + one, |
| 1071 | + ) |
| 1072 | + # Catch negative sizes |
| 1073 | + slice_length = maximum(zero, slice_length) |
| 1074 | + outshp.append(slice_length) |
| 1075 | + |
977 | 1076 | i += 1
|
978 | 1077 | else:
|
979 |
| - # That dimension is dropped |
| 1078 | + # That dimension is dropped by integer indexing |
980 | 1079 | pass
|
981 | 1080 | assert i == node.outputs[0].ndim
|
982 | 1081 | assert len(outshp) == node.outputs[0].ndim
|
|
0 commit comments