Skip to content

Commit 2c12d63

Browse files
ricardoV94aseyboldt
andcommitted
Implement new scan constructor user facing functions
Co-authored-by: Adrian Seyboldt <[email protected]>
1 parent 2928b3f commit 2c12d63

File tree

2 files changed

+316
-0
lines changed

2 files changed

+316
-0
lines changed

pytensor/loop/basic.py

+200
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,200 @@
1+
import functools
2+
from typing import List, Tuple
3+
4+
import numpy as np
5+
6+
from pytensor import Variable, as_symbolic, clone_replace
7+
from pytensor.graph import FunctionGraph
8+
from pytensor.graph.basic import Constant, truncated_graph_inputs
9+
from pytensor.loop.op import Scan
10+
from pytensor.scan.utils import until
11+
from pytensor.tensor import as_tensor, constant, empty_like, minimum
12+
13+
14+
def scan(
15+
fn,
16+
init_states=None,
17+
sequences=None,
18+
non_sequences=None,
19+
n_steps=None,
20+
go_backwards=False,
21+
) -> Tuple[List[Variable], List[Variable]]:
22+
if sequences is None and n_steps is None:
23+
raise ValueError("Must provide n_steps when scanning without sequences")
24+
25+
if init_states is None:
26+
init_states = []
27+
else:
28+
if not isinstance(init_states, (tuple, list)):
29+
init_states = [init_states]
30+
init_states = [as_symbolic(i) if i is not None else None for i in init_states]
31+
32+
if sequences is None:
33+
sequences = []
34+
else:
35+
if not isinstance(sequences, (tuple, list)):
36+
sequences = [sequences]
37+
sequences = [as_tensor(s) for s in sequences]
38+
39+
if sequences:
40+
leading_dims = [seq.shape[0] for seq in sequences]
41+
shortest_dim = functools.reduce(minimum, leading_dims)
42+
if n_steps is None:
43+
n_steps = shortest_dim
44+
else:
45+
n_steps = minimum(n_steps, shortest_dim)
46+
47+
if non_sequences is None:
48+
non_sequences = []
49+
else:
50+
if not isinstance(non_sequences, (tuple, list)):
51+
non_sequences = [non_sequences]
52+
non_sequences = [as_symbolic(n) for n in non_sequences]
53+
54+
# Create dummy inputs for the init state. The user function should not
55+
# draw any relationship with the outer initial states, since these are only
56+
# valid in the first iteration
57+
inner_states = [i.type() if i is not None else None for i in init_states]
58+
59+
# Create subsequence inputs for the inner function
60+
idx = constant(0, dtype="int64", name="idx")
61+
symbolic_idx = idx.type(name="idx")
62+
subsequences = [s[symbolic_idx] for s in sequences]
63+
64+
# Call user function to retrieve inner outputs. We use the same order as the old Scan,
65+
# although inner_states + subsequences + non_sequences seems more intuitive,
66+
# since subsequences are just a fancy non_sequence
67+
# We don't pass the non-carried outputs [init is None] to the inner function
68+
fn_inputs = (
69+
subsequences + [i for i in inner_states if i is not None] + non_sequences
70+
)
71+
fn_outputs = fn(*fn_inputs)
72+
if not isinstance(fn_outputs, (tuple, list)):
73+
fn_outputs = [fn_outputs]
74+
next_states = [out for out in fn_outputs if not isinstance(out, until)]
75+
76+
if len(next_states) > len(init_states):
77+
if not init_states:
78+
init_states = [None] * len(next_states)
79+
inner_states = init_states
80+
else:
81+
raise ValueError(
82+
"Please provide None as `init` for any output that is not carried over (i.e. it behaves like a map) "
83+
)
84+
85+
# Replace None init by dummy empty tensors
86+
prev_states = []
87+
prev_inner_states = []
88+
for i, (init_state, inner_state, next_state) in enumerate(
89+
zip(init_states, inner_states, next_states)
90+
):
91+
if init_state is None:
92+
# next_state may reference idx. We replace that by the initial value,
93+
# so that the shape of the dummy init state does not depend on it.
94+
[next_state] = clone_replace(
95+
output=[next_state], replace={symbolic_idx: idx}
96+
)
97+
init_state = empty_like(next_state)
98+
init_state.name = "empty_init_state"
99+
inner_state = init_state.type(name="dummy_state")
100+
prev_states.append(init_state)
101+
prev_inner_states.append(inner_state)
102+
103+
# Flip until to while condition
104+
while_condition = [~out.condition for out in fn_outputs if isinstance(out, until)]
105+
if not while_condition:
106+
while_condition = [as_tensor(np.array(True))]
107+
if len(while_condition) > 1:
108+
raise ValueError("Only one until condition can be returned")
109+
110+
fgraph_inputs = [symbolic_idx] + prev_inner_states + sequences + non_sequences
111+
fgraph_outputs = while_condition + [symbolic_idx + 1] + next_states
112+
113+
all_fgraph_inputs = truncated_graph_inputs(
114+
fgraph_outputs, ancestors_to_include=fgraph_inputs
115+
)
116+
extra_fgraph_inputs = [
117+
inp
118+
for inp in all_fgraph_inputs
119+
if (not isinstance(inp, Constant) and inp not in fgraph_inputs)
120+
]
121+
fgraph_inputs = fgraph_inputs + extra_fgraph_inputs
122+
update_fg = FunctionGraph(inputs=fgraph_inputs, outputs=fgraph_outputs)
123+
124+
scan_op = Scan(update_fg=update_fg)
125+
scan_outs = scan_op(
126+
n_steps, idx, *prev_states, *sequences, *non_sequences, *extra_fgraph_inputs
127+
)
128+
assert isinstance(scan_outs, list)
129+
last_states = scan_outs[: scan_op.n_states]
130+
traces = scan_outs[scan_op.n_states :]
131+
# Don't return the inner index state
132+
return last_states[1:], traces[1:]
133+
134+
135+
def map(
136+
fn,
137+
sequences,
138+
non_sequences=None,
139+
go_backwards=False,
140+
):
141+
_, traces = scan(
142+
fn=fn,
143+
sequences=sequences,
144+
non_sequences=non_sequences,
145+
go_backwards=go_backwards,
146+
)
147+
if len(traces) == 1:
148+
return traces[0]
149+
return traces
150+
151+
152+
def reduce(
153+
fn,
154+
init_states,
155+
sequences,
156+
non_sequences=None,
157+
go_backwards=False,
158+
):
159+
final_states, _ = scan(
160+
fn=fn,
161+
init_states=init_states,
162+
sequences=sequences,
163+
non_sequences=non_sequences,
164+
go_backwards=go_backwards,
165+
)
166+
if len(final_states) == 1:
167+
return final_states[0]
168+
return final_states
169+
170+
171+
def filter(
172+
fn,
173+
sequences,
174+
non_sequences=None,
175+
go_backwards=False,
176+
):
177+
if not isinstance(sequences, (tuple, list)):
178+
sequences = [sequences]
179+
180+
_, masks = scan(
181+
fn=fn,
182+
sequences=sequences,
183+
non_sequences=non_sequences,
184+
go_backwards=go_backwards,
185+
)
186+
187+
if not all(mask.dtype == "bool" for mask in masks):
188+
raise TypeError("The output of filter fn should be a boolean variable")
189+
if len(masks) == 1:
190+
masks = [masks[0]] * len(sequences)
191+
elif len(masks) != len(sequences):
192+
raise ValueError(
193+
"filter fn must return one variable or len(sequences), but it returned {len(masks)}"
194+
)
195+
196+
filtered_sequences = [seq[mask] for seq, mask in zip(sequences, masks)]
197+
198+
if len(filtered_sequences) == 1:
199+
return filtered_sequences[0]
200+
return filtered_sequences

