Skip to content

Commit cb61b0c

Browse files
committed
Avoid large allocation for taps of length 1 in ScanSaveMem
1 parent 2d414d4 commit cb61b0c

File tree

3 files changed

+79
-29
lines changed

3 files changed

+79
-29
lines changed

pytensor/scan/rewriting.py

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
from pytensor.tensor.basic import (
5454
Alloc,
5555
AllocEmpty,
56+
atleast_Nd,
5657
get_scalar_constant_value,
5758
)
5859
from pytensor.tensor.elemwise import DimShuffle, Elemwise
@@ -1186,8 +1187,8 @@ def while_scan_merge_subtensor_last_element(fgraph, scan_node):
11861187
return subtensor_merge_replacements
11871188

11881189

1189-
def _is_default_scan_buffer(x: TensorVariable) -> bool:
1190-
node = x.owner
1190+
def _is_default_scan_buffer(final_buffer: TensorVariable, taps: int) -> bool:
1191+
node = final_buffer.owner
11911192

11921193
if node is None:
11931194
return False
@@ -1200,8 +1201,10 @@ def _is_default_scan_buffer(x: TensorVariable) -> bool:
12001201
):
12011202
return False
12021203

1203-
x, y, *_ = node.inputs
1204-
if not (x.owner is not None and isinstance(x.owner.op, AllocEmpty)):
1204+
init_buffer, init_value, *_ = node.inputs
1205+
if not (
1206+
init_buffer.owner is not None and isinstance(init_buffer.owner.op, AllocEmpty)
1207+
):
12051208
return False
12061209

12071210
# The value may have been broadcast to fill in the initial taps.
@@ -1218,10 +1221,16 @@ def _is_default_scan_buffer(x: TensorVariable) -> bool:
12181221
# 1. alloc_empty(2 + nsteps)[:2].broadcastable == x.broadcastable
12191222
# But due to laziness we use the slightly more conservative check:
12201223
# 2. alloc_empty(2 + nsteps).broadcastable == x.broadcastable
1221-
if broadcasted_by(y, x):
1222-
return False
1223-
1224-
return True
1224+
if taps > 1:
1225+
return not broadcasted_by(init_value, init_buffer)
1226+
else:
1227+
# In this case we know we have alloc_empty(1 + nsteps, ...)[:1].set(init_value)
1228+
# The first dimension cannot possibly broadcast in the subtensor assignment,
1229+
# so we exclude it from `broadcasted_by`. To exclude it we squeeze it out,
1230+
# after adding any other implicit expand_dims. We select into the first entry of
1231+
# the buffer, to check for potential broadcasting in other dimensions.
1232+
init_value_ = atleast_Nd(init_value, n=init_buffer.ndim)
1233+
return not broadcasted_by(init_value_.squeeze(0), init_buffer[0])
12251234

12261235

12271236
def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation: bool):
@@ -1574,15 +1583,16 @@ def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation:
15741583
# If the memory for this output has been pre-allocated
15751584
# before going into the scan op (by an alloc node)
15761585
if idx < op_info.n_mit_sot + op_info.n_sit_sot:
1586+
taps = init_l[i]
15771587
nw_input = nw_inputs[offset + idx]
15781588

15791589
# Recreate default buffers with new size
1580-
if _is_default_scan_buffer(nw_input):
1581-
extra_size = 1 if required_orphan else val - init_l[i]
1590+
if _is_default_scan_buffer(nw_input, taps):
1591+
extra_size = 1 if required_orphan else val - taps
15821592
nw_input = expand_empty(nw_input.owner.inputs[1], extra_size)
15831593
# Otherwise, just trim with a slice
15841594
else:
1585-
stop = init_l[i] if required_orphan else val
1595+
stop = taps if required_orphan else val
15861596
nw_input = nw_input[:stop]
15871597

