Skip to content

Commit 5f15c5e

Browse files
committed
Make scan helper return sequences to match old API
This was not possible prior to use of TypedListType for non TensorVariable sequences, as it would otherwise not be possible to represent indexing of last sequence state, which is needed e.g., for shared random generator updates.
1 parent a750fd7 commit 5f15c5e

File tree

3 files changed

+53
-32
lines changed

3 files changed

+53
-32
lines changed

pytensor/loop/basic.py

+17-18
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import functools
2-
from typing import List, Tuple
2+
from typing import List, Union
33

44
import numpy as np
55

@@ -18,7 +18,7 @@ def scan(
1818
non_sequences=None,
1919
n_steps=None,
2020
go_backwards=False,
21-
) -> Tuple[List[Variable], List[Variable]]:
21+
) -> Union[Variable, List[Variable]]:
2222
if sequences is None and n_steps is None:
2323
raise ValueError("Must provide n_steps when scanning without sequences")
2424

@@ -126,10 +126,11 @@ def scan(
126126
n_steps, idx, *prev_states, *sequences, *non_sequences, *extra_fgraph_inputs
127127
)
128128
assert isinstance(scan_outs, list)
129-
last_states = scan_outs[: scan_op.n_states]
130-
traces = scan_outs[scan_op.n_states :]
131-
# Don't return the inner index state
132-
return last_states[1:], traces[1:]
129+
# Don't return the last states or the trace for the inner index
130+
traces = scan_outs[scan_op.n_states + 1 :]
131+
if len(traces) == 1:
132+
return traces[0]
133+
return traces
133134

134135

