Closed
Description
Description
The rewrite save_mem_new_scan
pytensor/pytensor/scan/rewriting.py
Line 1119 in 8ad3317
Seems to work with both static and dynamic for scan loops, but not with while loops.
import pytensor
import pytensor.tensor as pt
from pytensor.scan import until
x = pt.scalar("x")
n_steps = pt.iscalar("n_steps")
y, _, = pytensor.scan(
lambda xtm1: xtm1 + 1, # for loop
# lambda xtm1: (xtm1 + 1, {}, until(xtm1 >= 100)), # while loop
outputs_info=[x],
n_steps=n_steps, # dynamic
# n_steps=100, # static
strict=True,
)
# Save memory is triggered by choosing only last value
y = y[-1]
pytensor.config.optimizer_verbose = True
f = pytensor.function([x, n_steps], y, on_unused_input="ignore")
This can make a big difference in memory as well as performance as it avoids allocating large arrays for the outputs when these are not of interest (see #174)