15881598
nw_inputs[offset + idx] = nw_input
@@ -1626,14 +1636,13 @@ def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation:
16261636
# val == 0 means that we want to keep all intermediate
16271637
# results for that state, including the initial values.
16281638
if idx < op_info.n_mit_sot + op_info.n_sit_sot:
1639+
taps = init_l[op_info.n_mit_mot + idx]
16291640
in_idx = offset + idx
16301641
nw_input = nw_inputs[in_idx]
1631-
if _is_default_scan_buffer(nw_input):
1642+
if _is_default_scan_buffer(nw_input, taps):
16321643
nw_input = expand_empty(nw_input.owner.inputs[1], nw_steps)
16331644
else:
1634-
# Number of steps in the initial state
1635-
init_l_pt = pt.as_tensor(init_l[op_info.n_mit_mot + idx])
1636-
nw_input = nw_input[: (init_l_pt + nw_steps)]
1645+
nw_input = nw_input[: (taps + nw_steps)]
16371646
nw_inputs[in_idx] = nw_input
16381647

16391648
elif (

pytensor/tensor/rewriting/basic.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,9 +95,11 @@ def broadcasted_by(x: TensorVariable, y: TensorVariable) -> bool:
9595
"""
9696
bx = x.type.broadcastable
9797
by = y.type.broadcastable
98-
if len(bx) < len(by):
98+
bx_len = len(bx)
99+
by_len = len(by)
100+
if bx_len < by_len:
99101
return True
100-
bx = bx[-len(by) :]
102+
bx = bx[bx_len - by_len :]
101103
return any(bx_dim and not by_dim for bx_dim, by_dim in zip(bx, by, strict=True))
102104

103105

tests/scan/test_rewriting.py

Lines changed: 51 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,14 @@
99
from pytensor.compile.mode import get_default_mode
1010
from pytensor.configdefaults import config
1111
from pytensor.gradient import grad, jacobian
12-
from pytensor.graph.basic import Constant, equal_computations
12+
from pytensor.graph.basic import Constant, ancestors, equal_computations
1313
from pytensor.graph.fg import FunctionGraph
1414
from pytensor.graph.replace import clone_replace
1515
from pytensor.scan.op import Scan
1616
from pytensor.scan.rewriting import ScanInplaceOptimizer, ScanMerge
1717
from pytensor.scan.utils import until
1818
from pytensor.tensor import stack
19+
from pytensor.tensor.basic import AllocEmpty
1920
from pytensor.tensor.blas import Dot22
2021
from pytensor.tensor.elemwise import Elemwise
2122
from pytensor.tensor.math import Dot, dot, sigmoid, tanh
@@ -1207,7 +1208,7 @@ def test_inplace3(self):
12071208

12081209

12091210
class TestSaveMem:
1210-
mode = get_default_mode().including("scan_save_mem")
1211+
mode = get_default_mode().including("scan_save_mem").excluding("scan_pushout")
12111212

12121213
def test_save_mem(self):
12131214
rng = np.random.default_rng(utt.fetch_seed())
@@ -1371,7 +1372,7 @@ def test_save_mem_cannot_reduce_constant_number_of_steps(self):
13711372
)
13721373

13731374
def test_save_mem_store_steps(self):
1374-
def f_rnn(u_t, x1_tm1, x1_tm3, x2_tm1, x3tm2, x3_tm1, x4_tm1):
1375+
def step(u_t, x1_tm1, x1_tm3, x2_tm1, x3tm2, x3_tm1, x4_tm1):
13751376
return (
13761377
u_t + 1.0,
13771378
u_t + 2.0,
@@ -1388,7 +1389,7 @@ def f_rnn(u_t, x1_tm1, x1_tm3, x2_tm1, x3tm2, x3_tm1, x4_tm1):
13881389
x30 = vector("x30")
13891390
x40 = scalar("x40")
13901391
[x1, x2, x3, x4, x5, x6, x7], updates = scan(
1391-
f_rnn,
1392+
step,
13921393
u,
13931394
[
13941395
None,
@@ -1404,7 +1405,7 @@ def f_rnn(u_t, x1_tm1, x1_tm3, x2_tm1, x3tm2, x3_tm1, x4_tm1):
14041405
go_backwards=False,
14051406
)
14061407

1407-
f2 = function(
1408+
f = function(
14081409
[u, x10, x20, x30, x40],
14091410
[x1[-7], x2[-3:-1], x3[-6:], x4[-1], x5[-1]],
14101411
updates=updates,
@@ -1417,13 +1418,51 @@ def f_rnn(u_t, x1_tm1, x1_tm3, x2_tm1, x3tm2, x3_tm1, x4_tm1):
14171418
v_u = rng.uniform(-5.0, 5.0, size=(20,))
14181419

14191420
# compute the output in numpy
1420-
tx1, tx2, tx3, tx4, tx5 = f2(v_u, [0, 0], 0, [0, 0], 0)
1421-
1422-
utt.assert_allclose(tx1, v_u[-7] + 1.0)
1423-
utt.assert_allclose(tx2, v_u[-3:-1] + 2.0)
1424-
utt.assert_allclose(tx3, v_u[-6:] + 3.0)
1425-
utt.assert_allclose(tx4, v_u[-1] + 4.0)
1426-
utt.assert_allclose(tx5, v_u[-1] + 5.0)
1421+
tx1, tx2, tx3, tx4, tx5 = f(v_u, [0, 0], 0, [0, 0], 0)
1422+
rtol = 1e-7 if config.floatX == "float64" else 1e-6
1423+
np.testing.assert_allclose(tx1, v_u[-7] + 1.0, rtol=rtol)
1424+
np.testing.assert_allclose(tx2, v_u[-3:-1] + 2.0, rtol=rtol)
1425+
np.testing.assert_allclose(tx3, v_u[-6:] + 3.0, rtol=rtol)
1426+
np.testing.assert_allclose(tx4, v_u[-1] + 4.0, rtol=rtol)
1427+
np.testing.assert_allclose(tx5, v_u[-1] + 5.0, rtol=rtol)
1428+
1429+
# Confirm reduction in buffer sizes
1430+
[scan_node] = [
1431+
node for node in f.maker.fgraph.apply_nodes if isinstance(node.op, Scan)
1432+
]
1433+
# x6 and x7 are dropped because they are not used
1434+
[n_steps, seq, x4_buffer, x5_buffer, x1_len, x2_len, x3_len] = scan_node.inputs
1435+
[x4_underlying_alloc] = [
1436+
var
1437+
for var in ancestors([x4_buffer])
1438+
if var.owner and isinstance(var.owner.op, AllocEmpty)
1439+
]
1440+
[x5_underlying_alloc] = [
1441+
var
1442+
for var in ancestors([x5_buffer])
1443+
if var.owner and isinstance(var.owner.op, AllocEmpty)
1444+
]
1445+
buffer_lengths = pytensor.function(
1446+
[u, x10, x20, x30, x40],
1447+
[
1448+
x1_len,
1449+
x2_len,
1450+
x3_len,
1451+
x4_underlying_alloc.shape[0],
1452+
x5_underlying_alloc.shape[0],
1453+
],
1454+
accept_inplace=True,
1455+
on_unused_input="ignore",
1456+
allow_input_downcast=True,
1457+
)(v_u, [0, 0], 0, [0, 0], 0)
1458+
# ScanSaveMem keeps +1 entries to handle taps with preallocated outputs
1459+
assert [int(i) for i in buffer_lengths] == [
1460+
7, # entry -7 of a map variable is kept, we need at least that many
1461+
3, # entries [-3, -2] of a map variable are kept, we need at least 3
1462+
6, # last six entries of a map variable are kept
1463+
2 + 1, # last entry of a double tap variable is kept
1464+
1 + 1, # last entry of a single tap variable is kept
1465+
]
14271466

14281467
def test_savemem_does_not_duplicate_number_of_scan_nodes(self):
14291468
var = pt.ones(())

0 commit comments

Comments
 (0)