Skip to content

Commit e1f77bf

Browse files
committed
Fix constant number of steps reduction in ScanSaveMem rewrite
Also remove maxsize logic
1 parent 4d9c1c5 commit e1f77bf

File tree

2 files changed

+42
-22
lines changed

2 files changed

+42
-22
lines changed

pytensor/scan/rewriting.py

Lines changed: 5 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import copy
44
import dataclasses
55
from itertools import chain
6-
from sys import maxsize
76
from typing import cast
87

98
import numpy as np
@@ -1351,9 +1350,7 @@ def scan_save_mem(fgraph, node):
13511350
get_scalar_constant_value(cf_slice[0], raise_not_constant=False)
13521351
+ 1
13531352
)
1354-
if stop == maxsize or stop == get_scalar_constant_value(
1355-
length, raise_not_constant=False
1356-
):
1353+
if stop == get_scalar_constant_value(length, raise_not_constant=False):
13571354
stop = None
13581355
else:
13591356
# there is a **gotcha** here ! Namely, scan returns an
@@ -1366,21 +1363,13 @@ def scan_save_mem(fgraph, node):
13661363
# initial state)
13671364
stop = stop - init_l[i]
13681365

1369-
# 2.3.3 we might get away with less number of steps
1366+
# 2.3.3 we might get away with fewer steps
13701367
if stop is not None and global_nsteps is not None:
13711368
# yes if it is a tensor
13721369
if isinstance(stop, Variable):
13731370
global_nsteps["sym"] += [stop]
1374-
# not if it is maxsize
1375-
elif isinstance(stop, int) and stop == maxsize:
1376-
global_nsteps = None
1377-
# yes if it is a int k, 0 < k < maxsize
1378-
elif isinstance(stop, int) and global_nsteps["real"] < stop:
1379-
global_nsteps["real"] = stop
1380-
# yes if it is a int k, 0 < k < maxsize
1381-
elif isinstance(stop, int) and stop > 0:
1382-
pass
1383-
# not otherwise
1371+
elif isinstance(stop, int | np.integer):
1372+
global_nsteps["real"] = max(global_nsteps["real"], stop)
13841373
else:
13851374
global_nsteps = None
13861375

@@ -1703,10 +1692,7 @@ def scan_save_mem(fgraph, node):
17031692
- init_l[pos]
17041693
+ store_steps[pos]
17051694
)
1706-
if (
1707-
cnf_slice[0].stop is not None
1708-
and cnf_slice[0].stop != maxsize
1709-
):
1695+
if cnf_slice[0].stop is not None:
17101696
stop = (
17111697
cnf_slice[0].stop
17121698
- nw_steps

tests/scan/test_rewriting.py

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from pytensor.compile.mode import get_default_mode
1010
from pytensor.configdefaults import config
1111
from pytensor.gradient import grad, jacobian
12-
from pytensor.graph.basic import equal_computations
12+
from pytensor.graph.basic import Constant, equal_computations
1313
from pytensor.graph.fg import FunctionGraph
1414
from pytensor.graph.replace import clone_replace
1515
from pytensor.scan.op import Scan
@@ -1208,7 +1208,7 @@ def test_inplace3(self):
12081208

12091209

12101210
class TestSaveMem:
1211-
mode = get_default_mode().including("scan_save_mem", "scan_save_mem")
1211+
mode = get_default_mode().including("scan_save_mem")
12121212

12131213
def test_save_mem(self):
12141214
rng = np.random.default_rng(utt.fetch_seed())
@@ -1295,12 +1295,26 @@ def f_rnn(u_t):
12951295
[x1[:2], x2[4], x3[idx], x4[:idx], x5[-10], x6[-jdx], x7[:-jdx]],
12961296
updates=updates,
12971297
allow_input_downcast=True,
1298-
mode=self.mode,
1298+
mode=self.mode.excluding("scan_push_out_seq"),
12991299
)
1300+
# Check we actually have a Scan in the compiled function
1301+
[scan_node] = [
1302+
node for node in f2.maker.fgraph.toposort() if isinstance(node.op, Scan)
1303+
]
1304+
13001305
# get random initial values
13011306
rng = np.random.default_rng(utt.fetch_seed())
13021307
v_u = rng.uniform(-5.0, 5.0, size=(20,))
13031308

1309+
# Check the number of steps is actually reduced from 20
1310+
n_steps = scan_node.inputs[0]
1311+
n_steps_fn = pytensor.function([u, idx, jdx], n_steps, accept_inplace=True)
1312+
assert n_steps_fn(u=v_u, idx=3, jdx=15) == 11 # x5[const=-10] requires 11 steps
1313+
assert n_steps_fn(u=v_u, idx=3, jdx=3) == 18 # x6[jdx=-3] requires 18 steps
1314+
assert n_steps_fn(u=v_u, idx=16, jdx=15) == 17 # x3[idx=16] requires 17 steps
1315+
assert n_steps_fn(u=v_u, idx=-5, jdx=15) == 16 # x3[idx=-5] requires 16 steps
1316+
assert n_steps_fn(u=v_u, idx=19, jdx=15) == 20 # x3[idx=19] requires 20 steps
1317+
13041318
# compute the output in numpy
13051319
tx1, tx2, tx3, tx4, tx5, tx6, tx7 = f2(v_u, 3, 15)
13061320

@@ -1312,6 +1326,26 @@ def f_rnn(u_t):
13121326
utt.assert_allclose(tx6, v_u[-15] + 6.0)
13131327
utt.assert_allclose(tx7, v_u[:-15] + 7.0)
13141328

1329+
def test_save_mem_reducent_number_of_steps_constant(self):
1330+
x0 = pt.scalar("x0")
1331+
xs, _ = scan(
1332+
lambda xtm1: xtm1 + 1,
1333+
outputs_info=[x0],
1334+
n_steps=10,
1335+
)
1336+
1337+
fn = function([x0], xs[:5])
1338+
[scan_node] = [
1339+
node for node in fn.maker.fgraph.toposort() if isinstance(node.op, Scan)
1340+
]
1341+
n_steps = scan_node.inputs[0]
1342+
assert isinstance(n_steps, Constant) and n_steps.data == 5
1343+
1344+
np.testing.assert_allclose(
1345+
fn(0),
1346+
np.arange(1, 11)[:5],
1347+
)
1348+
13151349
def test_save_mem_store_steps(self):
13161350
def f_rnn(u_t, x1_tm1, x1_tm3, x2_tm1, x3tm2, x3_tm1, x4_tm1):
13171351
return (

0 commit comments

Comments
 (0)