9
9
from pytensor .compile .mode import get_default_mode
10
10
from pytensor .configdefaults import config
11
11
from pytensor .gradient import grad , jacobian
12
- from pytensor .graph .basic import Constant , equal_computations
12
+ from pytensor .graph .basic import Constant , ancestors , equal_computations
13
13
from pytensor .graph .fg import FunctionGraph
14
14
from pytensor .graph .replace import clone_replace
15
15
from pytensor .scan .op import Scan
16
16
from pytensor .scan .rewriting import ScanInplaceOptimizer , ScanMerge
17
17
from pytensor .scan .utils import until
18
18
from pytensor .tensor import stack
19
+ from pytensor .tensor .basic import AllocEmpty
19
20
from pytensor .tensor .blas import Dot22
20
21
from pytensor .tensor .elemwise import Elemwise
21
22
from pytensor .tensor .math import Dot , dot , sigmoid , tanh
@@ -1207,7 +1208,7 @@ def test_inplace3(self):
1207
1208
1208
1209
1209
1210
class TestSaveMem :
1210
- mode = get_default_mode ().including ("scan_save_mem" )
1211
+ mode = get_default_mode ().including ("scan_save_mem" ). excluding ( "scan_pushout" )
1211
1212
1212
1213
def test_save_mem (self ):
1213
1214
rng = np .random .default_rng (utt .fetch_seed ())
@@ -1371,7 +1372,7 @@ def test_save_mem_cannot_reduce_constant_number_of_steps(self):
1371
1372
)
1372
1373
1373
1374
def test_save_mem_store_steps (self ):
1374
- def f_rnn (u_t , x1_tm1 , x1_tm3 , x2_tm1 , x3tm2 , x3_tm1 , x4_tm1 ):
1375
+ def step (u_t , x1_tm1 , x1_tm3 , x2_tm1 , x3tm2 , x3_tm1 , x4_tm1 ):
1375
1376
return (
1376
1377
u_t + 1.0 ,
1377
1378
u_t + 2.0 ,
@@ -1388,7 +1389,7 @@ def f_rnn(u_t, x1_tm1, x1_tm3, x2_tm1, x3tm2, x3_tm1, x4_tm1):
1388
1389
x30 = vector ("x30" )
1389
1390
x40 = scalar ("x40" )
1390
1391
[x1 , x2 , x3 , x4 , x5 , x6 , x7 ], updates = scan (
1391
- f_rnn ,
1392
+ step ,
1392
1393
u ,
1393
1394
[
1394
1395
None ,
@@ -1404,7 +1405,7 @@ def f_rnn(u_t, x1_tm1, x1_tm3, x2_tm1, x3tm2, x3_tm1, x4_tm1):
1404
1405
go_backwards = False ,
1405
1406
)
1406
1407
1407
- f2 = function (
1408
+ f = function (
1408
1409
[u , x10 , x20 , x30 , x40 ],
1409
1410
[x1 [- 7 ], x2 [- 3 :- 1 ], x3 [- 6 :], x4 [- 1 ], x5 [- 1 ]],
1410
1411
updates = updates ,
@@ -1417,13 +1418,51 @@ def f_rnn(u_t, x1_tm1, x1_tm3, x2_tm1, x3tm2, x3_tm1, x4_tm1):
1417
1418
v_u = rng .uniform (- 5.0 , 5.0 , size = (20 ,))
1418
1419
1419
1420
# compute the output in numpy
1420
- tx1 , tx2 , tx3 , tx4 , tx5 = f2 (v_u , [0 , 0 ], 0 , [0 , 0 ], 0 )
1421
-
1422
- utt .assert_allclose (tx1 , v_u [- 7 ] + 1.0 )
1423
- utt .assert_allclose (tx2 , v_u [- 3 :- 1 ] + 2.0 )
1424
- utt .assert_allclose (tx3 , v_u [- 6 :] + 3.0 )
1425
- utt .assert_allclose (tx4 , v_u [- 1 ] + 4.0 )
1426
- utt .assert_allclose (tx5 , v_u [- 1 ] + 5.0 )
1421
+ tx1 , tx2 , tx3 , tx4 , tx5 = f (v_u , [0 , 0 ], 0 , [0 , 0 ], 0 )
1422
+ rtol = 1e-7 if config .floatX == "float64" else 1e-6
1423
+ np .testing .assert_allclose (tx1 , v_u [- 7 ] + 1.0 , rtol = rtol )
1424
+ np .testing .assert_allclose (tx2 , v_u [- 3 :- 1 ] + 2.0 , rtol = rtol )
1425
+ np .testing .assert_allclose (tx3 , v_u [- 6 :] + 3.0 , rtol = rtol )
1426
+ np .testing .assert_allclose (tx4 , v_u [- 1 ] + 4.0 , rtol = rtol )
1427
+ np .testing .assert_allclose (tx5 , v_u [- 1 ] + 5.0 , rtol = rtol )
1428
+
1429
+ # Confirm reduction in buffer sizes
1430
+ [scan_node ] = [
1431
+ node for node in f .maker .fgraph .apply_nodes if isinstance (node .op , Scan )
1432
+ ]
1433
+ # x6 and x7 are dropped because they are not used
1434
+ [n_steps , seq , x4_buffer , x5_buffer , x1_len , x2_len , x3_len ] = scan_node .inputs
1435
+ [x4_underlying_alloc ] = [
1436
+ var
1437
+ for var in ancestors ([x4_buffer ])
1438
+ if var .owner and isinstance (var .owner .op , AllocEmpty )
1439
+ ]
1440
+ [x5_underlying_alloc ] = [
1441
+ var
1442
+ for var in ancestors ([x5_buffer ])
1443
+ if var .owner and isinstance (var .owner .op , AllocEmpty )
1444
+ ]
1445
+ buffer_lengths = pytensor .function (
1446
+ [u , x10 , x20 , x30 , x40 ],
1447
+ [
1448
+ x1_len ,
1449
+ x2_len ,
1450
+ x3_len ,
1451
+ x4_underlying_alloc .shape [0 ],
1452
+ x5_underlying_alloc .shape [0 ],
1453
+ ],
1454
+ accept_inplace = True ,
1455
+ on_unused_input = "ignore" ,
1456
+ allow_input_downcast = True ,
1457
+ )(v_u , [0 , 0 ], 0 , [0 , 0 ], 0 )
1458
+ # ScanSaveMem keeps +1 entries to handle taps with preallocated outputs
1459
+ assert [int (i ) for i in buffer_lengths ] == [
1460
+ 7 , # entry -7 of a map variable is kept, we need at least that many
1461
+ 3 , # entries [-3, -2] of a map variable are kept, we need at least 3
1462
+ 6 , # last six entries of a map variable are kept
1463
+ 2 + 1 , # last entry of a double tap variable is kept
1464
+ 1 + 1 , # last entry of a single tap variable is kept
1465
+ ]
1427
1466
1428
1467
def test_savemem_does_not_duplicate_number_of_scan_nodes (self ):
1429
1468
var = pt .ones (())
0 commit comments