@@ -339,39 +339,6 @@ def power_step(prior_result, x):
339
339
compare_numba_and_py ([A ], result , test_input_vals )
340
340
341
341
342
- @pytest .mark .parametrize ("n_steps_val" , [1 , 5 ])
343
- def test_scan_save_mem_basic (n_steps_val ):
344
- """Make sure we can handle storage changes caused by the `scan_save_mem` rewrite."""
345
-
346
- def f_pow2 (x_tm2 , x_tm1 ):
347
- return 2 * x_tm1 + x_tm2
348
-
349
- init_x = pt .dvector ("init_x" )
350
- n_steps = pt .iscalar ("n_steps" )
351
- output , _ = scan (
352
- f_pow2 ,
353
- sequences = [],
354
- outputs_info = [{"initial" : init_x , "taps" : [- 2 , - 1 ]}],
355
- non_sequences = [],
356
- n_steps = n_steps ,
357
- )
358
-
359
- state_val = np .array ([1.0 , 2.0 ])
360
-
361
- numba_mode = get_mode ("NUMBA" ).including ("scan_save_mem" )
362
- py_mode = Mode ("py" ).including ("scan_save_mem" )
363
-
364
- test_input_vals = (state_val , n_steps_val )
365
-
366
- compare_numba_and_py (
367
- [init_x , n_steps ],
368
- [output ],
369
- test_input_vals ,
370
- numba_mode = numba_mode ,
371
- py_mode = py_mode ,
372
- )
373
-
374
-
375
342
def test_grad_sitsot ():
376
343
def get_sum_of_grad (inp ):
377
344
scan_outputs , updates = scan (
@@ -482,3 +449,120 @@ def step(seq1, seq2, mitsot1, mitsot2, sitsot1):
482
449
np .testing .assert_array_almost_equal (numba_r , ref_r )
483
450
484
451
benchmark (numba_fn , * test .values ())
452
+
453
+
454
+ @pytest .mark .parametrize (
455
+ "buffer_size" , ("unit" , "aligned" , "misaligned" , "whole" , "whole+init" )
456
+ )
457
+ @pytest .mark .parametrize ("n_steps, op_size" , [(10 , 2 ), (512 , 2 ), (512 , 256 )])
458
+ class TestScanSITSOTBuffer :
459
+ def buffer_tester (self , n_steps , op_size , buffer_size , benchmark = None ):
460
+ x0 = pt .vector (shape = (op_size ,), dtype = "float64" )
461
+ xs , _ = pytensor .scan (
462
+ fn = lambda xtm1 : (xtm1 + 1 ),
463
+ outputs_info = [x0 ],
464
+ n_steps = n_steps - 1 , # 1- makes it easier to align/misalign
465
+ )
466
+ if buffer_size == "unit" :
467
+ xs_kept = xs [- 1 ] # Only last state is used
468
+ expected_buffer_size = 2
469
+ elif buffer_size == "aligned" :
470
+ xs_kept = xs [- 2 :] # The buffer will be aligned at the end of the 9 steps
471
+ expected_buffer_size = 2
472
+ elif buffer_size == "misaligned" :
473
+ xs_kept = xs [- 3 :] # The buffer will be misaligned at the end of the 9 steps
474
+ expected_buffer_size = 3
475
+ elif buffer_size == "whole" :
476
+ xs_kept = xs # What users think is the whole buffer
477
+ expected_buffer_size = n_steps - 1
478
+ elif buffer_size == "whole+init" :
479
+ xs_kept = xs .owner .inputs [0 ] # Whole buffer actually used by Scan
480
+ expected_buffer_size = n_steps
481
+
482
+ x_test = np .zeros (x0 .type .shape )
483
+ numba_fn , _ = compare_numba_and_py (
484
+ [x0 ],
485
+ [xs_kept ],
486
+ test_inputs = [x_test ],
487
+ numba_mode = "NUMBA" , # Default doesn't include optimizations
488
+ eval_obj_mode = False ,
489
+ )
490
+ [scan_node ] = [
491
+ node
492
+ for node in numba_fn .maker .fgraph .toposort ()
493
+ if isinstance (node .op , Scan )
494
+ ]
495
+ buffer = scan_node .inputs [1 ]
496
+ assert buffer .type .shape [0 ] == expected_buffer_size
497
+
498
+ if benchmark is not None :
499
+ numba_fn .trust_input = True
500
+ benchmark (numba_fn , x_test )
501
+
502
+ def test_sit_sot_buffer (self , n_steps , op_size , buffer_size ):
503
+ self .buffer_tester (n_steps , op_size , buffer_size , benchmark = None )
504
+
505
+ def test_sit_sot_buffer_benchmark (self , n_steps , op_size , buffer_size , benchmark ):
506
+ self .buffer_tester (n_steps , op_size , buffer_size , benchmark = benchmark )
507
+
508
+
509
+ @pytest .mark .parametrize ("constant_n_steps" , [False , True ])
510
+ @pytest .mark .parametrize ("n_steps_val" , [1 , 1000 ])
511
+ class TestScanMITSOTBuffer :
512
+ def buffer_tester (self , constant_n_steps , n_steps_val , benchmark = None ):
513
+ """Make sure we can handle storage changes caused by the `scan_save_mem` rewrite."""
514
+
515
+ def f_pow2 (x_tm2 , x_tm1 ):
516
+ return 2 * x_tm1 + x_tm2
517
+
518
+ init_x = pt .vector ("init_x" , shape = (2 ,))
519
+ n_steps = pt .iscalar ("n_steps" )
520
+ output , _ = scan (
521
+ f_pow2 ,
522
+ sequences = [],
523
+ outputs_info = [{"initial" : init_x , "taps" : [- 2 , - 1 ]}],
524
+ non_sequences = [],
525
+ n_steps = n_steps_val if constant_n_steps else n_steps ,
526
+ )
527
+
528
+ init_x_val = np .array ([1.0 , 2.0 ], dtype = init_x .type .dtype )
529
+ test_vals = (
530
+ [init_x_val ]
531
+ if constant_n_steps
532
+ else [init_x_val , np .asarray (n_steps_val , dtype = n_steps .type .dtype )]
533
+ )
534
+ numba_fn , _ = compare_numba_and_py (
535
+ [init_x ] if constant_n_steps else [init_x , n_steps ],
536
+ [output [- 1 ]],
537
+ test_vals ,
538
+ numba_mode = "NUMBA" ,
539
+ eval_obj_mode = False ,
540
+ )
541
+
542
+ if n_steps_val == 1 and constant_n_steps :
543
+ # There's no Scan in the graph when nsteps=constant(1)
544
+ return
545
+
546
+ # Check the buffer size as been optimized
547
+ [scan_node ] = [
548
+ node
549
+ for node in numba_fn .maker .fgraph .toposort ()
550
+ if isinstance (node .op , Scan )
551
+ ]
552
+ [mitsot_buffer ] = scan_node .op .outer_mitsot (scan_node .inputs )
553
+ mitsot_buffer_shape = mitsot_buffer .shape .eval (
554
+ {init_x : init_x_val , n_steps : n_steps_val },
555
+ accept_inplace = True ,
556
+ on_unused_input = "ignore" ,
557
+ )
558
+ assert tuple (mitsot_buffer_shape ) == (3 ,)
559
+
560
+ if benchmark is not None :
561
+ numba_fn .trust_input = True
562
+ benchmark (numba_fn , * test_vals )
563
+
564
+ def test_mit_sot_buffer (self , constant_n_steps , n_steps_val ):
565
+ self .buffer_tester (constant_n_steps , n_steps_val , benchmark = None )
566
+
567
+ def test_mit_sot_buffer_benchmark (self , constant_n_steps , n_steps_val , benchmark ):
568
+ self .buffer_tester (constant_n_steps , n_steps_val , benchmark = benchmark )
0 commit comments