Skip to content

Commit 0dc2a1e

Browse files
committed
Simplify graph returned by Subtensor.infer_shape
1 parent a56e3a5 commit 0dc2a1e

File tree

2 files changed

+198
-52
lines changed

2 files changed

+198
-52
lines changed

pytensor/tensor/subtensor.py

Lines changed: 130 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -33,14 +33,16 @@
3333
alloc,
3434
get_scalar_constant_value,
3535
nonzero,
36+
switch,
3637
)
3738
from pytensor.tensor.basic import (
3839
constant as tensor_constant,
3940
)
4041
from pytensor.tensor.blockwise import vectorize_node_fallback
4142
from pytensor.tensor.elemwise import DimShuffle
4243
from pytensor.tensor.exceptions import AdvancedIndexingError, NotScalarConstantError
43-
from pytensor.tensor.math import clip
44+
from pytensor.tensor.math import abs as pt_abs
45+
from pytensor.tensor.math import clip, eq, ge, lt, maximum, minimum, sign
4446
from pytensor.tensor.shape import Reshape, Shape_i, specify_broadcastable
4547
from pytensor.tensor.type import (
4648
TensorType,
@@ -55,6 +57,7 @@
5557
lscalar,
5658
tensor,
5759
ubscalar,
60+
uint_dtypes,
5861
uiscalar,
5962
ulscalar,
6063
uwscalar,
@@ -254,6 +257,25 @@ def get_idx_list(inputs, idx_list):
254257
return indices_from_subtensor(inputs[1:], idx_list)
255258

256259

260+
def undo_scalarization(x):
261+
"""Undo scalarization of a variable.
262+
263+
PyTensor Basic index operations use ScalarVariables for the indices/slice arguments.
264+
But reasoning symbolically about the result of multiple indexing operations, we usually
265+
want to work on TensorVariables, since rewrites work on those and not ScalarVariables.
266+
267+
This function undoes ScalarFromTensor operation or converts ScalarConstants to TensorConstants.
268+
"""
269+
if isinstance(x, ScalarVariable):
270+
if isinstance(x, ScalarConstant):
271+
return tensor_constant(x.data, dtype=x.dtype)
272+
elif x.owner is not None and isinstance(x.owner.op, ScalarFromTensor):
273+
return x.owner.inputs[0]
274+
else:
275+
return as_tensor_variable(x)
276+
return x
277+
278+
257279
@overload
258280
def get_canonical_form_slice(
259281
theslice: slice,
@@ -296,25 +318,6 @@ def get_canonical_form_slice(
296318
direction
297319
Direction to iterate the resulting elements in. (-1 or 1). May be symbolic.
298320
"""
299-
from pytensor.tensor import ge, lt, sign, switch
300-
301-
def undo_scalarization(x):
302-
"""Undo scalarization of a variable.
303-
304-
PyTensor Basic index operations use ScalarVariables for the indices/slice arguments.
305-
But reasoning symbolically about the result of multiple indexing operations, we usually
306-
want to work on TensorVariables, since rewrites work on those and not ScalarVariables.
307-
308-
This function undoes ScalarFromTensor operation or converts ScalarConstants to TensorConstants.
309-
"""
310-
if isinstance(x, ScalarVariable):
311-
if isinstance(x, ScalarConstant):
312-
return tensor_constant(x.data, dtype=x.dtype)
313-
elif x.owner is not None and isinstance(x.owner.op, ScalarFromTensor):
314-
return x.owner.inputs[0]
315-
else:
316-
return as_tensor_variable(x)
317-
return x
318321

319322
def analyze(x):
320323
try:
@@ -845,6 +848,17 @@ def as_nontensor_scalar(a: Variable) -> ps.ScalarVariable:
845848
return ps.as_scalar(a)
846849

847850

851+
def _eager_switch(
852+
cond: TensorVariable | bool, a: TensorVariable, b: TensorVariable
853+
) -> TensorVariable:
854+
# Do not create a switch if cond is True/False
855+
# We need this because uint types cannot be negative and creating the lazy switch could upcast everything to float64
856+
# It also simplifies immediately the graph that's returned
857+
if isinstance(cond, bool):
858+
return a if cond else b
859+
return cast(TensorVariable, switch(cond, a, b))
860+
861+
848862
class Subtensor(COp):
849863
"""Basic NumPy indexing operator."""
850864

@@ -956,27 +970,112 @@ def infer_shape(self, fgraph, node, shapes):
956970
padded = actual_idx_list + [slice(None, None, None)] * (
957971
len(xshp) - len(self.idx_list)
958972
)
973+
974+
zero = tensor_constant(np.array(0, dtype="int64"))
975+
one = tensor_constant(np.array(1, dtype="int64"))
959976
i = 0
960977
for idx, xl in zip(padded, xshp, strict=True):
961978
if isinstance(idx, slice):
962-
# If it is the default (None, None, None) slice, or a variant,
963-
# the shape will be xl
979+
a, b, step = idx.start, idx.stop, idx.step
964980
if (
965-
(idx.start in [None, 0])
966-
and (idx.stop in [None, sys.maxsize])
967-
and (idx.step is None or idx.step == 1)
981+
a is None
982+
and b is None
983+
and step is not None
984+
and get_scalar_constant_value(step, raise_not_constant=False) == -1
968985
):
986+
# Shortcut for x[::-1]
969987
outshp.append(xl)
988+
970989
else:
971-
cnf = get_canonical_form_slice(idx, xl)[0]
972-
if cnf.step == 1:
973-
length = cnf.stop - cnf.start
990+
if step is None:
991+
step_pos = True
992+
unit_step = True
993+
abs_step = one
994+
else:
995+
step = undo_scalarization(step)
996+
if step.dtype in uint_dtypes:
997+
step_pos = True
998+
abs_step = step.astype("int64")
999+
else:
1000+
step_pos = ge(step, zero)
1001+
abs_step = pt_abs(step)
1002+
unit_step = eq(abs_step, one)
1003+
1004+
if a is None:
1005+
a_pos = True
1006+
a = _eager_switch(step_pos, zero, xl)
9741007
else:
975-
length = (cnf.stop - cnf.start - 1) // cnf.step + 1
976-
outshp.append(length)
1008+
a = undo_scalarization(a)
1009+
if a.dtype in uint_dtypes:
1010+
a_pos = True
1011+
a = a.astype("int64")
1012+
else:
1013+
a_pos = ge(a, zero)
1014+
1015+
if b is None:
1016+
# For negative steps there is no numerical equivalent for stop=None.
1017+
# The formulas below work if we set it to -1 and consider `b_pos=True`
1018+
b_pos = True
1019+
b = _eager_switch(step_pos, xl, -one)
1020+
else:
1021+
b = undo_scalarization(b)
1022+
if b.dtype in uint_dtypes:
1023+
b = b.astype("int64")
1024+
b_pos = True
1025+
else:
1026+
b_pos = ge(b, zero)
1027+
1028+
slice_length_pos_step = _eager_switch(
1029+
a_pos,
1030+
_eager_switch(
1031+
b_pos,
1032+
minimum(b - a, xl - a), # [a: b]
1033+
((xl + b) - a), # [a: -b]
1034+
),
1035+
_eager_switch(
1036+
b_pos,
1037+
# The [-a: b] is peculiar, the slice length actually decreases for larger arrays
1038+
# The branch -a is useless when b - a / 2 <= -a. Similar for the branch b
1039+
minimum(minimum(xl, b - a - xl), minimum(-a, b)), # [-a: b]
1040+
minimum(b - a, xl + b), # [-a: -b]
1041+
),
1042+
)
1043+
1044+
slice_length_neg_step = _eager_switch(
1045+
a_pos,
1046+
_eager_switch(
1047+
b_pos,
1048+
minimum(a - b, xl - b - one), # [a: b]
1049+
minimum(
1050+
minimum(xl, a - (xl + b)), minimum(a + one, -b - one)
1051+
), # [a: -b]
1052+
),
1053+
_eager_switch(
1054+
b_pos,
1055+
((xl + a) - b), # [-a: b]
1056+
minimum(a - b, xl + a + one), # [-a: -b]
1057+
),
1058+
)
1059+
1060+
slice_length = _eager_switch(
1061+
step_pos,
1062+
slice_length_pos_step,
1063+
slice_length_neg_step,
1064+
)
1065+
1066+
# Incorporate step size
1067+
slice_length = _eager_switch(
1068+
unit_step,
1069+
slice_length,
1070+
(slice_length - one) // abs_step + one,
1071+
)
1072+
# Catch negative sizes
1073+
slice_length = maximum(zero, slice_length)
1074+
outshp.append(slice_length)
1075+
9771076
i += 1
9781077
else:
979-
# That dimension is dropped
1078+
# That dimension is dropped by integer indexing
9801079
pass
9811080
assert i == node.outputs[0].ndim
9821081
assert len(outshp) == node.outputs[0].ndim

tests/tensor/test_subtensor.py

Lines changed: 68 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,10 @@
1515
from pytensor.compile.mode import Mode
1616
from pytensor.configdefaults import config
1717
from pytensor.gradient import grad
18-
from pytensor.graph import Constant
18+
from pytensor.graph import Constant, FunctionGraph
1919
from pytensor.graph.basic import equal_computations
2020
from pytensor.graph.op import get_test_value
21-
from pytensor.graph.rewriting.utils import is_same_graph
21+
from pytensor.graph.rewriting.utils import is_same_graph, rewrite_graph
2222
from pytensor.printing import pprint
2323
from pytensor.scalar.basic import as_scalar, int16
2424
from pytensor.tensor import as_tensor, get_vector_length, vectorize
@@ -71,6 +71,7 @@
7171
lscalar,
7272
lvector,
7373
matrix,
74+
scalar,
7475
tensor,
7576
tensor3,
7677
tensor4,
@@ -1055,26 +1056,8 @@ def test_adv_sub1_idx_broadcast(self):
10551056
assert np.allclose(g_0[0], 1)
10561057
assert np.allclose(g_0[1:], 0)
10571058

1058-
@pytest.mark.slow
1059-
def test_shape_i_const(self):
1060-
# Each axis is treated independently by shape_i/shape operators
1061-
1062-
mode_opt = self.mode
1063-
data = self.shared(np.array(np.arange(5), dtype=self.dtype))
1064-
for start in [None, -8, -5, -1, 0, 1, 5, 8]:
1065-
outs = []
1066-
shapes = []
1067-
for stop in [None, -8, -5, -1, 0, 1, 5, 8]:
1068-
for step in [None, -3, -1, 2]:
1069-
outs += [data[start:stop:step].shape]
1070-
shapes += [data.get_value(borrow=True)[start:stop:step].shape]
1071-
f = self.function([], outs, mode=mode_opt, op=subtensor_ops, N=0)
1072-
t_shapes = f()
1073-
for t_shape, shape in zip(t_shapes, shapes, strict=True):
1074-
assert np.all(t_shape == shape)
1075-
assert Subtensor not in [x.op for x in f.maker.fgraph.toposort()]
1076-
10771059
def test_shape_i_scalar(self):
1060+
# TODO: Move this to infer_shape
10781061
# Each axis is treated independently by shape_i/shape operators
10791062

10801063
mode_opt = self.mode
@@ -1466,6 +1449,70 @@ def test_adv1_inc_sub_notlastdim_1_2dval_no_broadcast(self):
14661449
assert np.allclose(m2_val, m2_ref), (m2_val, m2_ref)
14671450

14681451

1452+
class TestSubtensorInferShape:
1453+
_NO_OPT_MODE = Mode(linker="py", optimizer=None)
1454+
1455+
@pytest.mark.parametrize(
1456+
"b", [None, 0, 1, 7, 13, -1, -7, -13], ids=lambda x: f"b={x}"
1457+
)
1458+
@pytest.mark.parametrize(
1459+
"a", [None, 0, 1, 7, 13, -1, -7, -13], ids=lambda x: f"a={x}"
1460+
)
1461+
@pytest.mark.parametrize("step", [None, 1, 3, -1, -4], ids=lambda x: f"step={x}")
1462+
def test_constant_params(self, a, b, step):
1463+
x = vector("x", dtype="int64")
1464+
y = x[a:b:step].shape[0]
1465+
1466+
fg = FunctionGraph(outputs=[y], clone=False)
1467+
rewrite_graph(fg, include=("ShapeOpt", "canonicalize"), clone=False)
1468+
assert not any(isinstance(node.op, Subtensor) for node in fg.apply_nodes)
1469+
assert len(fg.apply_nodes) <= 9
1470+
1471+
fn = pytensor.function(
1472+
[x],
1473+
fg.outputs[0],
1474+
trust_input=True,
1475+
mode=self._NO_OPT_MODE,
1476+
on_unused_input="ignore",
1477+
)
1478+
x_full = np.arange(20)
1479+
for n in range(0, 20):
1480+
x_test = x_full[:n]
1481+
assert fn(x_test) == x_test[a:b:step].shape[0], f"failed with {n=}"
1482+
1483+
@pytest.mark.parametrize("a_dtype", (None, "int64", "uint64"))
1484+
@pytest.mark.parametrize("b_dtype", (None, "int64", "uint64"))
1485+
@pytest.mark.parametrize("step_dtype", (None, "int64", "uint64"))
1486+
def test_uint(self, a_dtype, b_dtype, step_dtype):
1487+
a = None if a_dtype is None else scalar(dtype=a_dtype)
1488+
b = None if b_dtype is None else scalar(dtype=b_dtype)
1489+
step = None if step_dtype is None else scalar(dtype=step_dtype)
1490+
x = vector("x", dtype="int64")
1491+
1492+
y = x[a:b:step].shape[0]
1493+
1494+
final_y = rewrite_graph(y, include=("ShapeOpt", "canonicalize"), clone=False)
1495+
assert final_y.dtype == "int64"
1496+
1497+
test_a = None if a is None else 1 if a_dtype.startswith("u") else -1
1498+
test_b = None if b is None else 10 if b_dtype.startswith("u") else -2
1499+
test_step = None if step is None else 2 if step_dtype.startswith("u") else -2
1500+
test_x = np.arange(20)
1501+
1502+
test_dict = {x: test_x}
1503+
if a is not None:
1504+
test_dict[a] = test_a
1505+
if b is not None:
1506+
test_dict[b] = test_b
1507+
if step is not None:
1508+
test_dict[step] = test_step
1509+
1510+
final_y_eval = final_y.eval(
1511+
test_dict, mode=self._NO_OPT_MODE, on_unused_input="ignore"
1512+
)
1513+
assert final_y_eval == test_x[test_a:test_b:test_step].shape[0]
1514+
1515+
14691516
def test_take_basic():
14701517
with pytest.raises(TypeError):
14711518
take(matrix(), lvector(), axis=lscalar())

0 commit comments

Comments
 (0)