135136
def map(
@@ -138,14 +139,12 @@ def map(
138139
non_sequences=None,
139140
go_backwards=False,
140141
):
141-
_, traces = scan(
142+
traces = scan(
142143
fn=fn,
143144
sequences=sequences,
144145
non_sequences=non_sequences,
145146
go_backwards=go_backwards,
146147
)
147-
if len(traces) == 1:
148-
return traces[0]
149148
return traces
150149

151150

@@ -156,16 +155,16 @@ def reduce(
156155
non_sequences=None,
157156
go_backwards=False,
158157
):
159-
final_states, _ = scan(
158+
traces = scan(
160159
fn=fn,
161160
init_states=init_states,
162161
sequences=sequences,
163162
non_sequences=non_sequences,
164163
go_backwards=go_backwards,
165164
)
166-
if len(final_states) == 1:
167-
return final_states[0]
168-
return final_states
165+
if not isinstance(traces, list):
166+
return traces[-1]
167+
return [trace[-1] for trace in traces]
169168

170169

171170
def filter(
@@ -177,21 +176,21 @@ def filter(
177176
if not isinstance(sequences, (tuple, list)):
178177
sequences = [sequences]
179178

180-
_, masks = scan(
179+
masks = scan(
181180
fn=fn,
182181
sequences=sequences,
183182
non_sequences=non_sequences,
184183
go_backwards=go_backwards,
185184
)
186185

187-
if not all(mask.dtype == "bool" for mask in masks):
188-
raise TypeError("The output of filter fn should be a boolean variable")
189-
if len(masks) == 1:
190-
masks = [masks[0]] * len(sequences)
186+
if not isinstance(masks, list):
187+
masks = [masks] * len(sequences)
191188
elif len(masks) != len(sequences):
192189
raise ValueError(
193190
"filter fn must return one variable or len(sequences), but it returned {len(masks)}"
194191
)
192+
if not all(mask.dtype == "bool" for mask in masks):
193+
raise TypeError("The output of filter fn should be a boolean variable")
195194

196195
filtered_sequences = [seq[mask] for seq, mask in zip(sequences, masks)]
197196

tests/link/jax/test_loop.py

+8-7
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,15 @@
1313

1414
def test_scan_with_single_sequence():
1515
xs = vector("xs")
16-
_, [ys] = scan(lambda x: x * 100, sequences=[xs])
16+
ys = scan(lambda x: x * 100, sequences=[xs])
1717

1818
out_fg = FunctionGraph([xs], [ys])
1919
compare_jax_and_py(out_fg, [np.arange(10)])
2020

2121

2222
def test_scan_with_single_sequence_shortened_by_nsteps():
2323
xs = vector("xs", shape=(10,)) # JAX needs the length to be constant
24-
_, [ys] = scan(
24+
ys = scan(
2525
lambda x: x * 100,
2626
sequences=[xs],
2727
n_steps=9,
@@ -35,7 +35,7 @@ def test_scan_with_multiple_sequences():
3535
# JAX can only handle constant n_steps
3636
xs = vector("xs", shape=(10,))
3737
ys = vector("ys", shape=(10,))
38-
_, [zs] = scan(
38+
zs = scan(
3939
fn=lambda x, y: x * y,
4040
sequences=[xs, ys],
4141
)
@@ -48,7 +48,7 @@ def test_scan_with_multiple_sequences():
4848

4949
def test_scan_with_carried_and_non_carried_states():
5050
x = scalar("x")
51-
_, [ys1, ys2] = scan(
51+
[ys1, ys2] = scan(
5252
fn=lambda xtm1: (xtm1 + 1, (xtm1 + 1) * 2),
5353
init_states=[x, None],
5454
n_steps=10,
@@ -59,7 +59,7 @@ def test_scan_with_carried_and_non_carried_states():
5959

6060
def test_scan_with_sequence_and_carried_state():
6161
xs = vector("xs")
62-
_, [ys] = scan(
62+
ys = scan(
6363
fn=lambda x, ytm1: (ytm1 + 1) * x,
6464
init_states=[zeros(())],
6565
sequences=[xs],
@@ -71,11 +71,12 @@ def test_scan_with_sequence_and_carried_state():
7171
def test_scan_with_rvs():
7272
rng = shared(np.random.default_rng(123))
7373

74-
[final_rng, _], [rngs, xs] = scan(
74+
[rngs, xs] = scan(
7575
fn=lambda prev_rng: normal(rng=prev_rng).owner.outputs,
7676
init_states=[rng, None],
7777
n_steps=10,
7878
)
79+
final_rng = rngs[-1]
7980

8081
# First without updates
8182
fn = function([], xs, mode="JAX", updates=None)
@@ -99,7 +100,7 @@ def test_scan_with_rvs():
99100

100101

101102
def test_while_scan_fails():
102-
_, [xs] = scan(
103+
xs = scan(
103104
fn=lambda x: (x + 1, until((x + 1) >= 9)),
104105
init_states=[-1],
105106
n_steps=20,

tests/loop/test_basic.py

+28-7
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,17 @@
11
import numpy as np
22

33
import pytensor
4-
from pytensor import function, grad
4+
from pytensor import function, grad, shared
55
from pytensor.loop.basic import filter, map, reduce, scan
66
from pytensor.scan import until
77
from pytensor.tensor import arange, eq, scalar, vector, zeros
8+
from pytensor.tensor.random import normal
89

910

1011
def test_scan_with_sequences():
1112
xs = vector("xs")
1213
ys = vector("ys")
13-
_, [zs] = scan(
14+
zs = scan(
1415
fn=lambda x, y: x * y,
1516
sequences=[xs, ys],
1617
)
@@ -23,7 +24,7 @@ def test_scan_with_sequences():
2324

2425
def test_scan_with_carried_and_non_carried_states():
2526
x = scalar("x")
26-
_, [ys1, ys2] = scan(
27+
[ys1, ys2] = scan(
2728
fn=lambda xtm1: (xtm1 + 1, (xtm1 + 1) * 2),
2829
init_states=[x, None],
2930
n_steps=10,
@@ -36,7 +37,7 @@ def test_scan_with_carried_and_non_carried_states():
3637

3738
def test_scan_with_sequence_and_carried_state():
3839
xs = vector("xs")
39-
_, [ys] = scan(
40+
ys = scan(
4041
fn=lambda x, ytm1: (ytm1 + 1) * x,
4142
init_states=[zeros(())],
4243
sequences=[xs],
@@ -50,7 +51,7 @@ def test_scan_taking_grads_wrt_non_sequence():
5051
xs = vector("xs")
5152
ys = xs**2
5253

53-
_, [J] = scan(
54+
J = scan(
5455
lambda i, ys, xs: grad(ys[i], wrt=xs),
5556
sequences=arange(ys.shape[0]),
5657
non_sequences=[ys, xs],
@@ -65,7 +66,7 @@ def test_scan_taking_grads_wrt_sequence():
6566
xs = vector("xs")
6667
ys = xs**2
6768

68-
_, [J] = scan(
69+
J = scan(
6970
lambda y, xs: grad(y, wrt=xs),
7071
sequences=[ys],
7172
non_sequences=[xs],
@@ -76,7 +77,7 @@ def test_scan_taking_grads_wrt_sequence():
7677

7778

7879
def test_while_scan():
79-
_, [xs] = scan(
80+
xs = scan(
8081
fn=lambda x: (x + 1, until((x + 1) >= 9)),
8182
init_states=[-1],
8283
n_steps=20,
@@ -86,6 +87,26 @@ def test_while_scan():
8687
np.testing.assert_array_equal(f(), np.arange(10))
8788

8889

90+
def test_scan_rvs():
91+
rng = shared(np.random.default_rng(123))
92+
test_rng = np.random.default_rng(123)
93+
94+
def normal_fn(prev_rng):
95+
next_rng, x = normal(rng=prev_rng).owner.outputs
96+
return next_rng, x
97+
98+
[rngs, xs] = scan(
99+
fn=normal_fn,
100+
init_states=[rng, None],
101+
n_steps=5,
102+
)
103+
fn = function([], xs, updates={rng: rngs[-1]})
104+
105+
for i in range(3):
106+
res = fn()
107+
np.testing.assert_almost_equal(res, test_rng.normal(size=5))
108+
109+
89110
def test_map():
90111
xs = vector("xs")
91112
ys = map(

0 commit comments

Comments
 (0)