28
28
from pytensor .graph .fg import FunctionGraph
29
29
from pytensor .graph .op import compute_test_value
30
30
from pytensor .graph .replace import clone_replace
31
- from pytensor .graph .rewriting .basic import GraphRewriter , in2out , node_rewriter
31
+ from pytensor .graph .rewriting .basic import (
32
+ GraphRewriter ,
33
+ copy_stack_trace ,
34
+ in2out ,
35
+ node_rewriter ,
36
+ )
32
37
from pytensor .graph .rewriting .db import EquilibriumDB , SequenceDB
38
+ from pytensor .graph .rewriting .utils import get_clients_at_depth
33
39
from pytensor .graph .type import HasShape
34
40
from pytensor .graph .utils import InconsistencyError
41
+ from pytensor .raise_op import Assert
42
+ from pytensor .scalar import ScalarConstant
35
43
from pytensor .scan .op import Scan , ScanInfo
36
44
from pytensor .scan .utils import (
37
45
ScanArgs ,
@@ -1103,6 +1111,71 @@ def sanitize(x):
1103
1111
return at .as_tensor_variable (x )
1104
1112
1105
1113
1114
+ @node_rewriter ([Scan ])
1115
+ def while_scan_merge_subtensor_last_element (fgraph , scan_node ):
1116
+ """
1117
+ Replace while_scan_out[abs(min(tap)):][-1] by while_scan_out[-1], for
1118
+ recurring outputs, asserting that at least one step occurs.
1119
+ Only the first step can be ensured by the inputs alone (i.e., `n_steps > 0`),
1120
+ as the while scan could abort earlier anytime after that. This means it is
1121
+ not possible to replace while_scan_out[abs(min(tap)):][-i]
1122
+ by while_scan_out[-i], for -i != -1.
1123
+ """
1124
+ op = scan_node .op
1125
+
1126
+ if not op .info .as_while :
1127
+ return None
1128
+
1129
+ # Optimization is not implemented form mit-mot
1130
+ recurrent_outputs = op .outer_mitsot_outs (scan_node .outputs ) + op .outer_sitsot_outs (
1131
+ scan_node .outputs
1132
+ )
1133
+ recurrent_outputs_taps_slices = (
1134
+ op .info .mit_sot_in_slices + op .info .sit_sot_in_slices
1135
+ )
1136
+
1137
+ n_steps = scan_node .inputs [0 ]
1138
+ non_zero_steps_cond = n_steps > 0
1139
+ assert_non_zero_steps_op = Assert ("n_steps > 0" )
1140
+
1141
+ subtensor_merge_replacements = {}
1142
+
1143
+ # Iterate over all nodes that are two computations below the while scan
1144
+ for node2 in get_clients_at_depth (fgraph , scan_node , depth = 2 ):
1145
+ if not isinstance (node2 .op , Subtensor ):
1146
+ continue
1147
+
1148
+ node1 = node2 .inputs [0 ].owner
1149
+ if not (node1 and isinstance (node1 .op , Subtensor )):
1150
+ continue
1151
+
1152
+ x = node1 .inputs [0 ]
1153
+ if x not in recurrent_outputs :
1154
+ continue
1155
+
1156
+ slice1 = get_idx_list (node1 .inputs , node1 .op .idx_list )
1157
+ slice2 = get_idx_list (node2 .inputs , node2 .op .idx_list )
1158
+
1159
+ min_tap = abs (min (recurrent_outputs_taps_slices [recurrent_outputs .index (x )]))
1160
+
1161
+ if (
1162
+ len (slice1 ) == 1
1163
+ and isinstance (slice1 [0 ], slice )
1164
+ and isinstance (slice1 [0 ].start , aes .ScalarConstant )
1165
+ and slice1 [0 ].start .data == min_tap
1166
+ and slice1 [0 ].stop is None
1167
+ and slice1 [0 ].step is None
1168
+ and len (slice2 ) == 1
1169
+ and isinstance (slice2 [0 ], aes .ScalarConstant )
1170
+ and slice2 [0 ].data == - 1
1171
+ ):
1172
+ out = assert_non_zero_steps_op (x [- 1 ], non_zero_steps_cond )
1173
+ copy_stack_trace ([node2 .outputs [0 ], node2 .inputs [0 ]], out )
1174
+ subtensor_merge_replacements [node2 .outputs [0 ]] = out
1175
+
1176
+ return subtensor_merge_replacements
1177
+
1178
+
1106
1179
@node_rewriter ([Scan ])
1107
1180
def save_mem_new_scan (fgraph , node ):
1108
1181
r"""Graph optimizer that reduces scan memory consumption.
@@ -1124,6 +1197,17 @@ def save_mem_new_scan(fgraph, node):
1124
1197
that SITSOT output. Only the most recently computed timestep ever needs to
1125
1198
be kept in memory.
1126
1199
1200
+ There are two ways in which the Scan buffer size is controlled:
1201
+ 1. Each recurring output is saved in an input empty tensor x with the initial
1202
+ state written at x[:abs(min(taps))]. The remaining x[abs(min(taps)):]
1203
+ positions determine how many intermediate results should be stored.
1204
+ This rewrite shortens x[abs(min(taps)):] to the smallest possible size.
1205
+ 2. Each non-recurrent output (nit-sot) is associated with a scalar integer
1206
+ input that determines how many steps should be saved in the perform method.
1207
+ This rewrite reduces this number to the smallest possible.
1208
+
1209
+ The scan perform implementation takes the output sizes into consideration,
1210
+ saving the newest results over the oldest ones whenever the buffer is filled.
1127
1211
"""
1128
1212
if not isinstance (node .op , Scan ):
1129
1213
return False
@@ -1172,13 +1256,16 @@ def save_mem_new_scan(fgraph, node):
1172
1256
# index(step) for any output scan actually needs to compute
1173
1257
# In other words n_steps should be equal to this maximal !
1174
1258
# Note: if we have a shared variable that gets updated at every step
1175
- # of the loop, reducing the number of steps will affect the the
1176
- # value of the shared variable after the loop so we need not to
1259
+ # of the loop, reducing the number of steps will affect the
1260
+ # value of the shared variable after the loop so we cannot
1177
1261
# change the number of steps in that case. To do this we set
1178
1262
# global_nsteps to None which is seen as a flag that nothing needs
1179
- # to be done
1263
+ # to be done.
1264
+ # Note: For simplicity while Scans also have global_nsteps set to None.
1265
+ # All step optimizations require knowing the shape of the output, which
1266
+ # cannot be determined from the inputs alone.
1180
1267
assert len (node .outputs ) >= c_outs
1181
- if len (node .outputs ) == c_outs :
1268
+ if len (node .outputs ) == c_outs and not op . info . as_while :
1182
1269
global_nsteps = {"real" : - 1 , "sym" : []}
1183
1270
else :
1184
1271
global_nsteps = None
@@ -1257,7 +1344,7 @@ def save_mem_new_scan(fgraph, node):
1257
1344
else :
1258
1345
# there is a **gotcha** here ! Namely, scan returns an
1259
1346
# array that contains the initial state of the output
1260
- # as well. Which means that if have a initial state of
1347
+ # as well. Which means that if y has an initial state of
1261
1348
# length 3, and you look for 5 steps you get an output
1262
1349
# y of length 8. If you only use y[:5], this does not
1263
1350
# mean that you only need to loop for 5 steps but
@@ -1285,9 +1372,9 @@ def save_mem_new_scan(fgraph, node):
1285
1372
1286
1373
# 2.3. Analyze global_nsteps to figure out for how many steps scan
1287
1374
# needs to iterate
1288
- if global_nsteps is not None :
1375
+ if global_nsteps is None :
1289
1376
nw_steps = node .inputs [0 ]
1290
-
1377
+ else :
1291
1378
# there are some symbolic tensors that limit the number of
1292
1379
# steps
1293
1380
if len (global_nsteps ["sym" ]) == 0 :
@@ -1303,16 +1390,14 @@ def save_mem_new_scan(fgraph, node):
1303
1390
real_steps = None
1304
1391
nw_steps = select_min (select_max (sym_steps , real_steps ), node .inputs [0 ])
1305
1392
1393
+ # FIXME: This is not correct. Scan with 0 steps seems to be supported
1306
1394
# Make sure the ScanSaveMem optimization never makes the new
1307
1395
# number of steps to be 0 (this could happen, for instance, if
1308
1396
# the optimization detects that the outputs of the Scan go through
1309
1397
# subtensor nodes that end up taking no elements) because Scan with
1310
1398
# 0 iterations are not supported. Make sure the new number of steps
1311
1399
# is at least 1.
1312
1400
nw_steps = select_max (nw_steps , 1 )
1313
- else :
1314
- nw_steps = node .inputs [0 ]
1315
- global_nsteps = None
1316
1401
1317
1402
# 2.4 Loop over the clients again now looking just to see how many
1318
1403
# intermediate steps to store
@@ -1335,19 +1420,33 @@ def save_mem_new_scan(fgraph, node):
1335
1420
store_steps [i ] = 0
1336
1421
break
1337
1422
1338
- if i > op_info .n_mit_mot :
1339
- length = node .inputs [0 ] + init_l [i ]
1423
+ # Special case for recurrent outputs where only the last result
1424
+ # is requested. This is needed for this rewrite to apply to
1425
+ # do-while Scans at all. Otherwise, `get_canonical_form_slice` in
1426
+ # the `else` branch would reintroduce a shape dependency on the
1427
+ # original Scan which would lead this rewrite to abort in the end.
1428
+ if (
1429
+ i <= op .info .n_mit_mot
1430
+ and isinstance (this_slice [0 ], ScalarConstant )
1431
+ and this_slice [0 ].value == - 1
1432
+ ):
1433
+ start = nw_steps - 1
1340
1434
else :
1341
- try :
1342
- length = shape_of [out ][0 ]
1343
- except KeyError :
1344
- length = out .shape [0 ]
1345
- cf_slice = get_canonical_form_slice (this_slice [0 ], length )
1435
+ if i <= op .info .n_mit_mot :
1436
+ try :
1437
+ length = shape_of [out ][0 ]
1438
+ except KeyError :
1439
+ length = out .shape [0 ]
1440
+ else :
1441
+ length = node .inputs [0 ] + init_l [i ]
1442
+
1443
+ cf_slice = get_canonical_form_slice (this_slice [0 ], length )
1444
+
1445
+ if isinstance (cf_slice [0 ], slice ):
1446
+ start = at .extract_constant (cf_slice [0 ].start )
1447
+ else :
1448
+ start = at .extract_constant (cf_slice [0 ])
1346
1449
1347
- if isinstance (cf_slice [0 ], slice ):
1348
- start = at .extract_constant (cf_slice [0 ].start )
1349
- else :
1350
- start = at .extract_constant (cf_slice [0 ])
1351
1450
if start == 0 or store_steps [i ] == 0 :
1352
1451
store_steps [i ] = 0
1353
1452
else :
@@ -1498,6 +1597,7 @@ def save_mem_new_scan(fgraph, node):
1498
1597
nw_input = expand_empty (_nw_input , nw_steps )
1499
1598
nw_inputs [in_idx ] = nw_input
1500
1599
else :
1600
+ # FIXME: This is never used
1501
1601
nw_input = nw_inputs [in_idx ][: (initl + nw_steps )]
1502
1602
1503
1603
elif (
@@ -1554,8 +1654,8 @@ def save_mem_new_scan(fgraph, node):
1554
1654
)
1555
1655
else :
1556
1656
fslice = sanitize (cnf_slice [0 ])
1557
-
1558
1657
nw_slice = (fslice ,) + tuple (old_slices [1 :])
1658
+
1559
1659
nw_pos = inv_compress_map [idx ]
1560
1660
1561
1661
subtens = Subtensor (nw_slice )
@@ -1604,9 +1704,16 @@ def save_mem_new_scan(fgraph, node):
1604
1704
) + tuple (old_slices [1 :])
1605
1705
1606
1706
else :
1607
- position = (
1608
- cnf_slice [0 ] - nw_steps - init_l [pos ] + store_steps [pos ]
1609
- )
1707
+ # Special case when only last value is requested
1708
+ if (
1709
+ isinstance (old_slices [0 ], ScalarConstant )
1710
+ and old_slices [0 ].value == - 1
1711
+ ):
1712
+ position = old_slices [0 ]
1713
+ else :
1714
+ position = (
1715
+ cnf_slice [0 ] - nw_steps - init_l [pos ] + store_steps [pos ]
1716
+ )
1610
1717
1611
1718
nw_slice = (sanitize (position ),) + tuple (old_slices [1 :])
1612
1719
subtens = Subtensor (nw_slice )
@@ -2403,6 +2510,12 @@ def push_out_dot1_scan(fgraph, node):
2403
2510
position = 5 ,
2404
2511
)
2405
2512
2513
+ scan_eqopt2 .register (
2514
+ "while_scan_merge_subtensor_last_element" ,
2515
+ in2out (while_scan_merge_subtensor_last_element , ignore_newtrees = True ),
2516
+ "fast_run" ,
2517
+ "scan" ,
2518
+ )
2406
2519
2407
2520
scan_eqopt2 .register (
2408
2521
"constant_folding_for_scan2" ,
0 commit comments