tests/loop/test_basic.py

+116
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
import numpy as np
2+
3+
import pytensor
4+
from pytensor import function, grad
5+
from pytensor.loop.basic import filter, map, reduce, scan
6+
from pytensor.scan import until
7+
from pytensor.tensor import arange, eq, scalar, vector, zeros
8+
9+
10+
def test_scan_with_sequences():
11+
xs = vector("xs")
12+
ys = vector("ys")
13+
_, [zs] = scan(
14+
fn=lambda x, y: x * y,
15+
sequences=[xs, ys],
16+
)
17+
pytensor.dprint(ys, print_type=True)
18+
np.testing.assert_almost_equal(
19+
zs.eval({xs: np.arange(10), ys: np.arange(10)}),
20+
np.arange(10) ** 2,
21+
)
22+
23+
24+
def test_scan_with_carried_and_non_carried_states():
25+
x = scalar("x")
26+
_, [ys1, ys2] = scan(
27+
fn=lambda xtm1: (xtm1 + 1, (xtm1 + 1) * 2),
28+
init_states=[x, None],
29+
n_steps=10,
30+
)
31+
fn = function([x], [ys1, ys2])
32+
res = fn(-1)
33+
np.testing.assert_almost_equal(res[0], np.arange(10))
34+
np.testing.assert_almost_equal(res[1], np.arange(10) * 2)
35+
36+
37+
def test_scan_with_sequence_and_carried_state():
38+
xs = vector("xs")
39+
_, [ys] = scan(
40+
fn=lambda x, ytm1: (ytm1 + 1) * x,
41+
init_states=[zeros(())],
42+
sequences=[xs],
43+
)
44+
fn = function([xs], ys)
45+
np.testing.assert_almost_equal(fn([1, 2, 3]), [1, 4, 15])
46+
47+
48+
def test_scan_taking_grads_wrt_non_sequence():
49+
# Tests sequence + non-carried state
50+
xs = vector("xs")
51+
ys = xs**2
52+
53+
_, [J] = scan(
54+
lambda i, ys, xs: grad(ys[i], wrt=xs),
55+
sequences=arange(ys.shape[0]),
56+
non_sequences=[ys, xs],
57+
)
58+
59+
f = pytensor.function([xs], J)
60+
np.testing.assert_array_equal(f([4, 4]), np.c_[[8, 0], [0, 8]])
61+
62+
63+
def test_scan_taking_grads_wrt_sequence():
64+
# This is not possible with the old Scan
65+
xs = vector("xs")
66+
ys = xs**2
67+
68+
_, [J] = scan(
69+
lambda y, xs: grad(y, wrt=xs),
70+
sequences=[ys],
71+
non_sequences=[xs],
72+
)
73+
74+
f = pytensor.function([xs], J)
75+
np.testing.assert_array_equal(f([4, 4]), np.c_[[8, 0], [0, 8]])
76+
77+
78+
def test_while_scan():
79+
_, [xs] = scan(
80+
fn=lambda x: (x + 1, until((x + 1) >= 9)),
81+
init_states=[-1],
82+
n_steps=20,
83+
)
84+
85+
f = pytensor.function([], xs)
86+
np.testing.assert_array_equal(f(), np.arange(10))
87+
88+
89+
def test_map():
90+
xs = vector("xs")
91+
ys = map(
92+
fn=lambda x: x * 100,
93+
sequences=xs,
94+
)
95+
np.testing.assert_almost_equal(ys.eval({xs: np.arange(10)}), np.arange(10) * 100)
96+
97+
98+
def test_reduce():
99+
xs = vector("xs")
100+
y = reduce(
101+
fn=lambda x, acc: acc + x,
102+
init_states=zeros(()),
103+
sequences=xs,
104+
)
105+
np.testing.assert_almost_equal(
106+
y.eval({xs: np.arange(10)}), np.arange(10).cumsum()[-1]
107+
)
108+
109+
110+
def test_filter():
111+
xs = vector("xs")
112+
ys = filter(
113+
fn=lambda x: eq(x % 2, 0),
114+
sequences=xs,
115+
)
116+
np.testing.assert_array_equal(ys.eval({xs: np.arange(0, 20)}), np.arange(0, 20, 2))

0 commit comments

Comments
 (0)