Skip to content

Commit f0e2107

Browse files
committed
Do more agressive scan memory saves in JIT backends
1 parent 50a7590 commit f0e2107

File tree

6 files changed

+75
-32
lines changed

6 files changed

+75
-32
lines changed

pytensor/compile/mode.py

+15-7
Original file line numberDiff line numberDiff line change
@@ -454,6 +454,19 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs):
454454
RewriteDatabaseQuery(include=["fast_run", "py_only"]),
455455
)
456456

457+
NUMBA = Mode(
458+
NumbaLinker(),
459+
RewriteDatabaseQuery(
460+
include=["fast_run", "numba"],
461+
exclude=[
462+
"cxx_only",
463+
"BlasOpt",
464+
"local_careduce_fusion",
465+
"scan_save_mem_prealloc",
466+
],
467+
),
468+
)
469+
457470
JAX = Mode(
458471
JAXLinker(),
459472
RewriteDatabaseQuery(
@@ -463,6 +476,7 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs):
463476
"BlasOpt",
464477
"fusion",
465478
"inplace",
479+
"scan_save_mem_prealloc",
466480
],
467481
),
468482
)
@@ -476,16 +490,10 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs):
476490
"fusion",
477491
"inplace",
478492
"local_uint_constant_indices",
493+
"scan_save_mem_prealloc",
479494
],
480495
),
481496
)
482-
NUMBA = Mode(
483-
NumbaLinker(),
484-
RewriteDatabaseQuery(
485-
include=["fast_run", "numba"],
486-
exclude=["cxx_only", "BlasOpt", "local_careduce_fusion"],
487-
),
488-
)
489497

490498

491499
predefined_modes = {

pytensor/configdefaults.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -1085,7 +1085,9 @@ def add_scan_configvars():
10851085
"scan__allow_output_prealloc",
10861086
"Allow/disallow memory preallocation for outputs inside of scan "
10871087
"(default: True)",
1088-
BoolParam(True),
1088+
# Non-mutable because ScanSaveMem rewrite checks it,
1089+
# and we can't have the rewrite and the implementation mismatch
1090+
BoolParam(True, mutable=False),
10891091
in_c_key=False,
10901092
)
10911093

pytensor/scan/rewriting.py

+50-15
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@
7070
get_slice_elements,
7171
set_subtensor,
7272
)
73-
from pytensor.tensor.variable import TensorConstant
73+
from pytensor.tensor.variable import TensorConstant, TensorVariable
7474

7575

