Skip to content

Commit 76a9b4c

Browse files
ricardoV94aseyboldt
andcommitted
Implement new Loop and Scan operators
Co-authored-by: Adrian Seyboldt <[email protected]>
1 parent b8831aa commit 76a9b4c

File tree

6 files changed

+777
-0
lines changed

6 files changed

+777
-0
lines changed

pytensor/loop/__init__.py

Whitespace-only changes.

pytensor/loop/basic.py

+120
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
from typing import List, Tuple
2+
3+
import numpy as np
4+
5+
from pytensor import Variable, as_symbolic
6+
from pytensor.graph import FunctionGraph
7+
from pytensor.loop.op import Scan
8+
from pytensor.scan.utils import until
9+
from pytensor.tensor import as_tensor, empty_like
10+
11+
12+
def scan(
13+
fn,
14+
init_states=None,
15+
sequences=None,
16+
non_sequences=None,
17+
n_steps=None,
18+
go_backwards=False,
19+
) -> Tuple[List[Variable], List[Variable]]:
20+
if sequences is None and n_steps is None:
21+
raise ValueError("Must provide n_steps when scanning without sequences")
22+
23+
if init_states is None:
24+
init_states = []
25+
else:
26+
if not isinstance(init_states, (tuple, list)):
27+
init_states = [init_states]
28+
init_states = [as_symbolic(i) for i in init_states]
29+
30+
if sequences is None:
31+
sequences = []
32+
else:
33+
if not isinstance(sequences, (tuple, list)):
34+
sequences = [sequences]
35+
sequences = [as_tensor(s) for s in sequences]
36+
37+
if non_sequences is None:
38+
non_sequences = []
39+
else:
40+
if not isinstance(non_sequences, (tuple, list)):
41+
non_sequences = [non_sequences]
42+
non_sequences = [as_symbolic(n) for n in non_sequences]
43+
44+
# Note: Old scan order is sequences + init + non_sequences
45+
inner_sequences = [s[0] for s in sequences]
46+
inner_inputs = [i.type() for i in init_states + inner_sequences + non_sequences]
47+
inner_outputs = fn(*inner_inputs)
48+
if not isinstance(inner_outputs, (tuple, list)):
49+
inner_outputs = [inner_outputs]
50+
next_states = [out for out in inner_outputs if not isinstance(out, until)]
51+
52+
if len(next_states) > len(init_states):
53+
if not init_states:
54+
init_states = [None] * len(next_states)
55+
else:
56+
raise ValueError(
57+
"Please provide None as `init` for any output that is not carried over (i.e. it behaves like a map) "
58+
)
59+
60+
# Replace None init by dummy empty tensors
61+
prev_states = []
62+
for i, (init_state, next_state) in enumerate(zip(init_states, next_states)):
63+
if init_state is None:
64+
init_state = empty_like(next_state)
65+
init_state.name = "empty_init_state"
66+
inner_inputs.insert(i, init_state.type())
67+
prev_states.append(init_state)
68+
69+
until_condition = [out.condition for out in inner_outputs if isinstance(out, until)]
70+
if not until_condition:
71+
until_condition = [as_tensor(np.array(True))]
72+
if len(until_condition) > 1:
73+
raise ValueError("Only one until condition can be returned")
74+
75+
update_fg = FunctionGraph(
76+
inputs=inner_inputs, outputs=until_condition + next_states
77+
)
78+
scan_op = Scan(update_fg=update_fg, n_sequences=len(sequences))
79+
scan_outs = scan_op(n_steps, *prev_states, *sequences, *non_sequences)
80+
assert isinstance(scan_outs, list)
81+
last_states = scan_outs[: scan_op.n_states]
82+
traces = scan_outs[scan_op.n_states :]
83+
84+
return last_states, traces
85+
86+
87+
def map(
88+
fn,
89+
sequences,
90+
non_sequences=None,
91+
go_backwards=False,
92+
):
93+
_, traces = scan(
94+
fn=fn,
95+
sequences=sequences,
96+
non_sequences=non_sequences,
97+
go_backwards=go_backwards,
98+
)
99+
if len(traces) == 1:
100+
return traces[0]
101+
return traces
102+
103+
104+
def reduce(
105+
fn,
106+
init_states,
107+
sequences,
108+
non_sequences=None,
109+
go_backwards=False,
110+
):
111+
final_states, _ = scan(
112+
fn=fn,
113+
init_states=init_states,
114+
sequences=sequences,
115+
non_sequences=non_sequences,
116+
go_backwards=go_backwards,
117+
)
118+
if len(final_states) == 1:
119+
return final_states[0]
120+
return final_states

0 commit comments

Comments
 (0)