Skip to content

Speedup Scan in different backends #1281

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Mar 13, 2025
22 changes: 15 additions & 7 deletions pytensor/compile/mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,6 +454,19 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs):
RewriteDatabaseQuery(include=["fast_run", "py_only"]),
)

NUMBA = Mode(
NumbaLinker(),
RewriteDatabaseQuery(
include=["fast_run", "numba"],
exclude=[
"cxx_only",
"BlasOpt",
"local_careduce_fusion",
"scan_save_mem_prealloc",
],
),
)

JAX = Mode(
JAXLinker(),
RewriteDatabaseQuery(
Expand All @@ -463,6 +476,7 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs):
"BlasOpt",
"fusion",
"inplace",
"scan_save_mem_prealloc",
],
),
)
Expand All @@ -476,16 +490,10 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs):
"fusion",
"inplace",
"local_uint_constant_indices",
"scan_save_mem_prealloc",
],
),
)
NUMBA = Mode(
NumbaLinker(),
RewriteDatabaseQuery(
include=["fast_run", "numba"],
exclude=["cxx_only", "BlasOpt", "local_careduce_fusion"],
),
)


predefined_modes = {
Expand Down
4 changes: 3 additions & 1 deletion pytensor/configdefaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -1085,7 +1085,9 @@ def add_scan_configvars():
"scan__allow_output_prealloc",
"Allow/disallow memory preallocation for outputs inside of scan "
"(default: True)",
BoolParam(True),
# Non-mutable because ScanSaveMem rewrite checks it,
# and we can't have the rewrite and the implementation mismatch
BoolParam(True, mutable=False),
in_c_key=False,
)

Expand Down
2 changes: 1 addition & 1 deletion pytensor/link/jax/dispatch/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def scan(*outer_inputs):
# Extract JAX scan inputs
outer_inputs = list(outer_inputs)
n_steps = outer_inputs[0] # JAX `length`
seqs = op.outer_seqs(outer_inputs) # JAX `xs`
seqs = [seq[:n_steps] for seq in op.outer_seqs(outer_inputs)] # JAX `xs`
Copy link
Member Author

@ricardoV94 ricardoV94 Mar 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PyTensor Scan allows sequences to be longer than steps, in which case they just get unused. The Scan save memory rewrite doesn't bother with trimming them.

JAX however doesn't allow it. Fixing the constant nsteps optimization revealed this issue.


mit_sot_init = []
for tap, seq in zip(
Expand Down
39 changes: 32 additions & 7 deletions pytensor/link/numba/dispatch/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def range_arr(x):


@numba_funcify.register(Scan)
def numba_funcify_Scan(op, node, **kwargs):
def numba_funcify_Scan(op: Scan, node, **kwargs):
# Apply inner rewrites
# TODO: Not sure this is the right place to do this, should we have a rewrite that
# explicitly triggers the optimization of the inner graphs of Scan?
Expand All @@ -67,9 +67,32 @@ def numba_funcify_Scan(op, node, **kwargs):
.optimizer
)
fgraph = op.fgraph
# When the buffer can only hold one SITSOT or as as many MITSOT as there are taps,
# We must always discard the oldest tap, so it's safe to destroy it in the inner function.
# TODO: Allow inplace for MITMOT
destroyable_sitsot = [
inner_sitsot
for outer_sitsot, inner_sitsot in zip(
op.outer_sitsot(node.inputs), op.inner_sitsot(fgraph.inputs), strict=True
)
if outer_sitsot.type.shape[0] == 1
]
destroyable_mitsot = [
oldest_inner_mitmot
for outer_mitsot, oldest_inner_mitmot, taps in zip(
op.outer_mitsot(node.inputs),
op.oldest_inner_mitsot(fgraph.inputs),
op.info.mit_sot_in_slices,
strict=True,
)
if outer_mitsot.type.shape[0] == abs(min(taps))
]
destroyable = {*destroyable_sitsot, *destroyable_mitsot}
add_supervisor_to_fgraph(
fgraph=fgraph,
input_specs=[In(x, borrow=True, mutable=False) for x in fgraph.inputs],
input_specs=[
In(x, borrow=True, mutable=x in destroyable) for x in fgraph.inputs
],
accept_inplace=True,
)
rewriter(fgraph)
Expand Down Expand Up @@ -222,14 +245,16 @@ def add_output_storage_post_proc_stmt(
# the storage array.
# This is needed when the output storage array does not have a length
# equal to the number of taps plus `n_steps`.
# If the storage size only allows one entry, there's nothing to rotate
output_storage_post_proc_stmts.append(
dedent(
f"""
if (i + {tap_size}) > {storage_size}:
if 1 < {storage_size} < (i + {tap_size}):
{outer_in_name}_shift = (i + {tap_size}) % ({storage_size})
{outer_in_name}_left = {outer_in_name}[:{outer_in_name}_shift]
{outer_in_name}_right = {outer_in_name}[{outer_in_name}_shift:]
{outer_in_name} = np.concatenate(({outer_in_name}_right, {outer_in_name}_left))
if {outer_in_name}_shift > 0:
{outer_in_name}_left = {outer_in_name}[:{outer_in_name}_shift]
{outer_in_name}_right = {outer_in_name}[{outer_in_name}_shift:]
{outer_in_name} = np.concatenate(({outer_in_name}_right, {outer_in_name}_left))
"""
).strip()
)
Expand Down Expand Up @@ -417,4 +442,4 @@ def scan({", ".join(outer_in_names)}):

scan_op_fn = compile_function_src(scan_op_src, "scan", {**globals(), **global_env})

return numba_basic.numba_njit(scan_op_fn)
return numba_basic.numba_njit(scan_op_fn, boundscheck=False)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

boundscheck defaults to False in numba, do we override that?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably not found this easier to read

10 changes: 10 additions & 0 deletions pytensor/scan/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,16 @@ def inner_mitsot(self, list_inputs):
self.info.n_seqs + n_mitmot_taps : self.info.n_seqs + ntaps_upto_sit_sot
]

def oldest_inner_mitsot(self, list_inputs):
inner_mitsot_inputs = self.inner_mitsot(list_inputs)
oldest_inner_mitsot_inputs = []
offset = 0
for taps in self.info.mit_sot_in_slices:
oldest_tap = np.argmin(taps)
oldest_inner_mitsot_inputs += [inner_mitsot_inputs[offset + oldest_tap]]
offset += len(taps)
return oldest_inner_mitsot_inputs

def outer_mitsot(self, list_inputs):
offset = 1 + self.info.n_seqs + self.info.n_mit_mot
return list_inputs[offset : offset + self.info.n_mit_sot]
Expand Down
105 changes: 68 additions & 37 deletions pytensor/scan/rewriting.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import copy
import dataclasses
from itertools import chain
from sys import maxsize
from typing import cast

import numpy as np
Expand Down Expand Up @@ -71,7 +70,7 @@
get_slice_elements,
set_subtensor,
)
from pytensor.tensor.variable import TensorConstant
from pytensor.tensor.variable import TensorConstant, TensorVariable


list_opt_slice = [
Expand Down Expand Up @@ -1183,8 +1182,7 @@ def while_scan_merge_subtensor_last_element(fgraph, scan_node):
return subtensor_merge_replacements


@node_rewriter([Scan])
def scan_save_mem(fgraph, node):
def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation: bool):
r"""Graph optimizer that reduces scan memory consumption.

This optimizations attempts to determine if a `Scan` node, during its execution,
Expand Down Expand Up @@ -1215,10 +1213,16 @@ def scan_save_mem(fgraph, node):

The scan perform implementation takes the output sizes into consideration,
saving the newest results over the oldest ones whenever the buffer is filled.
"""
if not isinstance(node.op, Scan):
return False

