Skip to content

Commit 4b67bf5

Browse files
committed
Do more agressive scan memory saves in JIT backends
1 parent e37a8c0 commit 4b67bf5

File tree

6 files changed

+74
-32
lines changed

6 files changed

+74
-32
lines changed

pytensor/compile/mode.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -454,6 +454,19 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs):
454454
RewriteDatabaseQuery(include=["fast_run", "py_only"]),
455455
)
456456

457+
NUMBA = Mode(
458+
NumbaLinker(),
459+
RewriteDatabaseQuery(
460+
include=["fast_run", "numba"],
461+
exclude=[
462+
"cxx_only",
463+
"BlasOpt",
464+
"local_careduce_fusion",
465+
"scan_save_mem_prealloc",
466+
],
467+
),
468+
)
469+
457470
JAX = Mode(
458471
JAXLinker(),
459472
RewriteDatabaseQuery(
@@ -463,6 +476,7 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs):
463476
"BlasOpt",
464477
"fusion",
465478
"inplace",
479+
"scan_save_mem_prealloc",
466480
],
467481
),
468482
)
@@ -476,16 +490,10 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs):
476490
"fusion",
477491
"inplace",
478492
"local_uint_constant_indices",
493+
"scan_save_mem_prealloc",
479494
],
480495
),
481496
)
482-
NUMBA = Mode(
483-
NumbaLinker(),
484-
RewriteDatabaseQuery(
485-
include=["fast_run", "numba"],
486-
exclude=["cxx_only", "BlasOpt", "local_careduce_fusion"],
487-
),
488-
)
489497

490498

491499
predefined_modes = {

pytensor/configdefaults.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1085,7 +1085,9 @@ def add_scan_configvars():
10851085
"scan__allow_output_prealloc",
10861086
"Allow/disallow memory preallocation for outputs inside of scan "
10871087
"(default: True)",
1088-
BoolParam(True),
1088+
# Non-mutable because ScanSaveMem rewrite checks it,
1089+
# and we can't have the rewrite and the implementation mismatch
1090+
BoolParam(True, mutable=False),
10891091
in_c_key=False,
10901092
)
10911093

pytensor/scan/rewriting.py

Lines changed: 49 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@
7070
get_slice_elements,
7171
set_subtensor,
7272
)
73-
from pytensor.tensor.variable import TensorConstant
73+
from pytensor.tensor.variable import TensorConstant, TensorVariable
7474

7575

