Skip to content

Fix scan save memory rewrite bug #236

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions pytensor/scan/rewriting.py
Original file line number Diff line number Diff line change
Expand Up @@ -1339,7 +1339,7 @@ def save_mem_new_scan(fgraph, node):
stop = at.extract_constant(cf_slice[0].stop)
else:
stop = at.extract_constant(cf_slice[0]) + 1
if stop == maxsize or stop == length:
if stop == maxsize or stop == at.extract_constant(length):
stop = None
else:
# there is a **gotcha** here ! Namely, scan returns an
Expand Down Expand Up @@ -1516,13 +1516,17 @@ def save_mem_new_scan(fgraph, node):
if (
nw_inputs[offset + idx].owner
and isinstance(nw_inputs[offset + idx].owner.op, IncSubtensor)
and nw_inputs[offset + idx].owner.op.set_instead_of_inc
and isinstance(
nw_inputs[offset + idx].owner.op.idx_list[0], slice
)
):
assert isinstance(
nw_inputs[offset + idx].owner.op, IncSubtensor
# Don't try to create a smart Alloc, if set_subtensor is broadcasting the fill value
# As it happens in set_subtensor(empty(2)[:], 0)
and not (
nw_inputs[offset + idx].ndim
> nw_inputs[offset + idx].owner.inputs[1].ndim
)
):
_nw_input = nw_inputs[offset + idx].owner.inputs[1]
cval = at.as_tensor_variable(val)
initl = at.as_tensor_variable(init_l[i])
Expand Down
16 changes: 16 additions & 0 deletions tests/scan/test_rewriting.py
Original file line number Diff line number Diff line change
Expand Up @@ -1487,6 +1487,22 @@ def test_while_scan_taps_and_map(self):
assert stored_ys_steps == 2
assert stored_zs_steps == 1

def test_vector_zeros_init(self):
ys, _ = pytensor.scan(
fn=lambda ytm2, ytm1: ytm1 + ytm2,
outputs_info=[{"initial": at.zeros(2), "taps": range(-2, 0)}],
n_steps=100,
)

fn = pytensor.function([], ys[-50:], mode=self.mode)
assert tuple(fn().shape) == (50,)

# Check that rewrite worked
[scan_node] = (n for n in fn.maker.fgraph.apply_nodes if isinstance(n.op, Scan))
_, ys_trace = scan_node.inputs
debug_fn = pytensor.function([], ys_trace.shape[0], accept_inplace=True)
assert debug_fn() == 50


def test_inner_replace_dot():
"""
Expand Down