Skip to content

Commit 6c953b3

Browse files
ricardoV94aseyboldt
andcommitted
Implement new scan constructor user facing functions
Co-authored-by: Adrian Seyboldt <[email protected]>
1 parent b16bd74 commit 6c953b3

File tree

2 files changed

+303
-0
lines changed

2 files changed

+303
-0
lines changed

pytensor/loop/basic.py

+199
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,199 @@
1+
import functools
2+
from typing import List, Tuple
3+
4+
import numpy as np
5+
6+
from pytensor import Variable, as_symbolic, clone_replace
7+
from pytensor.graph import FunctionGraph
8+
from pytensor.graph.basic import Constant, truncated_graph_inputs
9+
from pytensor.loop.op import Scan
10+
from pytensor.scan.utils import until
11+
from pytensor.tensor import as_tensor, constant, empty_like, minimum
12+
13+
14+
def scan(
15+
fn,
16+
init_states=None,
17+
sequences=None,
18+
non_sequences=None,
19+
n_steps=None,
20+
go_backwards=False,
21+
) -> Tuple[List[Variable], List[Variable]]:
22+
if sequences is None and n_steps is None:
23+
raise ValueError("Must provide n_steps when scanning without sequences")
24+
25+
if init_states is None:
26+
init_states = []
27+
else:
28+
if not isinstance(init_states, (tuple, list)):
29+
init_states = [init_states]
30+
init_states = [as_symbolic(i) if i is not None else None for i in init_states]
31+
32+
if sequences is None:
33+
sequences = []
34+
else:
35+
if not isinstance(sequences, (tuple, list)):
36+
sequences = [sequences]
37+
sequences = [as_tensor(s) for s in sequences]
38+
39+
if sequences:
40+
leading_dims = [seq.shape[0] for seq in sequences]
41+
shortest_dim = functools.reduce(minimum, leading_dims)
42+
if n_steps is None:
43+
n_steps = shortest_dim
44+
else:
45+
n_steps = minimum(n_steps, shortest_dim)
46+
47+
if non_sequences is None:
48+
non_sequences = []
49+
else:
50+
if not isinstance(non_sequences, (tuple, list)):
51+
non_sequences = [non_sequences]
52+
non_sequences = [as_symbolic(n) for n in non_sequences]
53+
54+
# Create dummy inputs for the init state. The user function should not
55+
# draw any relationship with the outer initial states, since these are only
56+
# valid in the first iteration
57+
inner_states = [i.type() if i is not None else None for i in init_states]
58+
59+
# Create subsequence inputs for the inner function
60+
idx = constant(0, dtype="int64", name="idx")
61+
symbolic_idx = idx.type(name="idx")
62+
subsequences = [s[symbolic_idx] for s in sequences]
63+
64+
# Call user function to retrieve inner outputs. We use the same order as the old Scan,
65+
# although inner_states + subsequences + non_sequences seems more intuitive,
66+
# since subsequences are just a fancy non_sequence
67+
# We don't pass the non-carried outputs [init is None] to the inner function
68+
fn_inputs = (
69+
subsequences + [i for i in inner_states if i is not None] + non_sequences
70+
)
71+
fn_outputs = fn(*fn_inputs)
72+
if not isinstance(fn_outputs, (tuple, list)):
73+
fn_outputs = [fn_outputs]
74+
next_states = [out for out in fn_outputs if not isinstance(out, until)]
75+
76+
if len(next_states) > len(init_states):
77+
if not init_states:
78+
init_states = [None] * len(next_states)
79+
inner_states = init_states
80+
else:
81+
raise ValueError(
82+
"Please provide None as `init` for any output that is not carried over (i.e. it behaves like a map) "
83+
)
84+
85+
# Replace None init by dummy empty tensors
86+
prev_states = []
87+
prev_inner_states = []
88+
for i, (init_state, inner_state, next_state) in enumerate(
89+
zip(init_states, inner_states, next_states)
90+
):
91+
if init_state is None:
92+
# next_state may reference idx. We replace that by the initial value,
93+
# so that the shape of the dummy init state does not depend on it.
94+
[next_state] = clone_replace(
95+
output=[next_state], replace={symbolic_idx: idx}
96+
)
97+
init_state = empty_like(next_state)
98+
init_state.name = "empty_init_state"
99+
inner_state = init_state.type(name="dummy_state")
100+
prev_states.append(init_state)
101+
prev_inner_states.append(inner_state)
102+
103+
until_condition = [out.condition for out in fn_outputs if isinstance(out, until)]
104+
if not until_condition:
105+
until_condition = [as_tensor(np.array(True))]
106+
if len(until_condition) > 1:
107+
raise ValueError("Only one until condition can be returned")
108+
109+
fgraph_inputs = [symbolic_idx] + prev_inner_states + sequences + non_sequences
110+
fgraph_outputs = until_condition + [symbolic_idx + 1] + next_states
111+
112+
all_fgraph_inputs = truncated_graph_inputs(
113+
fgraph_outputs, ancestors_to_include=fgraph_inputs
114+
)
115+
extra_fgraph_inputs = [
116+
inp
117+
for inp in all_fgraph_inputs
118+
if (not isinstance(inp, Constant) and inp not in fgraph_inputs)
119+
]
120+
fgraph_inputs = fgraph_inputs + extra_fgraph_inputs
121+
update_fg = FunctionGraph(inputs=fgraph_inputs, outputs=fgraph_outputs)
122+
123+
scan_op = Scan(update_fg=update_fg)
124+
scan_outs = scan_op(
125+
n_steps, idx, *prev_states, *sequences, *non_sequences, *extra_fgraph_inputs
126+
)
127+
assert isinstance(scan_outs, list)
128+
last_states = scan_outs[: scan_op.n_states]
129+
traces = scan_outs[scan_op.n_states :]
130+
# Don't return the inner index state
131+
return last_states[1:], traces[1:]
132+
133+
134+
def map(
135+
fn,
136+
sequences,
137+
non_sequences=None,
138+
go_backwards=False,
139+
):
140+
_, traces = scan(
141+
fn=fn,
142+
sequences=sequences,
143+
non_sequences=non_sequences,
144+
go_backwards=go_backwards,
145+
)
146+
if len(traces) == 1:
147+
return traces[0]
148+
return traces
149+
150+
151+
def reduce(
152+
fn,
153+
init_states,
154+
sequences,
155+
non_sequences=None,
156+
go_backwards=False,
157+
):
158+
final_states, _ = scan(
159+
fn=fn,
160+
init_states=init_states,
161+
sequences=sequences,
162+
non_sequences=non_sequences,
163+
go_backwards=go_backwards,
164+
)
165+
if len(final_states) == 1:
166+
return final_states[0]
167+
return final_states
168+
169+
170+
def filter(
171+
fn,
172+
sequences,
173+
non_sequences=None,
174+
go_backwards=False,
175+
):
176+
if not isinstance(sequences, (tuple, list)):
177+
sequences = [sequences]
178+
179+
_, masks = scan(
180+
fn=fn,
181+
sequences=sequences,
182+
non_sequences=non_sequences,
183+
go_backwards=go_backwards,
184+
)
185+
186+
if not all(mask.dtype == "bool" for mask in masks):
187+
raise TypeError("The output of filter fn should be a boolean variable")
188+
if len(masks) == 1:
189+
masks = [masks[0]] * len(sequences)
190+
elif len(masks) != len(sequences):
191+
raise ValueError(
192+
"filter fn must return one variable or len(sequences), but it returned {len(masks)}"
193+
)
194+
195+
filtered_sequences = [seq[mask] for seq, mask in zip(sequences, masks)]
196+
197+
if len(filtered_sequences) == 1:
198+
return filtered_sequences[0]
199+
return filtered_sequences

