Open
Description
import pytensor
import pytensor.tensor as pt
from pytensor.compile.mode import get_default_mode
n = pt.iscalar("n")
x0 = pt.vector("x0")
xs, _ = pytensor.scan(lambda xtm1: xtm1 + 1, outputs_info=[x0], n_steps=n)
mode = get_default_mode().including("scan_save_mem")
fn = pytensor.function([n, x0], xs, mode=mode, on_unused_input="ignore")
fn.dprint()
Subtensor{start:} [id A] 6
├─ Scan{scan_fn, while_loop=False, inplace=all} [id B] 5
│ ├─ n [id C]
│ └─ SetSubtensor{:stop} [id D] 4
│ ├─ AllocEmpty{dtype='float64'} [id E] 3
│ │ ├─ Add [id F] 2
│ │ │ ├─ 1 [id G]
│ │ │ └─ n [id C]
│ │ └─ Shape_i{0} [id H] 1
│ │ └─ x0 [id I]
│ ├─ ExpandDims{axis=0} [id J] 0
│ │ └─ x0 [id I]
│ └─ 1 [id K]
└─ 1 [id K]
Inner graphs:
Scan{scan_fn, while_loop=False, inplace=all} [id B]
← Add [id L]
├─ [1.] [id M]
└─ *0-<Vector(float64, shape=(?,))> [id N] -> [id D]
Want to add an option debug_print(print_inner_graphs=False)
that ommits the inner graphs