Paramaters
----------
backend_supports_output_pre_allocation: bool
When the backend supports output pre-allocation Scan must keep buffers
with a length of required_states + 1, because the inner function will
attempt to write the inner function outputs directly into the provided
position in the outer circular buffer. This would invalidate results,
if the input is still needed for some other output computation.
"""
if hasattr(fgraph, "shape_feature"):
shape_of = fgraph.shape_feature.shape_of
else:
Expand Down Expand Up @@ -1271,14 +1275,15 @@ def scan_save_mem(fgraph, node):
# Note: For simplicity while Scans also have global_nsteps set to None.
# All step optimizations require knowing the shape of the output, which
# cannot be determined from the inputs alone.
global_nsteps: None | dict
assert len(node.outputs) >= c_outs
if len(node.outputs) == c_outs and not op.info.as_while:
global_nsteps = {"real": -1, "sym": []}
else:
global_nsteps = None

# Keeps track of the original slices that each client represent
slices = [None for o in node.outputs]
slices: list[None | list] = [None for o in node.outputs]

# A list for each output indicating how many intermediate values
# should be stored. If negative it means none of the intermediate
Expand All @@ -1295,7 +1300,7 @@ def scan_save_mem(fgraph, node):
# or not
flag_store = False

# 2.2 Loop over the clients
# 2.2 Loop over the clients to figure out how many steps we actually need to do in the Scan
for i, out in enumerate(node.outputs[:c_outs]):
# look at all its clients
slices[i] = []
Expand Down Expand Up @@ -1338,7 +1343,7 @@ def scan_save_mem(fgraph, node):
except KeyError:
length = out.shape[0]
cf_slice = get_canonical_form_slice(this_slice[0], length)
slices[i] += [(cf_slice, this_slice)]
slices[i] += [(cf_slice, this_slice)] # type: ignore

if isinstance(this_slice[0], slice) and this_slice[0].stop is None:
global_nsteps = None
Expand All @@ -1351,10 +1356,9 @@ def scan_save_mem(fgraph, node):
get_scalar_constant_value(cf_slice[0], raise_not_constant=False)
+ 1
)
if stop == maxsize or stop == get_scalar_constant_value(
length, raise_not_constant=False
):
if stop == get_scalar_constant_value(length, raise_not_constant=False):
stop = None
global_nsteps = None
else:
# there is a **gotcha** here ! Namely, scan returns an
# array that contains the initial state of the output
Expand All @@ -1366,21 +1370,13 @@ def scan_save_mem(fgraph, node):
# initial state)
stop = stop - init_l[i]

# 2.3.3 we might get away with less number of steps
# 2.3.3 we might get away with fewer steps
if stop is not None and global_nsteps is not None:
# yes if it is a tensor
if isinstance(stop, Variable):
global_nsteps["sym"] += [stop]
# not if it is maxsize
elif isinstance(stop, int) and stop == maxsize:
global_nsteps = None
# yes if it is a int k, 0 < k < maxsize
elif isinstance(stop, int) and global_nsteps["real"] < stop:
global_nsteps["real"] = stop
# yes if it is a int k, 0 < k < maxsize
elif isinstance(stop, int) and stop > 0:
pass
# not otherwise
elif isinstance(stop, int | np.integer):
global_nsteps["real"] = max(global_nsteps["real"], stop)
else:
global_nsteps = None

Expand Down Expand Up @@ -1430,9 +1426,18 @@ def scan_save_mem(fgraph, node):
store_steps[i] = 0
break

if isinstance(this_slice[0], slice) and this_slice[0].start is None:
store_steps[i] = 0
break
if isinstance(this_slice[0], slice):
start = this_slice[0].start
if isinstance(start, Constant):
start = start.data
# Don't do anything if the subtensor is starting from the beginning of the buffer
# Or just skipping the initial values (default output returned to the user).
# Trimming the initial values would require a roll to align the buffer once scan is done
# As it always starts writing at position [0+max(taps)], and ends up at position [:max(taps)]
# It's cheaper to just keep the initial values in the buffer and slice them away (default output)
if start in (0, None, init_l[i]):
store_steps[i] = 0
break

# Special case for recurrent outputs where only the last result
# is requested. This is needed for this rewrite to apply to
Expand Down Expand Up @@ -1478,7 +1483,10 @@ def scan_save_mem(fgraph, node):
# for mitsots and sitsots (because mitmots are not
# currently supported by the mechanism) and only if
# the pre-allocation mechanism is activated.
prealloc_outs = config.scan__allow_output_prealloc
prealloc_outs = (
backend_supports_output_pre_allocation
and config.scan__allow_output_prealloc
)

first_mitsot_idx = op_info.n_mit_mot
last_sitsot_idx = (
Expand All @@ -1487,6 +1495,8 @@ def scan_save_mem(fgraph, node):
preallocable_output = first_mitsot_idx <= i <= last_sitsot_idx

if prealloc_outs and preallocable_output:
# TODO: If there's only one output or other outputs do not depend
# on the same input, we could reduce the buffer size to the minimum
pval = select_max(nw_steps - start + init_l[i], init_l[i] + 1)
else:
pval = select_max(nw_steps - start + init_l[i], init_l[i])
Expand Down Expand Up @@ -1653,7 +1663,7 @@ def scan_save_mem(fgraph, node):
name=op.name,
allow_gc=op.allow_gc,
)
new_outs = new_op(*node_ins, return_list=True)
new_outs = cast(list[TensorVariable], new_op(*node_ins, return_list=True))

old_new = []
# 3.7 Get replace pairs for those outputs that do not change
Expand Down Expand Up @@ -1683,7 +1693,7 @@ def scan_save_mem(fgraph, node):
sl_ins = get_slice_elements(
nw_slice, lambda entry: isinstance(entry, Variable)
)
new_o = subtens(new_outs[nw_pos], *sl_ins)
new_o = cast(TensorVariable, subtens(new_outs[nw_pos], *sl_ins))
if new_o.ndim > 0:
new_o = new_o[:: cnf_slice[1]]
replaced_outs.append(idx)
Expand All @@ -1703,10 +1713,7 @@ def scan_save_mem(fgraph, node):
- init_l[pos]
+ store_steps[pos]
)
if (
cnf_slice[0].stop is not None
and cnf_slice[0].stop != maxsize
):
if cnf_slice[0].stop is not None:
stop = (
cnf_slice[0].stop
- nw_steps
Expand Down Expand Up @@ -1741,7 +1748,7 @@ def scan_save_mem(fgraph, node):
sl_ins = get_slice_elements(
nw_slice, lambda entry: isinstance(entry, Variable)
)
new_o = subtens(new_outs[nw_pos], *sl_ins)
new_o = cast(TensorVariable, subtens(new_outs[nw_pos], *sl_ins))
if new_o.ndim > 0:
new_o = new_o[:: cnf_slice[1]]
old_new += [(old, new_o)]
Expand Down Expand Up @@ -1772,6 +1779,20 @@ def scan_save_mem(fgraph, node):
return False


@node_rewriter([Scan])
def scan_save_mem_prealloc(fgraph, node):
return scan_save_mem_rewrite(
fgraph, node, backend_supports_output_pre_allocation=True
)


@node_rewriter([Scan])
def scan_save_mem_no_prealloc(fgraph, node):
return scan_save_mem_rewrite(
fgraph, node, backend_supports_output_pre_allocation=False
)


class ScanMerge(GraphRewriter):
r"""Graph optimizer that merges different scan ops.

Expand Down Expand Up @@ -2499,10 +2520,20 @@ def scan_push_out_dot1(fgraph, node):
optdb.register("scan_eqopt2", scan_eqopt2, "fast_run", "scan", position=1.6)
# ScanSaveMem should execute only once per node.
optdb.register(
"scan_save_mem",
in2out(scan_save_mem, ignore_newtrees=True),
"scan_save_mem_prealloc",
in2out(scan_save_mem_prealloc, ignore_newtrees=True),
"fast_run",
"scan",
"scan_save_mem",
position=1.61,
)
optdb.register(
"scan_save_mem_no_prealloc",
in2out(scan_save_mem_no_prealloc, ignore_newtrees=True),
"numba",
"jax",
"pytorch",
use_db_name_as_tag=False,
position=1.61,
)
optdb.register(
Expand Down
Loading