tests/loop/test_basic.py

+104
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
import numpy as np
2+
3+
import pytensor
4+
from pytensor import function, grad
5+
from pytensor.loop.basic import filter, map, reduce, scan
6+
from pytensor.tensor import arange, eq, scalar, vector, zeros
7+
8+
9+
def test_scan_with_sequences():
10+
xs = vector("xs")
11+
ys = vector("ys")
12+
_, [zs] = scan(
13+
fn=lambda x, y: x * y,
14+
sequences=[xs, ys],
15+
)
16+
pytensor.dprint(ys, print_type=True)
17+
np.testing.assert_almost_equal(
18+
zs.eval({xs: np.arange(10), ys: np.arange(10)}),
19+
np.arange(10) ** 2,
20+
)
21+
22+
23+
def test_scan_with_carried_and_non_carried_states():
24+
x = scalar("x")
25+
_, [ys1, ys2] = scan(
26+
fn=lambda xtm1: (xtm1 + 1, (xtm1 + 1) * 2),
27+
init_states=[x, None],
28+
n_steps=10,
29+
)
30+
fn = function([x], [ys1, ys2])
31+
res = fn(-1)
32+
np.testing.assert_almost_equal(res[0], np.arange(10))
33+
np.testing.assert_almost_equal(res[1], np.arange(10) * 2)
34+
35+
36+
def test_scan_with_sequence_and_carried_state():
37+
xs = vector("xs")
38+
_, [ys] = scan(
39+
fn=lambda x, ytm1: (ytm1 + 1) * x,
40+
init_states=[zeros(())],
41+
sequences=[xs],
42+
)
43+
fn = function([xs], ys)
44+
np.testing.assert_almost_equal(fn([1, 2, 3]), [1, 4, 15])
45+
46+
47+
def test_scan_taking_grads_wrt_non_sequence():
48+
# Tests sequence + non-carried state
49+
xs = vector("xs")
50+
ys = xs**2
51+
52+
_, [J] = scan(
53+
lambda i, ys, xs: grad(ys[i], wrt=xs),
54+
sequences=arange(ys.shape[0]),
55+
non_sequences=[ys, xs],
56+
)
57+
58+
f = pytensor.function([xs], J)
59+
np.testing.assert_array_equal(f([4, 4]), np.c_[[8, 0], [0, 8]])
60+
61+
62+
def test_scan_taking_grads_wrt_sequence():
63+
# This is not possible with the old Scan
64+
xs = vector("xs")
65+
ys = xs**2
66+
67+
_, [J] = scan(
68+
lambda y, xs: grad(y, wrt=xs),
69+
sequences=[ys],
70+
non_sequences=[xs],
71+
)
72+
73+
f = pytensor.function([xs], J)
74+
np.testing.assert_array_equal(f([4, 4]), np.c_[[8, 0], [0, 8]])
75+
76+
77+
def test_map():
78+
xs = vector("xs")
79+
ys = map(
80+
fn=lambda x: x * 100,
81+
sequences=xs,
82+
)
83+
np.testing.assert_almost_equal(ys.eval({xs: np.arange(10)}), np.arange(10) * 100)
84+
85+
86+
def test_reduce():
87+
xs = vector("xs")
88+
y = reduce(
89+
fn=lambda x, acc: acc + x,
90+
init_states=zeros(()),
91+
sequences=xs,
92+
)
93+
np.testing.assert_almost_equal(
94+
y.eval({xs: np.arange(10)}), np.arange(10).cumsum()[-1]
95+
)
96+
97+
98+
def test_filter():
99+
xs = vector("xs")
100+
ys = filter(
101+
fn=lambda x: eq(x % 2, 0),
102+
sequences=xs,
103+
)
104+
np.testing.assert_array_equal(ys.eval({xs: np.arange(0, 20)}), np.arange(0, 20, 2))

0 commit comments

Comments
 (0)