Skip to content

Commit 2957f7b

Browse files
committed
Fix bug in save_mem_new_scan due to broadcasting by set_subtensor
1 parent 05bed7b commit 2957f7b

File tree

2 files changed

+23
-3
lines changed

2 files changed

+23
-3
lines changed

pytensor/scan/rewriting.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -1516,13 +1516,17 @@ def save_mem_new_scan(fgraph, node):
15161516
if (
15171517
nw_inputs[offset + idx].owner
15181518
and isinstance(nw_inputs[offset + idx].owner.op, IncSubtensor)
1519+
and nw_inputs[offset + idx].owner.op.set_instead_of_inc
15191520
and isinstance(
15201521
nw_inputs[offset + idx].owner.op.idx_list[0], slice
15211522
)
1522-
):
1523-
assert isinstance(
1524-
nw_inputs[offset + idx].owner.op, IncSubtensor
1523+
# Don't try to create a smart Alloc, if set_subtensor is broadcasting the fill value
1524+
# As it happens in set_subtensor(empty(2)[:], 0)
1525+
and not (
1526+
nw_inputs[offset + idx].ndim
1527+
> nw_inputs[offset + idx].owner.inputs[1].ndim
15251528
)
1529+
):
15261530
_nw_input = nw_inputs[offset + idx].owner.inputs[1]
15271531
cval = at.as_tensor_variable(val)
15281532
initl = at.as_tensor_variable(init_l[i])

tests/scan/test_rewriting.py

+16
Original file line numberDiff line numberDiff line change
@@ -1487,6 +1487,22 @@ def test_while_scan_taps_and_map(self):
14871487
assert stored_ys_steps == 2
14881488
assert stored_zs_steps == 1
14891489

1490+
def test_vector_zeros_init(self):
1491+
ys, _ = pytensor.scan(
1492+
fn=lambda ytm2, ytm1: ytm1 + ytm2,
1493+
outputs_info=[{"initial": at.zeros(2), "taps": range(-2, 0)}],
1494+
n_steps=100,
1495+
)
1496+
1497+
fn = pytensor.function([], ys[-50:], mode=self.mode)
1498+
assert tuple(fn().shape) == (50,)
1499+
1500+
# Check that rewrite worked
1501+
[scan_node] = (n for n in fn.maker.fgraph.apply_nodes if isinstance(n.op, Scan))
1502+
_, ys_trace = scan_node.inputs
1503+
debug_fn = pytensor.function([], ys_trace.shape[0], accept_inplace=True)
1504+
assert debug_fn() == 50
1505+
14901506

14911507
def test_inner_replace_dot():
14921508
"""

0 commit comments

Comments
 (0)