Skip to content

Commit e37a8c0

Browse files
committed
Do not try to save initial values buffer size in Scan
This will always require a roll at the end, for a minimal gain
1 parent 6135962 commit e37a8c0

File tree

3 files changed

+44
-33
lines changed

3 files changed

+44
-33
lines changed

pytensor/scan/rewriting.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1419,9 +1419,18 @@ def scan_save_mem(fgraph, node):
14191419
store_steps[i] = 0
14201420
break
14211421

1422-
if isinstance(this_slice[0], slice) and this_slice[0].start is None:
1423-
store_steps[i] = 0
1424-
break
1422+
if isinstance(this_slice[0], slice):
1423+
start = this_slice[0].start
1424+
if isinstance(start, Constant):
1425+
start = start.data
1426+
# Don't do anything if the subtensor is starting from the beginning of the buffer
1427+
# Or just skipping the initial values (default output returned to the user).
1428+
# Trimming the initial values would require a roll to align the buffer once scan is done
1429+
# As it always starts writing at position [0+max(taps)], and ends up at position [:max(taps)]
1430+
# It's cheaper to just keep the initial values in the buffer and slice them away (default output)
1431+
if start in (0, None, init_l[i]):
1432+
store_steps[i] = 0
1433+
break
14251434

14261435
# Special case for recurrent outputs where only the last result
14271436
# is requested. This is needed for this rewrite to apply to

tests/link/numba/test_scan.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -474,7 +474,7 @@ def buffer_tester(self, n_steps, op_size, buffer_size, benchmark=None):
474474
expected_buffer_size = 3
475475
elif buffer_size == "whole":
476476
xs_kept = xs # What users think is the whole buffer
477-
expected_buffer_size = n_steps - 1
477+
expected_buffer_size = n_steps
478478
elif buffer_size == "whole+init":
479479
xs_kept = xs.owner.inputs[0] # Whole buffer actually used by Scan
480480
expected_buffer_size = n_steps

tests/scan/test_printing.py

Lines changed: 31 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -643,35 +643,37 @@ def no_shared_fn(n, x_tm1, M):
643643
# (i.e. from `Scan._fn`)
644644
out = pytensor.function([M], out, updates=updates, mode="FAST_RUN")
645645

646-
expected_output = """Scan{scan_fn, while_loop=False, inplace=all} [id A] 2 (outer_out_sit_sot-0)
647-
├─ 20000 [id B] (n_steps)
648-
├─ [ 0 ... 998 19999] [id C] (outer_in_seqs-0)
649-
├─ SetSubtensor{:stop} [id D] 1 (outer_in_sit_sot-0)
650-
│ ├─ AllocEmpty{dtype='int64'} [id E] 0
651-
│ │ └─ 20000 [id B]
652-
│ ├─ [0] [id F]
653-
│ └─ 1 [id G]
654-
└─ <Tensor3(float64, shape=(20000, 2, 2))> [id H] (outer_in_non_seqs-0)
655-
656-
Inner graphs:
657-
658-
Scan{scan_fn, while_loop=False, inplace=all} [id A]
659-
← Composite{switch(lt(0, i0), 1, 0)} [id I] (inner_out_sit_sot-0)
660-
└─ Subtensor{i, j, k} [id J]
661-
├─ *2-<Tensor3(float64, shape=(20000, 2, 2))> [id K] -> [id H] (inner_in_non_seqs-0)
662-
├─ ScalarFromTensor [id L]
663-
│ └─ *0-<Scalar(int64, shape=())> [id M] -> [id C] (inner_in_seqs-0)
664-
├─ ScalarFromTensor [id N]
665-
│ └─ *1-<Scalar(int64, shape=())> [id O] -> [id D] (inner_in_sit_sot-0)
666-
└─ 0 [id P]
667-
668-
Composite{switch(lt(0, i0), 1, 0)} [id I]
669-
← Switch [id Q] 'o0'
670-
├─ LT [id R]
671-
│ ├─ 0 [id S]
672-
│ └─ i0 [id T]
673-
├─ 1 [id U]
674-
└─ 0 [id S]
646+
expected_output = """Subtensor{start:} [id A] 3
647+
├─ Scan{scan_fn, while_loop=False, inplace=all} [id B] 2 (outer_out_sit_sot-0)
648+
│ ├─ 20000 [id C] (n_steps)
649+
│ ├─ [ 0 ... 998 19999] [id D] (outer_in_seqs-0)
650+
│ ├─ SetSubtensor{:stop} [id E] 1 (outer_in_sit_sot-0)
651+
│ │ ├─ AllocEmpty{dtype='int64'} [id F] 0
652+
│ │ │ └─ 20001 [id G]
653+
│ │ ├─ [0] [id H]
654+
│ │ └─ 1 [id I]
655+
│ └─ <Tensor3(float64, shape=(20000, 2, 2))> [id J] (outer_in_non_seqs-0)
656+
└─ 1 [id I]
657+
658+
Inner graphs:
659+
660+
Scan{scan_fn, while_loop=False, inplace=all} [id B]
661+
← Composite{switch(lt(0, i0), 1, 0)} [id K] (inner_out_sit_sot-0)
662+
└─ Subtensor{i, j, k} [id L]
663+
├─ *2-<Tensor3(float64, shape=(20000, 2, 2))> [id M] -> [id J] (inner_in_non_seqs-0)
664+
├─ ScalarFromTensor [id N]
665+
│ └─ *0-<Scalar(int64, shape=())> [id O] -> [id D] (inner_in_seqs-0)
666+
├─ ScalarFromTensor [id P]
667+
│ └─ *1-<Scalar(int64, shape=())> [id Q] -> [id E] (inner_in_sit_sot-0)
668+
└─ 0 [id R]
669+
670+
Composite{switch(lt(0, i0), 1, 0)} [id K]
671+
← Switch [id S] 'o0'
672+
├─ LT [id T]
673+
│ ├─ 0 [id U]
674+
│ └─ i0 [id V]
675+
├─ 1 [id W]
676+
└─ 0 [id U]
675677
"""
676678

677679
output_str = debugprint(out, file="str", print_op_info=True)

0 commit comments

Comments
 (0)