Skip to content

Commit 7bcd42c

Browse files
committed
Keep outer graph visible to the scan user function, including sequences
Sequences are now demoted to being just another constant in the Scan Op. The user facing function creates the right indexing graph for iterating over sequences automatically. Some extra logic is added in the `scan_to_loop` rewrite to avoid creating duplicated indexes, while being on guard for Scans created elsewhere.
1 parent c46cd53 commit 7bcd42c

File tree

4 files changed

+197
-147
lines changed

4 files changed

+197
-147
lines changed

pytensor/loop/basic.py

+47-16
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
1+
import functools
12
from typing import List, Tuple
23

34
import numpy as np
45

5-
from pytensor import Variable, as_symbolic
6+
from pytensor import Variable, as_symbolic, clone_replace
67
from pytensor.graph import FunctionGraph
8+
from pytensor.graph.basic import Constant, truncated_graph_inputs
79
from pytensor.loop.op import Scan
810
from pytensor.scan.utils import until
9-
from pytensor.tensor import as_tensor, empty_like
11+
from pytensor.tensor import as_tensor, constant, empty_like, minimum
1012

1113

1214
def scan(
@@ -20,6 +22,8 @@ def scan(
2022
if sequences is None and n_steps is None:
2123
raise ValueError("Must provide n_steps when scanning without sequences")
2224

25+
# TODO: init_states should be made opaque to the inner function,
26+
# since any relationship to the outer graph no longer holds
2327
if init_states is None:
2428
init_states = []
2529
else:
@@ -34,20 +38,31 @@ def scan(
3438
sequences = [sequences]
3539
sequences = [as_tensor(s) for s in sequences]
3640

41+
if sequences:
42+
leading_dims = [seq.shape[0] for seq in sequences]
43+
shortest_dim = functools.reduce(minimum, leading_dims)
44+
if n_steps is None:
45+
n_steps = shortest_dim
46+
else:
47+
n_steps = minimum(n_steps, shortest_dim)
48+
3749
if non_sequences is None:
3850
non_sequences = []
3951
else:
4052
if not isinstance(non_sequences, (tuple, list)):
4153
non_sequences = [non_sequences]
4254
non_sequences = [as_symbolic(n) for n in non_sequences]
4355

56+
# Create subsequence inputs for the inner function
57+
idx = constant(0, dtype="int64", name="idx")
58+
symbolic_idx = idx.type(name="idx")
59+
subsequences = [s[symbolic_idx] for s in sequences]
4460
# Note: Old scan order is sequences + init + non_sequences
45-
inner_sequences = [s[0] for s in sequences]
46-
inner_inputs = [i.type() for i in init_states + inner_sequences + non_sequences]
47-
inner_outputs = fn(*inner_inputs)
48-
if not isinstance(inner_outputs, (tuple, list)):
49-
inner_outputs = [inner_outputs]
50-
next_states = [out for out in inner_outputs if not isinstance(out, until)]
61+
fn_inputs = init_states + subsequences + non_sequences
62+
fn_outputs = fn(*fn_inputs)
63+
if not isinstance(fn_outputs, (tuple, list)):
64+
fn_outputs = [fn_outputs]
65+
next_states = [out for out in fn_outputs if not isinstance(out, until)]
5166

5267
if len(next_states) > len(init_states):
5368
if not init_states:
@@ -61,27 +76,43 @@ def scan(
6176
prev_states = []
6277
for i, (init_state, next_state) in enumerate(zip(init_states, next_states)):
6378
if init_state is None:
79+
# next_state may reference idx, let's replace that by the initial value
80+
[next_state] = clone_replace(
81+
output=[next_state], replace={symbolic_idx: idx}
82+
)
6483
init_state = empty_like(next_state)
6584
init_state.name = "empty_init_state"
66-
inner_inputs.insert(i, init_state.type())
6785
prev_states.append(init_state)
6886

69-
until_condition = [out.condition for out in inner_outputs if isinstance(out, until)]
87+
until_condition = [out.condition for out in fn_outputs if isinstance(out, until)]
7088
if not until_condition:
7189
until_condition = [as_tensor(np.array(True))]
7290
if len(until_condition) > 1:
7391
raise ValueError("Only one until condition can be returned")
7492

75-
update_fg = FunctionGraph(
76-
inputs=inner_inputs, outputs=until_condition + next_states
93+
fgraph_inputs = [symbolic_idx] + prev_states + sequences + non_sequences
94+
fgraph_outputs = until_condition + [symbolic_idx + 1] + next_states
95+
96+
all_fgraph_inputs = truncated_graph_inputs(
97+
fgraph_outputs, ancestors_to_include=fgraph_inputs
98+
)
99+
extra_fgraph_inputs = [
100+
inp
101+
for inp in all_fgraph_inputs
102+
if (not isinstance(inp, Constant) and inp not in fgraph_inputs)
103+
]
104+
fgraph_inputs = fgraph_inputs + extra_fgraph_inputs
105+
update_fg = FunctionGraph(inputs=fgraph_inputs, outputs=fgraph_outputs)
106+
107+
scan_op = Scan(update_fg=update_fg)
108+
scan_outs = scan_op(
109+
n_steps, idx, *prev_states, *sequences, *non_sequences, *extra_fgraph_inputs
77110
)
78-
scan_op = Scan(update_fg=update_fg, n_sequences=len(sequences))
79-
scan_outs = scan_op(n_steps, *prev_states, *sequences, *non_sequences)
80111
assert isinstance(scan_outs, list)
81112
last_states = scan_outs[: scan_op.n_states]
82113
traces = scan_outs[scan_op.n_states :]
83-
84-
return last_states, traces
114+
# Don't return the inner index state
115+
return last_states[1:], traces[1:]
85116

86117

87118
def map(

0 commit comments

Comments
 (0)