7676
list_opt_slice = [
@@ -1182,8 +1182,7 @@ def while_scan_merge_subtensor_last_element(fgraph, scan_node):
11821182
return subtensor_merge_replacements
11831183

11841184

1185-
@node_rewriter([Scan])
1186-
def scan_save_mem(fgraph, node):
1185+
def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation: bool):
11871186
r"""Graph optimizer that reduces scan memory consumption.
11881187
11891188
This optimizations attempts to determine if a `Scan` node, during its execution,
@@ -1214,10 +1213,16 @@ def scan_save_mem(fgraph, node):
12141213
12151214
The scan perform implementation takes the output sizes into consideration,
12161215
saving the newest results over the oldest ones whenever the buffer is filled.
1217-
"""
1218-
if not isinstance(node.op, Scan):
1219-
return False
12201216
1217+
Paramaters
1218+
----------
1219+
backend_supports_output_pre_allocation: bool
1220+
When the backend supports output pre-allocation Scan must keep buffers
1221+
with a length of required_states + 1, because the inner function will
1222+
attempt to write the inner function outputs directly into the provided
1223+
position in the outer circular buffer. This would invalidate results,
1224+
if the input is still needed for some other output computation.
1225+
"""
12211226
if hasattr(fgraph, "shape_feature"):
12221227
shape_of = fgraph.shape_feature.shape_of
12231228
else:
@@ -1270,14 +1275,15 @@ def scan_save_mem(fgraph, node):
12701275
# Note: For simplicity while Scans also have global_nsteps set to None.
12711276
# All step optimizations require knowing the shape of the output, which
12721277
# cannot be determined from the inputs alone.
1278+
global_nsteps: None | dict
12731279
assert len(node.outputs) >= c_outs
12741280
if len(node.outputs) == c_outs and not op.info.as_while:
12751281
global_nsteps = {"real": -1, "sym": []}
12761282
else:
12771283
global_nsteps = None
12781284

12791285
# Keeps track of the original slices that each client represent
1280-
slices = [None for o in node.outputs]
1286+
slices: list[None | list] = [None for o in node.outputs]
12811287

12821288
# A list for each output indicating how many intermediate values
12831289
# should be stored. If negative it means none of the intermediate
@@ -1294,7 +1300,7 @@ def scan_save_mem(fgraph, node):
12941300
# or not
12951301
flag_store = False
12961302

1297-
# 2.2 Loop over the clients
1303+
# 2.2 Loop over the clients to figure out how many steps we actually need to do in the Scan
12981304
for i, out in enumerate(node.outputs[:c_outs]):
12991305
# look at all its clients
13001306
slices[i] = []
@@ -1337,7 +1343,7 @@ def scan_save_mem(fgraph, node):
13371343
except KeyError:
13381344
length = out.shape[0]
13391345
cf_slice = get_canonical_form_slice(this_slice[0], length)
1340-
slices[i] += [(cf_slice, this_slice)]
1346+
slices[i] += [(cf_slice, this_slice)] # type: ignore
13411347

13421348
if isinstance(this_slice[0], slice) and this_slice[0].stop is None:
13431349
global_nsteps = None
@@ -1477,7 +1483,10 @@ def scan_save_mem(fgraph, node):
14771483
# for mitsots and sitsots (because mitmots are not
14781484
# currently supported by the mechanism) and only if
14791485
# the pre-allocation mechanism is activated.
1480-
prealloc_outs = config.scan__allow_output_prealloc
1486+
prealloc_outs = (
1487+
backend_supports_output_pre_allocation
1488+
and config.scan__allow_output_prealloc
1489+
)
14811490

14821491
first_mitsot_idx = op_info.n_mit_mot
14831492
last_sitsot_idx = (
@@ -1486,6 +1495,8 @@ def scan_save_mem(fgraph, node):
14861495
preallocable_output = first_mitsot_idx <= i <= last_sitsot_idx
14871496

14881497
if prealloc_outs and preallocable_output:
1498+
# TODO: If there's only one output or other outputs do not depend
1499+
# on the same input, we could reduce the buffer size to the minimum
14891500
pval = select_max(nw_steps - start + init_l[i], init_l[i] + 1)
14901501
else:
14911502
pval = select_max(nw_steps - start + init_l[i], init_l[i])
@@ -1652,7 +1663,7 @@ def scan_save_mem(fgraph, node):
16521663
name=op.name,
16531664
allow_gc=op.allow_gc,
16541665
)
1655-
new_outs = new_op(*node_ins, return_list=True)
1666+
new_outs = cast(list[TensorVariable], new_op(*node_ins, return_list=True))
16561667

16571668
old_new = []
16581669
# 3.7 Get replace pairs for those outputs that do not change
@@ -1682,7 +1693,7 @@ def scan_save_mem(fgraph, node):
16821693
sl_ins = get_slice_elements(
16831694
nw_slice, lambda entry: isinstance(entry, Variable)
16841695
)
1685-
new_o = subtens(new_outs[nw_pos], *sl_ins)
1696+
new_o = cast(TensorVariable, subtens(new_outs[nw_pos], *sl_ins))
16861697
if new_o.ndim > 0:
16871698
new_o = new_o[:: cnf_slice[1]]
16881699
replaced_outs.append(idx)
@@ -1737,7 +1748,7 @@ def scan_save_mem(fgraph, node):
17371748
sl_ins = get_slice_elements(
17381749
nw_slice, lambda entry: isinstance(entry, Variable)
17391750
)
1740-
new_o = subtens(new_outs[nw_pos], *sl_ins)
1751+
new_o = cast(TensorVariable, subtens(new_outs[nw_pos], *sl_ins))
17411752
if new_o.ndim > 0:
17421753
new_o = new_o[:: cnf_slice[1]]
17431754
old_new += [(old, new_o)]
@@ -1768,6 +1779,20 @@ def scan_save_mem(fgraph, node):
17681779
return False
17691780

17701781

1782+
@node_rewriter([Scan])
1783+
def scan_save_mem_prealloc(fgraph, node):
1784+
return scan_save_mem_rewrite(
1785+
fgraph, node, backend_supports_output_pre_allocation=True
1786+
)
1787+
1788+
1789+
@node_rewriter([Scan])
1790+
def scan_save_mem_no_prealloc(fgraph, node):
1791+
return scan_save_mem_rewrite(
1792+
fgraph, node, backend_supports_output_pre_allocation=False
1793+
)
1794+
1795+
17711796
class ScanMerge(GraphRewriter):
17721797
r"""Graph optimizer that merges different scan ops.
17731798
@@ -2495,10 +2520,20 @@ def scan_push_out_dot1(fgraph, node):
24952520
optdb.register("scan_eqopt2", scan_eqopt2, "fast_run", "scan", position=1.6)
24962521
# ScanSaveMem should execute only once per node.
24972522
optdb.register(
2498-
"scan_save_mem",
2499-
in2out(scan_save_mem, ignore_newtrees=True),
2523+
"scan_save_mem_prealloc",
2524+
in2out(scan_save_mem_prealloc, ignore_newtrees=True),
25002525
"fast_run",
25012526
"scan",
2527+
"scan_save_mem",
2528+
position=1.61,
2529+
)
2530+
optdb.register(
2531+
"scan_save_mem_no_prealloc",
2532+
in2out(scan_save_mem_no_prealloc, ignore_newtrees=True),
2533+
"numba",
2534+
"jax",
2535+
"pytorch",
2536+
use_db_name_as_tag=False,
25022537
position=1.61,
25032538
)
25042539
optdb.register(

pytensor/tensor/subtensor.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import logging
22
import sys
33
import warnings
4-
from collections.abc import Callable, Iterable
4+
from collections.abc import Callable, Iterable, Sequence
55
from itertools import chain, groupby
66
from textwrap import dedent
77
from typing import cast, overload
@@ -645,7 +645,7 @@ def indexed_result_shape(array_shape, indices, indices_are_shapes=False):
645645

646646

647647
def get_slice_elements(
648-
idxs: list,
648+
idxs: Sequence,
649649
cond: Callable = lambda x: isinstance(x, Variable),
650650
) -> list:
651651
"""Extract slice elements conditional on a given predicate function.

tests/link/numba/test_scan.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -465,7 +465,7 @@ def buffer_tester(self, n_steps, op_size, buffer_size, benchmark=None):
465465
)
466466
if buffer_size == "unit":
467467
xs_kept = xs[-1] # Only last state is used
468-
expected_buffer_size = 2
468+
expected_buffer_size = 1
469469
elif buffer_size == "aligned":
470470
xs_kept = xs[-2:] # The buffer will be aligned at the end of the 9 steps
471471
expected_buffer_size = 2
@@ -555,8 +555,7 @@ def f_pow2(x_tm2, x_tm1):
555555
accept_inplace=True,
556556
on_unused_input="ignore",
557557
)
558-
assert tuple(mitsot_buffer_shape) == (3,)
559-
558+
assert tuple(mitsot_buffer_shape) == (2,)
560559
if benchmark is not None:
561560
numba_fn.trust_input = True
562561
benchmark(numba_fn, *test_vals)

tests/scan/test_rewriting.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -742,7 +742,7 @@ def rnn_step1(
742742
utt.assert_allclose(f_opt_output, f_no_opt_output)
743743

744744
def test_non_zero_init(self):
745-
"""Test the case where the initial value for the nitsot output is non-zero."""
745+
"""Test the case where the initial value for the sitsot output is non-zero."""
746746

747747
input1 = tensor3()
748748
input2 = tensor3()
@@ -759,8 +759,7 @@ def inner_fct(seq1, seq2, seq3, previous_output):
759759

760760
init = pt.as_tensor_variable(np.random.normal(size=(3, 7)))
761761

762-
# Compile the function twice, once with the optimization and once
763-
# without
762+
# Compile the function twice, once with the optimization and once without
764763
opt_mode = mode.including("scan")
765764
h, _ = pytensor.scan(
766765
inner_fct,
@@ -792,7 +791,7 @@ def inner_fct(seq1, seq2, seq3, previous_output):
792791
output_opt = f_opt(input1_value, input2_value, input3_value)
793792
output_no_opt = f_no_opt(input1_value, input2_value, input3_value)
794793

795-
utt.assert_allclose(output_opt, output_no_opt)
794+
np.testing.assert_allclose(output_opt, output_no_opt)
796795

797796

798797
class TestScanMerge:

0 commit comments

Comments
 (0)