9
9
from pytensor .compile .mode import get_default_mode
10
10
from pytensor .configdefaults import config
11
11
from pytensor .gradient import grad , jacobian
12
- from pytensor .graph .basic import equal_computations
12
+ from pytensor .graph .basic import Constant , equal_computations
13
13
from pytensor .graph .fg import FunctionGraph
14
14
from pytensor .graph .replace import clone_replace
15
15
from pytensor .scan .op import Scan
@@ -1208,7 +1208,7 @@ def test_inplace3(self):
1208
1208
1209
1209
1210
1210
class TestSaveMem :
1211
- mode = get_default_mode ().including ("scan_save_mem" , "scan_save_mem" )
1211
+ mode = get_default_mode ().including ("scan_save_mem" )
1212
1212
1213
1213
def test_save_mem (self ):
1214
1214
rng = np .random .default_rng (utt .fetch_seed ())
@@ -1295,12 +1295,26 @@ def f_rnn(u_t):
1295
1295
[x1 [:2 ], x2 [4 ], x3 [idx ], x4 [:idx ], x5 [- 10 ], x6 [- jdx ], x7 [:- jdx ]],
1296
1296
updates = updates ,
1297
1297
allow_input_downcast = True ,
1298
- mode = self .mode ,
1298
+ mode = self .mode . excluding ( "scan_push_out_seq" ) ,
1299
1299
)
1300
+ # Check we actually have a Scan in the compiled function
1301
+ [scan_node ] = [
1302
+ node for node in f2 .maker .fgraph .toposort () if isinstance (node .op , Scan )
1303
+ ]
1304
+
1300
1305
# get random initial values
1301
1306
rng = np .random .default_rng (utt .fetch_seed ())
1302
1307
v_u = rng .uniform (- 5.0 , 5.0 , size = (20 ,))
1303
1308
1309
+ # Check the number of steps is actually reduced from 20
1310
+ n_steps = scan_node .inputs [0 ]
1311
+ n_steps_fn = pytensor .function ([u , idx , jdx ], n_steps , accept_inplace = True )
1312
+ assert n_steps_fn (u = v_u , idx = 3 , jdx = 15 ) == 11 # x5[const=-10] requires 11 steps
1313
+ assert n_steps_fn (u = v_u , idx = 3 , jdx = 3 ) == 18 # x6[jdx=-3] requires 18 steps
1314
+ assert n_steps_fn (u = v_u , idx = 16 , jdx = 15 ) == 17 # x3[idx=16] requires 17 steps
1315
+ assert n_steps_fn (u = v_u , idx = - 5 , jdx = 15 ) == 16 # x3[idx=-5] requires 16 steps
1316
+ assert n_steps_fn (u = v_u , idx = 19 , jdx = 15 ) == 20 # x3[idx=19] requires 20 steps
1317
+
1304
1318
# compute the output in numpy
1305
1319
tx1 , tx2 , tx3 , tx4 , tx5 , tx6 , tx7 = f2 (v_u , 3 , 15 )
1306
1320
@@ -1312,6 +1326,26 @@ def f_rnn(u_t):
1312
1326
utt .assert_allclose (tx6 , v_u [- 15 ] + 6.0 )
1313
1327
utt .assert_allclose (tx7 , v_u [:- 15 ] + 7.0 )
1314
1328
1329
+ def test_save_mem_reducent_number_of_steps_constant (self ):
1330
+ x0 = pt .scalar ("x0" )
1331
+ xs , _ = scan (
1332
+ lambda xtm1 : xtm1 + 1 ,
1333
+ outputs_info = [x0 ],
1334
+ n_steps = 10 ,
1335
+ )
1336
+
1337
+ fn = function ([x0 ], xs [:5 ])
1338
+ [scan_node ] = [
1339
+ node for node in fn .maker .fgraph .toposort () if isinstance (node .op , Scan )
1340
+ ]
1341
+ n_steps = scan_node .inputs [0 ]
1342
+ assert isinstance (n_steps , Constant ) and n_steps .data == 5
1343
+
1344
+ np .testing .assert_allclose (
1345
+ fn (0 ),
1346
+ np .arange (1 , 11 )[:5 ],
1347
+ )
1348
+
1315
1349
def test_save_mem_store_steps (self ):
1316
1350
def f_rnn (u_t , x1_tm1 , x1_tm3 , x2_tm1 , x3tm2 , x3_tm1 , x4_tm1 ):
1317
1351
return (
0 commit comments