Skip to content

Commit 63f8d6e

Browse files
committed
Optimize while scans when only last state is needed
1 parent 01e92ba commit 63f8d6e

File tree

4 files changed

+243
-27
lines changed

4 files changed

+243
-27
lines changed

pytensor/scan/op.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1182,7 +1182,7 @@ def make_node(self, *inputs):
11821182
# these are states that do not feed anything back in the recurrent
11831183
# computation, and hence they do not have an initial state. The scan
11841184
# node however receives an input for each such argument, the input
1185-
# in this case is just a int saying how many steps of this output we
1185+
# in this case is just an int saying how many steps of this output we
11861186
# need to store. This input does not have the same dtype, nor is it the same
11871187
# type of tensor as the output, it is always a scalar int.
11881188
new_inputs += [as_tensor_variable(ons) for ons in self.outer_nitsot(inputs)]

pytensor/scan/rewriting.py

+139-26
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,18 @@
2828
from pytensor.graph.fg import FunctionGraph
2929
from pytensor.graph.op import compute_test_value
3030
from pytensor.graph.replace import clone_replace
31-
from pytensor.graph.rewriting.basic import GraphRewriter, in2out, node_rewriter
31+
from pytensor.graph.rewriting.basic import (
32+
GraphRewriter,
33+
copy_stack_trace,
34+
in2out,
35+
node_rewriter,
36+
)
3237
from pytensor.graph.rewriting.db import EquilibriumDB, SequenceDB
38+
from pytensor.graph.rewriting.utils import get_clients_at_depth
3339
from pytensor.graph.type import HasShape
3440
from pytensor.graph.utils import InconsistencyError
41+
from pytensor.raise_op import Assert
42+
from pytensor.scalar import ScalarConstant
3543
from pytensor.scan.op import Scan, ScanInfo
3644
from pytensor.scan.utils import (
3745
ScanArgs,
@@ -1103,6 +1111,71 @@ def sanitize(x):
11031111
return at.as_tensor_variable(x)
11041112

11051113

1114+
@node_rewriter([Scan])
1115+
def while_scan_merge_subtensor_last_element(fgraph, scan_node):
1116+
"""
1117+
Replace while_scan_out[abs(min(tap)):][-1] by while_scan_out[-1], for
1118+
recurring outputs, asserting that at least one step occurs.
1119+
Only the first step can be ensured by the inputs alone (i.e., `n_steps > 0`),
1120+
as the while scan could abort earlier anytime after that. This means it is
1121+
not possible to replace while_scan_out[abs(min(tap)):][-i]
1122+
by while_scan_out[-i], for -i != -1.
1123+
"""
1124+
op = scan_node.op
1125+
1126+
if not op.info.as_while:
1127+
return None
1128+
1129+
# Optimization is not implemented form mit-mot
1130+
recurrent_outputs = op.outer_mitsot_outs(scan_node.outputs) + op.outer_sitsot_outs(
1131+
scan_node.outputs
1132+
)
1133+
recurrent_outputs_taps_slices = (
1134+
op.info.mit_sot_in_slices + op.info.sit_sot_in_slices
1135+
)
1136+
1137+
n_steps = scan_node.inputs[0]
1138+
non_zero_steps_cond = n_steps > 0
1139+
assert_non_zero_steps_op = Assert("n_steps > 0")
1140+
1141+
subtensor_merge_replacements = {}
1142+
1143+
# Iterate over all nodes that are two computations below the while scan
1144+
for node2 in get_clients_at_depth(fgraph, scan_node, depth=2):
1145+
if not isinstance(node2.op, Subtensor):
1146+
continue
1147+
1148+
node1 = node2.inputs[0].owner
1149+
if not (node1 and isinstance(node1.op, Subtensor)):
1150+
continue
1151+
1152+
x = node1.inputs[0]
1153+
if x not in recurrent_outputs:
1154+
continue
1155+
1156+
slice1 = get_idx_list(node1.inputs, node1.op.idx_list)
1157+
slice2 = get_idx_list(node2.inputs, node2.op.idx_list)
1158+
1159+
min_tap = abs(min(recurrent_outputs_taps_slices[recurrent_outputs.index(x)]))
1160+
1161+
if (
1162+
len(slice1) == 1
1163+
and isinstance(slice1[0], slice)
1164+
and isinstance(slice1[0].start, aes.ScalarConstant)
1165+
and slice1[0].start.data == min_tap
1166+
and slice1[0].stop is None
1167+
and slice1[0].step is None
1168+
and len(slice2) == 1
1169+
and isinstance(slice2[0], aes.ScalarConstant)
1170+
and slice2[0].data == -1
1171+
):
1172+
out = assert_non_zero_steps_op(x[-1], non_zero_steps_cond)
1173+
copy_stack_trace([node2.outputs[0], node2.inputs[0]], out)
1174+
subtensor_merge_replacements[node2.outputs[0]] = out
1175+
1176+
return subtensor_merge_replacements
1177+
1178+
11061179
@node_rewriter([Scan])
11071180
def save_mem_new_scan(fgraph, node):
11081181
r"""Graph optimizer that reduces scan memory consumption.
@@ -1124,6 +1197,17 @@ def save_mem_new_scan(fgraph, node):
11241197
that SITSOT output. Only the most recently computed timestep ever needs to
11251198
be kept in memory.
11261199
1200+
There are two ways in which the Scan buffer size is controlled:
1201+
1. Each recurring output is saved in an input empty tensor x with the initial
1202+
state written at x[:abs(min(taps))]. The remaining x[abs(min(taps)):]
1203+
positions determine how many intermediate results should be stored.
1204+
This rewrite shortens x[abs(min(taps)):] to the smallest possible size.
1205+
2. Each non-recurrent output (nit-sot) is associated with a scalar integer
1206+
input that determines how many steps should be saved in the perform method.
1207+
This rewrite reduces this number to the smallest possible.
1208+
1209+
The scan perform implementation takes the output sizes into consideration,
1210+
saving the newest results over the oldest ones whenever the buffer is filled.
11271211
"""
11281212
if not isinstance(node.op, Scan):
11291213
return False
@@ -1172,13 +1256,16 @@ def save_mem_new_scan(fgraph, node):
11721256
# index(step) for any output scan actually needs to compute
11731257
# In other words n_steps should be equal to this maximal !
11741258
# Note: if we have a shared variable that gets updated at every step
1175-
# of the loop, reducing the number of steps will affect the the
1176-
# value of the shared variable after the loop so we need not to
1259+
# of the loop, reducing the number of steps will affect the
1260+
# value of the shared variable after the loop so we cannot
11771261
# change the number of steps in that case. To do this we set
11781262
# global_nsteps to None which is seen as a flag that nothing needs
1179-
# to be done
1263+
# to be done.
1264+
# Note: For simplicity while Scans also have global_nsteps set to None.
1265+
# All step optimizations require knowing the shape of the output, which
1266+
# cannot be determined from the inputs alone.
11801267
assert len(node.outputs) >= c_outs
1181-
if len(node.outputs) == c_outs:
1268+
if len(node.outputs) == c_outs and not op.info.as_while:
11821269
global_nsteps = {"real": -1, "sym": []}
11831270
else:
11841271
global_nsteps = None
@@ -1257,7 +1344,7 @@ def save_mem_new_scan(fgraph, node):
12571344
else:
12581345
# there is a **gotcha** here ! Namely, scan returns an
12591346
# array that contains the initial state of the output
1260-
# as well. Which means that if have a initial state of
1347+
# as well. Which means that if y has an initial state of
12611348
# length 3, and you look for 5 steps you get an output
12621349
# y of length 8. If you only use y[:5], this does not
12631350
# mean that you only need to loop for 5 steps but
@@ -1285,9 +1372,9 @@ def save_mem_new_scan(fgraph, node):
12851372

12861373
# 2.3. Analyze global_nsteps to figure out for how many steps scan
12871374
# needs to iterate
1288-
if global_nsteps is not None:
1375+
if global_nsteps is None:
12891376
nw_steps = node.inputs[0]
1290-
1377+
else:
12911378
# there are some symbolic tensors that limit the number of
12921379
# steps
12931380
if len(global_nsteps["sym"]) == 0:
@@ -1303,16 +1390,14 @@ def save_mem_new_scan(fgraph, node):
13031390
real_steps = None
13041391
nw_steps = select_min(select_max(sym_steps, real_steps), node.inputs[0])
13051392

1393+
# FIXME: This is not correct. Scan with 0 steps seems to be supported
13061394
# Make sure the ScanSaveMem optimization never makes the new
13071395
# number of steps to be 0 (this could happen, for instance, if
13081396
# the optimization detects that the outputs of the Scan go through
13091397
# subtensor nodes that end up taking no elements) because Scan with
13101398
# 0 iterations are not supported. Make sure the new number of steps
13111399
# is at least 1.
13121400
nw_steps = select_max(nw_steps, 1)
1313-
else:
1314-
nw_steps = node.inputs[0]
1315-
global_nsteps = None
13161401

13171402
# 2.4 Loop over the clients again now looking just to see how many
13181403
# intermediate steps to store
@@ -1335,19 +1420,33 @@ def save_mem_new_scan(fgraph, node):
13351420
store_steps[i] = 0
13361421
break
13371422

1338-
if i > op_info.n_mit_mot:
1339-
length = node.inputs[0] + init_l[i]
1423+
# Special case for recurrent outputs where only the last result
1424+
# is requested. This is needed for this rewrite to apply to
1425+
# do-while Scans at all. Otherwise, `get_canonical_form_slice` in
1426+
# the `else` branch would reintroduce a shape dependency on the
1427+
# original Scan which would lead this rewrite to abort in the end.
1428+
if (
1429+
i <= op.info.n_mit_mot
1430+
and isinstance(this_slice[0], ScalarConstant)
1431+
and this_slice[0].value == -1
1432+
):
1433+
start = nw_steps - 1
13401434
else:
1341-
try:
1342-
length = shape_of[out][0]
1343-
except KeyError:
1344-
length = out.shape[0]
1345-
cf_slice = get_canonical_form_slice(this_slice[0], length)
1435+
if i <= op.info.n_mit_mot:
1436+
try:
1437+
length = shape_of[out][0]
1438+
except KeyError:
1439+
length = out.shape[0]
1440+
else:
1441+
length = node.inputs[0] + init_l[i]
1442+
1443+
cf_slice = get_canonical_form_slice(this_slice[0], length)
1444+
1445+
if isinstance(cf_slice[0], slice):
1446+
start = at.extract_constant(cf_slice[0].start)
1447+
else:
1448+
start = at.extract_constant(cf_slice[0])
13461449

1347-
if isinstance(cf_slice[0], slice):
1348-
start = at.extract_constant(cf_slice[0].start)
1349-
else:
1350-
start = at.extract_constant(cf_slice[0])
13511450
if start == 0 or store_steps[i] == 0:
13521451
store_steps[i] = 0
13531452
else:
@@ -1498,6 +1597,7 @@ def save_mem_new_scan(fgraph, node):
14981597
nw_input = expand_empty(_nw_input, nw_steps)
14991598
nw_inputs[in_idx] = nw_input
15001599
else:
1600+
# FIXME: This is never used
15011601
nw_input = nw_inputs[in_idx][: (initl + nw_steps)]
15021602

15031603
elif (
@@ -1554,8 +1654,8 @@ def save_mem_new_scan(fgraph, node):
15541654
)
15551655
else:
15561656
fslice = sanitize(cnf_slice[0])
1557-
15581657
nw_slice = (fslice,) + tuple(old_slices[1:])
1658+
15591659
nw_pos = inv_compress_map[idx]
15601660

15611661
subtens = Subtensor(nw_slice)
@@ -1604,9 +1704,16 @@ def save_mem_new_scan(fgraph, node):
16041704
) + tuple(old_slices[1:])
16051705

