Skip to content

Apply scan memory save rewrite to while scans #178

Closed
@ricardoV94

Description

@ricardoV94

Description

The rewrite save_mem_new_scan

def save_mem_new_scan(fgraph, node):

Seems to work with both static and dynamic for scan loops, but not with while loops.

import pytensor
import pytensor.tensor as pt
from pytensor.scan import until

x = pt.scalar("x")
n_steps = pt.iscalar("n_steps")

y, _, = pytensor.scan(
    lambda xtm1: xtm1 + 1,  # for loop
    # lambda xtm1: (xtm1 + 1, {}, until(xtm1 >= 100)),  # while loop
    outputs_info=[x],
    n_steps=n_steps, # dynamic
    # n_steps=100, # static
    strict=True,
)
# Save memory is triggered by choosing only last value
y = y[-1]

pytensor.config.optimizer_verbose = True
f = pytensor.function([x, n_steps], y, on_unused_input="ignore")

This can make a big difference in memory as well as performance as it avoids allocating large arrays for the outputs when these are not of interest (see #174)

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions