Skip to content

Avoid default allocation for taps of length 1 in ScanSaveMem #1395

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 24 additions & 15 deletions pytensor/scan/rewriting.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
from pytensor.tensor.basic import (
Alloc,
AllocEmpty,
atleast_Nd,
get_scalar_constant_value,
)
from pytensor.tensor.elemwise import DimShuffle, Elemwise
Expand Down Expand Up @@ -1186,8 +1187,8 @@ def while_scan_merge_subtensor_last_element(fgraph, scan_node):
return subtensor_merge_replacements


def _is_default_scan_buffer(x: TensorVariable) -> bool:
node = x.owner
def _is_default_scan_buffer(final_buffer: TensorVariable, taps: int) -> bool:
node = final_buffer.owner

if node is None:
return False
Expand All @@ -1200,8 +1201,10 @@ def _is_default_scan_buffer(x: TensorVariable) -> bool:
):
return False

x, y, *_ = node.inputs
if not (x.owner is not None and isinstance(x.owner.op, AllocEmpty)):
init_buffer, init_value, *_ = node.inputs
if not (
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this the default buffer case? (just uninitialized)

Copy link
Member Author

@ricardoV94 ricardoV94 May 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

default buffer is an empty with enough space to hold all scan steps plus initial taps, initial taps are written to the beginning of the buffer with set_subtensor

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so is this the default case?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand the question, if it's not an AllocEmpty it's not a default buffer, hence why it returns False. Am I missing something?

init_buffer.owner is not None and isinstance(init_buffer.owner.op, AllocEmpty)
):
return False

# The value may have been broadcast to fill in the initial taps.
Expand All @@ -1218,10 +1221,16 @@ def _is_default_scan_buffer(x: TensorVariable) -> bool:
# 1. alloc_empty(2 + nsteps)[:2].broadcastable == x.broadcastable
# But due to laziness we use the slightly more conservative check:
# 2. alloc_empty(2 + nsteps).broadcastable == x.broadcastable
if broadcasted_by(y, x):
return False

return True
if taps > 1:
return not broadcasted_by(init_value, init_buffer)
else:
# In this case we know we have alloc_empty(1 + nsteps, ...)[:1].set(init_value)
# The first dimension cannot possibly broadcast in the subtensor assignment,
# so we exclude it from `broadcasted_by`. To exclude it we squeeze it out,
# after adding any other implicit expand_dims. We select into the first entry of
# the buffer, to check for potential broadcasting in other dimensions.
init_value_ = atleast_Nd(init_value, n=init_buffer.ndim)
return not broadcasted_by(init_value_.squeeze(0), init_buffer[0])


def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation: bool):
Expand Down Expand Up @@ -1574,15 +1583,16 @@ def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation:
# If the memory for this output has been pre-allocated
# before going into the scan op (by an alloc node)
if idx < op_info.n_mit_sot + op_info.n_sit_sot:
taps = init_l[i]
nw_input = nw_inputs[offset + idx]

# Recreate default buffers with new size
if _is_default_scan_buffer(nw_input):
extra_size = 1 if required_orphan else val - init_l[i]
if _is_default_scan_buffer(nw_input, taps):
extra_size = 1 if required_orphan else val - taps
nw_input = expand_empty(nw_input.owner.inputs[1], extra_size)
# Otherwise, just trim with a slice
else:
stop = init_l[i] if required_orphan else val
stop = taps if required_orphan else val
nw_input = nw_input[:stop]

nw_inputs[offset + idx] = nw_input
Expand Down Expand Up @@ -1626,14 +1636,13 @@ def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation:
# val == 0 means that we want to keep all intermediate
# results for that state, including the initial values.
if idx < op_info.n_mit_sot + op_info.n_sit_sot:
taps = init_l[op_info.n_mit_mot + idx]
in_idx = offset + idx
nw_input = nw_inputs[in_idx]
if _is_default_scan_buffer(nw_input):
if _is_default_scan_buffer(nw_input, taps):
nw_input = expand_empty(nw_input.owner.inputs[1], nw_steps)
else:
# Number of steps in the initial state
init_l_pt = pt.as_tensor(init_l[op_info.n_mit_mot + idx])
nw_input = nw_input[: (init_l_pt + nw_steps)]
nw_input = nw_input[: (taps + nw_steps)]
nw_inputs[in_idx] = nw_input