16061706
else:
1607-
position = (
1608-
cnf_slice[0] - nw_steps - init_l[pos] + store_steps[pos]
1609-
)
1707+
# Special case when only last value is requested
1708+
if (
1709+
isinstance(old_slices[0], ScalarConstant)
1710+
and old_slices[0].value == -1
1711+
):
1712+
position = old_slices[0]
1713+
else:
1714+
position = (
1715+
cnf_slice[0] - nw_steps - init_l[pos] + store_steps[pos]
1716+
)
16101717

16111718
nw_slice = (sanitize(position),) + tuple(old_slices[1:])
16121719
subtens = Subtensor(nw_slice)
@@ -2403,6 +2510,12 @@ def push_out_dot1_scan(fgraph, node):
24032510
position=5,
24042511
)
24052512

2513+
scan_eqopt2.register(
2514+
"while_scan_merge_subtensor_last_element",
2515+
in2out(while_scan_merge_subtensor_last_element, ignore_newtrees=True),
2516+
"fast_run",
2517+
"scan",
2518+
)
24062519

24072520
scan_eqopt2.register(
24082521
"constant_folding_for_scan2",

pytensor/tensor/rewriting/subtensor.py

+11
Original file line numberDiff line numberDiff line change
@@ -479,6 +479,7 @@ def local_subtensor_merge(fgraph, node):
479479
expresses all slices in a canonical form, and then merges them together.
480480
481481
"""
482+
from pytensor.scan.op import Scan
482483

483484
if isinstance(node.op, Subtensor):
484485
u = node.inputs[0]
@@ -489,6 +490,16 @@ def local_subtensor_merge(fgraph, node):
489490
# slices of the first applied subtensor
490491
slices1 = get_idx_list(u.owner.inputs, u.owner.op.idx_list)
491492
slices2 = get_idx_list(node.inputs, node.op.idx_list)
493+
494+
# Don't try to do the optimization on do-while scan outputs,
495+
# as it will create a dependency on the shape of the outputs
496+
if (
497+
x.owner is not None
498+
and isinstance(x.owner.op, Scan)
499+
and x.owner.op.info.as_while
500+
):
501+
return None
502+
492503
# Get the shapes of the vectors !
493504
try:
494505
# try not to introduce new shape into the graph

tests/scan/test_rewriting.py

+92
Original file line numberDiff line numberDiff line change
@@ -1395,6 +1395,98 @@ def f_pow2(x_tm1):
13951395
rng = np.random.default_rng(utt.fetch_seed())
13961396
my_f(rng.uniform(size=(3,)), 4, np.int64([2, 2, 3]))
13971397

1398+
def test_while_scan_taps(self):
1399+
n_steps = scalar("n_steps", dtype="int64")
1400+
x0 = vector("x0")
1401+
1402+
ys, _ = pytensor.scan(
1403+
# Fibonacci Sequence
1404+
lambda xtm2, xtm1: (xtm1 + xtm2, {}, until(xtm1 >= 34)),
1405+
outputs_info=[{"initial": x0, "taps": [-2, -1]}],
1406+
n_steps=n_steps,
1407+
)
1408+
# Save memory is triggered by choosing only last value
1409+
y = ys[-1]
1410+
1411+
f = pytensor.function(
1412+
[n_steps, x0], y, mode=get_default_mode().including("scan")
1413+
)
1414+
1415+
np.testing.assert_equal(f(n_steps=1000, x0=[1, 1]), 55)
1416+
np.testing.assert_equal(f(n_steps=1, x0=[1, 1]), 2)
1417+
with pytest.raises(AssertionError, match="n_steps > 0"):
1418+
f(n_steps=0, x0=[1, 1])
1419+
1420+
# ys_trace is an Alloc that controls the size of the inner buffer,
1421+
# it should have shape[0] == 3, with two entries for the taps and one
1422+
# entry for the intermediate output
1423+
[scan_node] = (n for n in f.maker.fgraph.apply_nodes if isinstance(n.op, Scan))
1424+
_, ys_trace = scan_node.inputs
1425+
debug_fn = pytensor.function(
1426+
[n_steps, x0], ys_trace.shape[0], accept_inplace=True
1427+
)
1428+
assert debug_fn(n_steps=1000, x0=[1, 1]) == 3
1429+
1430+
def test_while_scan_map(self):
1431+
xs = vector("xs")
1432+
ys, _ = pytensor.scan(
1433+
lambda x: (x + 1, {}, until(x + 1 >= 10)),
1434+
outputs_info=[None],
1435+
sequences=[xs],
1436+
)
1437+
# Save memory is triggered by choosing only last value
1438+
y = ys[-1]
1439+
1440+
f = pytensor.function([xs], y, mode=get_default_mode().including("scan"))
1441+
np.testing.assert_equal(f(xs=np.arange(100, dtype=config.floatX)), 10)
1442+
np.testing.assert_equal(f(xs=[0]), 1)
1443+
with pytest.raises(IndexError):
1444+
f(xs=[])
1445+
1446+
# len_ys is a numerical input that controls the shape of the inner buffer
1447+
# It should be 1, as only the last output is needed
1448+
[scan_node] = (n for n in f.maker.fgraph.apply_nodes if isinstance(n.op, Scan))
1449+
_, _, len_ys = scan_node.inputs
1450+
debug_fn = pytensor.function([xs], len_ys, accept_inplace=True)
1451+
assert debug_fn(xs=np.zeros((100,), dtype=config.floatX)) == 1
1452+
1453+
def test_while_scan_taps_and_map(self):
1454+
x0 = scalar("x0")
1455+
seq = vector("seq")
1456+
n_steps = scalar("n_steps", dtype="int64")
1457+
1458+
# while loop
1459+
[ys, zs], _ = pytensor.scan(
1460+
lambda s, xtm1: ((xtm1 + 1, xtm1 + 1 + s), {}, until(xtm1 >= 99)),
1461+
sequences=[seq],
1462+
outputs_info=[x0, None],
1463+
n_steps=n_steps,
1464+
)
1465+
# Save memory is triggered by choosing only last value
1466+
y = ys[-1]
1467+
z = zs[-1]
1468+
1469+
f = pytensor.function(
1470+
[x0, seq, n_steps], [y, z], mode=get_default_mode().including("scan")
1471+
)
1472+
test_seq = np.zeros(200, dtype=config.floatX)
1473+
np.testing.assert_allclose(f(x0=0, seq=test_seq, n_steps=200), 100)
1474+
np.testing.assert_allclose(f(x0=1, seq=test_seq, n_steps=20), 21)
1475+
np.testing.assert_allclose(f(x0=np.e, seq=test_seq, n_steps=1), np.e + 1)
1476+
with pytest.raises(AssertionError, match="n_steps > 0"):
1477+
f(x0=0, seq=test_seq, n_steps=0)
1478+
1479+
# Evaluate the shape of ys_trace and len_zs to confirm the rewrite worked correctly.
1480+
# If a MissingInputError is raised, it means the rewrite failed
1481+
[scan_node] = (n for n in f.maker.fgraph.apply_nodes if isinstance(n.op, Scan))
1482+
_, _, ys_trace, len_zs = scan_node.inputs
1483+
debug_fn = pytensor.function(
1484+
[n_steps], [ys_trace.shape[0], len_zs], accept_inplace=True
1485+
)
1486+
stored_ys_steps, stored_zs_steps = debug_fn(n_steps=200)
1487+
assert stored_ys_steps == 2
1488+
assert stored_zs_steps == 1
1489+
13981490

13991491
def test_inner_replace_dot():
14001492
"""

0 commit comments

Comments
 (0)