7676
list_opt_slice = [
@@ -1182,8 +1182,7 @@ def while_scan_merge_subtensor_last_element(fgraph, scan_node):
11821182
return subtensor_merge_replacements
11831183

11841184

1185-
@node_rewriter([Scan])
1186-
def scan_save_mem(fgraph, node):
1185+
def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation: bool):
11871186
r"""Graph optimizer that reduces scan memory consumption.
11881187
11891188
This optimizations attempts to determine if a `Scan` node, during its execution,
@@ -1214,10 +1213,16 @@ def scan_save_mem(fgraph, node):
12141213
12151214
The scan perform implementation takes the output sizes into consideration,
12161215
saving the newest results over the oldest ones whenever the buffer is filled.
1217-
"""
1218-
if not isinstance(node.op, Scan):
1219-
return False
12201216
1217+
Paramaters
1218+
----------
1219+
backend_supports_output_pre_allocation: bool
1220+
When the backend supports output pre-allocation Scan must keep buffers
1221+
with a length of required_states + 1, because the inner function will
1222+
attempt to write the inner function outputs directly into the provided
1223+
position in the outer circular buffer. This would invalidate results,
1224+
if the input is still needed for some other output computation.
1225+
"""
12211226
if hasattr(fgraph, "shape_feature"):
12221227
shape_of = fgraph.shape_feature.shape_of
12231228
else:
@@ -1270,14 +1275,15 @@ def scan_save_mem(fgraph, node):
12701275
# Note: For simplicity while Scans also have global_nsteps set to None.
12711276
# All step optimizations require knowing the shape of the output, which
12721277
# cannot be determined from the inputs alone.
1278+
global_nsteps: None | dict
12731279
assert len(node.outputs) >= c_outs
12741280
if len(node.outputs) == c_outs and not op.info.as_while:
12751281
global_nsteps = {"real": -1, "sym": []}
12761282
else:
12771283
global_nsteps = None
12781284

12791285
# Keeps track of the original slices that each client represent
1280-
slices = [None for o in node.outputs]
1286+
slices: list[None | list] = [None for o in node.outputs]
12811287

12821288
# A list for each output indicating how many intermediate values
12831289
# should be stored. If negative it means none of the intermediate
@@ -1294,7 +1300,7 @@ def scan_save_mem(fgraph, node):
12941300
# or not
12951301
flag_store = False
12961302

1297-
# 2.2 Loop over the clients
1303+
# 2.2 Loop over the clients to figure out how many steps we actually need to do in the Scan
12981304
for i, out in enumerate(node.outputs[:c_outs]):
12991305
# look at all its clients
13001306
slices[i] = []
@@ -1337,7 +1343,7 @@ def scan_save_mem(fgraph, node):
13371343
except KeyError:
13381344
length = out.shape[0]
13391345
cf_slice = get_canonical_form_slice(this_slice[0], length)
1340-
slices[i] += [(cf_slice, this_slice)]
1346+
slices[i] += [(cf_slice, this_slice)] # type: ignore
13411347

13421348
if isinstance(this_slice[0], slice) and this_slice[0].stop is None:
13431349
global_nsteps = None
@@ -1476,7 +1482,10 @@ def scan_save_mem(fgraph, node):
14761482
# for mitsots and sitsots (because mitmots are not
14771483
# currently supported by the mechanism) and only if
14781484
# the pre-allocation mechanism is activated.
1479-
prealloc_outs = config.scan__allow_output_prealloc
1485+
prealloc_outs = (
1486+
backend_supports_output_pre_allocation
1487+
and config.scan__allow_output_prealloc
1488+
)
14801489

14811490
first_mitsot_idx = op_info.n_mit_mot
14821491
last_sitsot_idx = (
@@ -1485,6 +1494,8 @@ def scan_save_mem(fgraph, node):
14851494
preallocable_output = first_mitsot_idx <= i <= last_sitsot_idx
14861495

14871496
if prealloc_outs and preallocable_output:
1497+
# TODO: If there's only one output or other outputs do not depend
1498+
# on the same input, we could reduce the buffer size to the minimum
14881499
pval = select_max(nw_steps - start + init_l[i], init_l[i] + 1)
14891500
else:
14901501
pval = select_max(nw_steps - start + init_l[i], init_l[i])
@@ -1651,7 +1662,7 @@ def scan_save_mem(fgraph, node):
16511662
name=op.name,
16521663
allow_gc=op.allow_gc,
16531664
)
1654-
new_outs = new_op(*node_ins, return_list=True)
1665+
new_outs = cast(list[TensorVariable], new_op(*node_ins, return_list=True))
16551666

16561667
old_new = []
16571668
# 3.7 Get replace pairs for those outputs that do not change
@@ -1681,7 +1692,7 @@ def scan_save_mem(fgraph, node):
16811692
sl_ins = get_slice_elements(
16821693
nw_slice, lambda entry: isinstance(entry, Variable)
16831694
)
1684-
new_o = subtens(new_outs[nw_pos], *sl_ins)
1695+
new_o = cast(TensorVariable, subtens(new_outs[nw_pos], *sl_ins))
16851696
if new_o.ndim > 0:
16861697
new_o = new_o[:: cnf_slice[1]]
16871698
replaced_outs.append(idx)
@@ -1736,7 +1747,7 @@ def scan_save_mem(fgraph, node):
17361747
sl_ins = get_slice_elements(
17371748
nw_slice, lambda entry: isinstance(entry, Variable)
17381749
)
1739-
new_o = subtens(new_outs[nw_pos], *sl_ins)
1750+
new_o = cast(TensorVariable, subtens(new_outs[nw_pos], *sl_ins))
17401751
if new_o.ndim > 0:
17411752
new_o = new_o[:: cnf_slice[1]]
17421753
old_new += [(old, new_o)]
@@ -1767,6 +1778,20 @@ def scan_save_mem(fgraph, node):
17671778
return False
17681779

17691780

1781+
@node_rewriter([Scan])
1782+
def scan_save_mem_prealloc(fgraph, node):
1783+
return scan_save_mem_rewrite(
1784+
fgraph, node, backend_supports_output_pre_allocation=True
1785+
)
1786+
1787+
1788+
@node_rewriter([Scan])
1789+
def scan_save_mem_no_prealloc(fgraph, node):
1790+
return scan_save_mem_rewrite(
1791+
fgraph, node, backend_supports_output_pre_allocation=False
1792+
)
1793+
1794+
17701795
class ScanMerge(GraphRewriter):
17711796
r"""Graph optimizer that merges different scan ops.
17721797
@@ -2494,12 +2519,21 @@ def scan_push_out_dot1(fgraph, node):
24942519
optdb.register("scan_eqopt2", scan_eqopt2, "fast_run", "scan", position=1.6)
24952520
# ScanSaveMem should execute only once per node.
24962521
optdb.register(
2497-
"scan_save_mem",
2498-
in2out(scan_save_mem, ignore_newtrees=True),
2522+
"scan_save_mem_prealloc",
2523+
in2out(scan_save_mem_prealloc, ignore_newtrees=True),
24992524
"fast_run",
25002525
"scan",
25012526
position=1.61,
25022527
)
2528+
optdb.register(
2529+
"scan_save_mem_no_prealloc",
2530+
in2out(scan_save_mem_no_prealloc, ignore_newtrees=True),
2531+
"numba",
2532+
"jax",
2533+
"pytorch",
2534+
use_db_name_as_tag=False,
2535+
position=1.61,
2536+
)
25032537
optdb.register(
25042538
"scan_make_inplace",
25052539
ScanInplaceOptimizer(),

pytensor/tensor/subtensor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import logging
22
import sys
33
import warnings
4-
from collections.abc import Callable, Iterable
4+
from collections.abc import Callable, Iterable, Sequence
55
from itertools import chain, groupby
66
from textwrap import dedent
77
from typing import cast, overload
@@ -645,7 +645,7 @@ def indexed_result_shape(array_shape, indices, indices_are_shapes=False):
645645

646646

647647
def get_slice_elements(
648-
idxs: list,
648+
idxs: Sequence,
649649
cond: Callable = lambda x: isinstance(x, Variable),
650650
) -> list:
651651
"""Extract slice elements conditional on a given predicate function.

tests/link/numba/test_scan.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -465,7 +465,7 @@ def buffer_tester(self, n_steps, op_size, buffer_size, benchmark=None):
465465
)
466466
if buffer_size == "unit":
467467
xs_kept = xs[-1] # Only last state is used
468-
expected_buffer_size = 2
468+
expected_buffer_size = 1
469469
elif buffer_size == "aligned":
470470
xs_kept = xs[-2:] # The buffer will be aligned at the end of the 9 steps
471471
expected_buffer_size = 2
@@ -555,8 +555,7 @@ def f_pow2(x_tm2, x_tm1):
555555
accept_inplace=True,
556556
on_unused_input="ignore",
557557
)
558-
assert tuple(mitsot_buffer_shape) == (3,)
559-
558+
assert tuple(mitsot_buffer_shape) == (2,)
560559
if benchmark is not None:
561560
numba_fn.trust_input = True
562561
benchmark(numba_fn, *test_vals)

tests/scan/test_rewriting.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -742,7 +742,7 @@ def rnn_step1(
742742
utt.assert_allclose(f_opt_output, f_no_opt_output)
743743

744744
def test_non_zero_init(self):
745-
"""Test the case where the initial value for the nitsot output is non-zero."""
745+
"""Test the case where the initial value for the sitsot output is non-zero."""
746746

747747
input1 = tensor3()
748748
input2 = tensor3()
@@ -759,8 +759,7 @@ def inner_fct(seq1, seq2, seq3, previous_output):
759759

760760
init = pt.as_tensor_variable(np.random.normal(size=(3, 7)))
761761

762-
# Compile the function twice, once with the optimization and once
763-
# without
762+
# Compile the function twice, once with the optimization and once without
764763
opt_mode = mode.including("scan")
765764
h, _ = pytensor.scan(
766765
inner_fct,
@@ -792,7 +791,7 @@ def inner_fct(seq1, seq2, seq3, previous_output):
792791
output_opt = f_opt(input1_value, input2_value, input3_value)
793792
output_no_opt = f_no_opt(input1_value, input2_value, input3_value)
794793

795-
utt.assert_allclose(output_opt, output_no_opt)
794+
np.testing.assert_allclose(output_opt, output_no_opt)
796795

797796

798797
class TestScanMerge:

0 commit comments

Comments
 (0)