70
70
get_slice_elements ,
71
71
set_subtensor ,
72
72
)
73
- from pytensor .tensor .variable import TensorConstant
73
+ from pytensor .tensor .variable import TensorConstant , TensorVariable
74
74
75
75
76
76
list_opt_slice = [
@@ -1182,8 +1182,7 @@ def while_scan_merge_subtensor_last_element(fgraph, scan_node):
1182
1182
return subtensor_merge_replacements
1183
1183
1184
1184
1185
- @node_rewriter ([Scan ])
1186
- def scan_save_mem (fgraph , node ):
1185
+ def scan_save_mem_rewrite (fgraph , node , backend_supports_output_pre_allocation : bool ):
1187
1186
r"""Graph optimizer that reduces scan memory consumption.
1188
1187
1189
1188
This optimizations attempts to determine if a `Scan` node, during its execution,
@@ -1214,10 +1213,16 @@ def scan_save_mem(fgraph, node):
1214
1213
1215
1214
The scan perform implementation takes the output sizes into consideration,
1216
1215
saving the newest results over the oldest ones whenever the buffer is filled.
1217
- """
1218
- if not isinstance (node .op , Scan ):
1219
- return False
1220
1216
1217
+ Paramaters
1218
+ ----------
1219
+ backend_supports_output_pre_allocation: bool
1220
+ When the backend supports output pre-allocation Scan must keep buffers
1221
+ with a length of required_states + 1, because the inner function will
1222
+ attempt to write the inner function outputs directly into the provided
1223
+ position in the outer circular buffer. This would invalidate results,
1224
+ if the input is still needed for some other output computation.
1225
+ """
1221
1226
if hasattr (fgraph , "shape_feature" ):
1222
1227
shape_of = fgraph .shape_feature .shape_of
1223
1228
else :
@@ -1270,14 +1275,15 @@ def scan_save_mem(fgraph, node):
1270
1275
# Note: For simplicity while Scans also have global_nsteps set to None.
1271
1276
# All step optimizations require knowing the shape of the output, which
1272
1277
# cannot be determined from the inputs alone.
1278
+ global_nsteps : None | dict
1273
1279
assert len (node .outputs ) >= c_outs
1274
1280
if len (node .outputs ) == c_outs and not op .info .as_while :
1275
1281
global_nsteps = {"real" : - 1 , "sym" : []}
1276
1282
else :
1277
1283
global_nsteps = None
1278
1284
1279
1285
# Keeps track of the original slices that each client represent
1280
- slices = [None for o in node .outputs ]
1286
+ slices : list [ None | list ] = [None for o in node .outputs ]
1281
1287
1282
1288
# A list for each output indicating how many intermediate values
1283
1289
# should be stored. If negative it means none of the intermediate
@@ -1294,7 +1300,7 @@ def scan_save_mem(fgraph, node):
1294
1300
# or not
1295
1301
flag_store = False
1296
1302
1297
- # 2.2 Loop over the clients
1303
+ # 2.2 Loop over the clients to figure out how many steps we actually need to do in the Scan
1298
1304
for i , out in enumerate (node .outputs [:c_outs ]):
1299
1305
# look at all its clients
1300
1306
slices [i ] = []
@@ -1337,7 +1343,7 @@ def scan_save_mem(fgraph, node):
1337
1343
except KeyError :
1338
1344
length = out .shape [0 ]
1339
1345
cf_slice = get_canonical_form_slice (this_slice [0 ], length )
1340
- slices [i ] += [(cf_slice , this_slice )]
1346
+ slices [i ] += [(cf_slice , this_slice )] # type: ignore
1341
1347
1342
1348
if isinstance (this_slice [0 ], slice ) and this_slice [0 ].stop is None :
1343
1349
global_nsteps = None
@@ -1477,7 +1483,10 @@ def scan_save_mem(fgraph, node):
1477
1483
# for mitsots and sitsots (because mitmots are not
1478
1484
# currently supported by the mechanism) and only if
1479
1485
# the pre-allocation mechanism is activated.
1480
- prealloc_outs = config .scan__allow_output_prealloc
1486
+ prealloc_outs = (
1487
+ backend_supports_output_pre_allocation
1488
+ and config .scan__allow_output_prealloc
1489
+ )
1481
1490
1482
1491
first_mitsot_idx = op_info .n_mit_mot
1483
1492
last_sitsot_idx = (
@@ -1486,6 +1495,8 @@ def scan_save_mem(fgraph, node):
1486
1495
preallocable_output = first_mitsot_idx <= i <= last_sitsot_idx
1487
1496
1488
1497
if prealloc_outs and preallocable_output :
1498
+ # TODO: If there's only one output or other outputs do not depend
1499
+ # on the same input, we could reduce the buffer size to the minimum
1489
1500
pval = select_max (nw_steps - start + init_l [i ], init_l [i ] + 1 )
1490
1501
else :
1491
1502
pval = select_max (nw_steps - start + init_l [i ], init_l [i ])
@@ -1652,7 +1663,7 @@ def scan_save_mem(fgraph, node):
1652
1663
name = op .name ,
1653
1664
allow_gc = op .allow_gc ,
1654
1665
)
1655
- new_outs = new_op (* node_ins , return_list = True )
1666
+ new_outs = cast ( list [ TensorVariable ], new_op (* node_ins , return_list = True ) )
1656
1667
1657
1668
old_new = []
1658
1669
# 3.7 Get replace pairs for those outputs that do not change
@@ -1682,7 +1693,7 @@ def scan_save_mem(fgraph, node):
1682
1693
sl_ins = get_slice_elements (
1683
1694
nw_slice , lambda entry : isinstance (entry , Variable )
1684
1695
)
1685
- new_o = subtens (new_outs [nw_pos ], * sl_ins )
1696
+ new_o = cast ( TensorVariable , subtens (new_outs [nw_pos ], * sl_ins ) )
1686
1697
if new_o .ndim > 0 :
1687
1698
new_o = new_o [:: cnf_slice [1 ]]
1688
1699
replaced_outs .append (idx )
@@ -1737,7 +1748,7 @@ def scan_save_mem(fgraph, node):
1737
1748
sl_ins = get_slice_elements (
1738
1749
nw_slice , lambda entry : isinstance (entry , Variable )
1739
1750
)
1740
- new_o = subtens (new_outs [nw_pos ], * sl_ins )
1751
+ new_o = cast ( TensorVariable , subtens (new_outs [nw_pos ], * sl_ins ) )
1741
1752
if new_o .ndim > 0 :
1742
1753
new_o = new_o [:: cnf_slice [1 ]]
1743
1754
old_new += [(old , new_o )]
@@ -1768,6 +1779,20 @@ def scan_save_mem(fgraph, node):
1768
1779
return False
1769
1780
1770
1781
1782
+ @node_rewriter ([Scan ])
1783
+ def scan_save_mem_prealloc (fgraph , node ):
1784
+ return scan_save_mem_rewrite (
1785
+ fgraph , node , backend_supports_output_pre_allocation = True
1786
+ )
1787
+
1788
+
1789
+ @node_rewriter ([Scan ])
1790
+ def scan_save_mem_no_prealloc (fgraph , node ):
1791
+ return scan_save_mem_rewrite (
1792
+ fgraph , node , backend_supports_output_pre_allocation = False
1793
+ )
1794
+
1795
+
1771
1796
class ScanMerge (GraphRewriter ):
1772
1797
r"""Graph optimizer that merges different scan ops.
1773
1798
@@ -2495,10 +2520,20 @@ def scan_push_out_dot1(fgraph, node):
2495
2520
optdb .register ("scan_eqopt2" , scan_eqopt2 , "fast_run" , "scan" , position = 1.6 )
2496
2521
# ScanSaveMem should execute only once per node.
2497
2522
optdb .register (
2498
- "scan_save_mem " ,
2499
- in2out (scan_save_mem , ignore_newtrees = True ),
2523
+ "scan_save_mem_prealloc " ,
2524
+ in2out (scan_save_mem_prealloc , ignore_newtrees = True ),
2500
2525
"fast_run" ,
2501
2526
"scan" ,
2527
+ "scan_save_mem" ,
2528
+ position = 1.61 ,
2529
+ )
2530
+ optdb .register (
2531
+ "scan_save_mem_no_prealloc" ,
2532
+ in2out (scan_save_mem_no_prealloc , ignore_newtrees = True ),
2533
+ "numba" ,
2534
+ "jax" ,
2535
+ "pytorch" ,
2536
+ use_db_name_as_tag = False ,
2502
2537
position = 1.61 ,
2503
2538
)
2504
2539
optdb .register (
0 commit comments