elif (
Expand Down
6 changes: 4 additions & 2 deletions pytensor/tensor/rewriting/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,11 @@ def broadcasted_by(x: TensorVariable, y: TensorVariable) -> bool:
"""
bx = x.type.broadcastable
by = y.type.broadcastable
if len(bx) < len(by):
bx_len = len(bx)
by_len = len(by)
if bx_len < by_len:
return True
bx = bx[-len(by) :]
bx = bx[bx_len - by_len :]
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this would fail with the infamous [-0:] edge case

return any(bx_dim and not by_dim for bx_dim, by_dim in zip(bx, by, strict=True))


Expand Down
63 changes: 51 additions & 12 deletions tests/scan/test_rewriting.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,14 @@
from pytensor.compile.mode import get_default_mode
from pytensor.configdefaults import config
from pytensor.gradient import grad, jacobian
from pytensor.graph.basic import Constant, equal_computations
from pytensor.graph.basic import Constant, ancestors, equal_computations
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.replace import clone_replace
from pytensor.scan.op import Scan
from pytensor.scan.rewriting import ScanInplaceOptimizer, ScanMerge
from pytensor.scan.utils import until
from pytensor.tensor import stack
from pytensor.tensor.basic import AllocEmpty
from pytensor.tensor.blas import Dot22
from pytensor.tensor.elemwise import Elemwise
from pytensor.tensor.math import Dot, dot, sigmoid, tanh
Expand Down Expand Up @@ -1207,7 +1208,7 @@ def test_inplace3(self):


class TestSaveMem:
mode = get_default_mode().including("scan_save_mem")
mode = get_default_mode().including("scan_save_mem").excluding("scan_pushout")

def test_save_mem(self):
rng = np.random.default_rng(utt.fetch_seed())
Expand Down Expand Up @@ -1371,7 +1372,7 @@ def test_save_mem_cannot_reduce_constant_number_of_steps(self):
)

def test_save_mem_store_steps(self):
def f_rnn(u_t, x1_tm1, x1_tm3, x2_tm1, x3tm2, x3_tm1, x4_tm1):
def step(u_t, x1_tm1, x1_tm3, x2_tm1, x3tm2, x3_tm1, x4_tm1):
return (
u_t + 1.0,
u_t + 2.0,
Expand All @@ -1388,7 +1389,7 @@ def f_rnn(u_t, x1_tm1, x1_tm3, x2_tm1, x3tm2, x3_tm1, x4_tm1):
x30 = vector("x30")
x40 = scalar("x40")
[x1, x2, x3, x4, x5, x6, x7], updates = scan(
f_rnn,
step,
u,
[
None,
Expand All @@ -1404,7 +1405,7 @@ def f_rnn(u_t, x1_tm1, x1_tm3, x2_tm1, x3tm2, x3_tm1, x4_tm1):
go_backwards=False,
)

f2 = function(
f = function(
[u, x10, x20, x30, x40],
[x1[-7], x2[-3:-1], x3[-6:], x4[-1], x5[-1]],
updates=updates,
Expand All @@ -1417,13 +1418,51 @@ def f_rnn(u_t, x1_tm1, x1_tm3, x2_tm1, x3tm2, x3_tm1, x4_tm1):
v_u = rng.uniform(-5.0, 5.0, size=(20,))

# compute the output in numpy
tx1, tx2, tx3, tx4, tx5 = f2(v_u, [0, 0], 0, [0, 0], 0)

utt.assert_allclose(tx1, v_u[-7] + 1.0)
utt.assert_allclose(tx2, v_u[-3:-1] + 2.0)
utt.assert_allclose(tx3, v_u[-6:] + 3.0)
utt.assert_allclose(tx4, v_u[-1] + 4.0)
utt.assert_allclose(tx5, v_u[-1] + 5.0)
tx1, tx2, tx3, tx4, tx5 = f(v_u, [0, 0], 0, [0, 0], 0)
rtol = 1e-7 if config.floatX == "float64" else 1e-6
np.testing.assert_allclose(tx1, v_u[-7] + 1.0, rtol=rtol)
np.testing.assert_allclose(tx2, v_u[-3:-1] + 2.0, rtol=rtol)
np.testing.assert_allclose(tx3, v_u[-6:] + 3.0, rtol=rtol)
np.testing.assert_allclose(tx4, v_u[-1] + 4.0, rtol=rtol)
np.testing.assert_allclose(tx5, v_u[-1] + 5.0, rtol=rtol)

# Confirm reduction in buffer sizes
[scan_node] = [
node for node in f.maker.fgraph.apply_nodes if isinstance(node.op, Scan)
]
# x6 and x7 are dropped because they are not used
[n_steps, seq, x4_buffer, x5_buffer, x1_len, x2_len, x3_len] = scan_node.inputs
[x4_underlying_alloc] = [
var
for var in ancestors([x4_buffer])
if var.owner and isinstance(var.owner.op, AllocEmpty)
]
[x5_underlying_alloc] = [
var
for var in ancestors([x5_buffer])
if var.owner and isinstance(var.owner.op, AllocEmpty)
]
buffer_lengths = pytensor.function(
[u, x10, x20, x30, x40],
[
x1_len,
x2_len,
x3_len,
x4_underlying_alloc.shape[0],
x5_underlying_alloc.shape[0],
],
accept_inplace=True,
on_unused_input="ignore",
allow_input_downcast=True,
)(v_u, [0, 0], 0, [0, 0], 0)
# ScanSaveMem keeps +1 entries to handle taps with preallocated outputs
assert [int(i) for i in buffer_lengths] == [
7, # entry -7 of a map variable is kept, we need at least that many
3, # entries [-3, -2] of a map variable are kept, we need at least 3
6, # last six entries of a map variable are kept
2 + 1, # last entry of a double tap variable is kept
1 + 1, # last entry of a single tap variable is kept
]

def test_savemem_does_not_duplicate_number_of_scan_nodes(self):
var = pt.ones(())
Expand Down