Skip to content

Commit c6a0e5b

Browse files
committed
Fix Scan JAX dispatcher
1 parent 2dc912d commit c6a0e5b

File tree

2 files changed

+107
-41
lines changed

2 files changed

+107
-41
lines changed

pytensor/link/jax/dispatch/scan.py

+51-27
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,31 @@
11
import jax
22
import jax.numpy as jnp
33

4-
from pytensor.graph.fg import FunctionGraph
54
from pytensor.link.jax.dispatch.basic import jax_funcify
65
from pytensor.scan.op import Scan
76
from pytensor.scan.utils import ScanArgs
87

98

109
@jax_funcify.register(Scan)
1110
def jax_funcify_Scan(op, **kwargs):
12-
inner_fg = FunctionGraph(op.inputs, op.outputs)
13-
jax_at_inner_func = jax_funcify(inner_fg, **kwargs)
11+
# TODO: Raise NotImplementedError if While scan
12+
13+
# Apply inner rewrites
14+
# TODO: Not sure this is the right place to do this, should we have a rewrite that
15+
# explicitly triggers the optimization of the inner graphs of Scan?
16+
# The C-code defers it to the make_thunk phase
17+
fgraph = op.fgraph.clone()
18+
rewriter = op.mode_instance.optimizer
19+
rewriter(fgraph)
20+
scan_inner_func = jax_funcify(fgraph, **kwargs)
1421

1522
def scan(*outer_inputs):
1623
scan_args = ScanArgs(
17-
list(outer_inputs), [None] * op.info.n_outs, op.inputs, op.outputs, op.info
24+
list(outer_inputs),
25+
[None] * len(op.inner_outputs),
26+
op.inner_inputs,
27+
op.inner_outputs,
28+
op.info,
1829
)
1930

2031
# `outer_inputs` is a list with the following composite form:
@@ -29,16 +40,13 @@ def scan(*outer_inputs):
2940
n_steps = scan_args.n_steps
3041
seqs = scan_args.outer_in_seqs
3142

32-
# TODO: mit_mots
3343
mit_mot_in_slices = []
44+
if scan_args.outer_in_mit_mot:
45+
raise NotImplementedError("JAX Scan with MIT-MOT not supported yet.")
3446

3547
mit_sot_in_slices = []
3648
for tap, seq in zip(scan_args.mit_sot_in_slices, scan_args.outer_in_mit_sot):
37-
neg_taps = [abs(t) for t in tap if t < 0]
38-
pos_taps = [abs(t) for t in tap if t > 0]
39-
max_neg = max(neg_taps) if neg_taps else 0
40-
max_pos = max(pos_taps) if pos_taps else 0
41-
init_slice = seq[: max_neg + max_pos]
49+
init_slice = seq[: abs(min(tap))]
4250
mit_sot_in_slices.append(init_slice)
4351

4452
sit_sot_in_slices = [seq[0] for seq in scan_args.outer_in_sit_sot]
@@ -76,6 +84,7 @@ def jax_args_to_inner_scan(op, carry, x):
7684
for array, index in zip(inner_in_mit_sot, scan_args.mit_sot_in_slices):
7785
inner_in_mit_sot_flatten.extend(array[jnp.array(index)])
7886

87+
# Concatenate lists
7988
inner_scan_inputs = sum(
8089
[
8190
inner_in_seqs,
@@ -116,9 +125,13 @@ def update_mit_sot(mit_sot, new_val):
116125
if not inner_in_sit_sot:
117126
inner_out_sit_sot = []
118127
else:
119-
inner_out_sit_sot = inner_scan_outs
128+
inner_out_sit_sot = inner_scan_outs[
129+
len(inner_in_mit_sot) : len(inner_in_mit_sot)
130+
+ len(inner_in_sit_sot)
131+
]
132+
120133
new_carry = (
121-
inner_in_mit_mot,
134+
inner_in_mit_mot, # Just an empty list, we raise earlier if there are any MIT-MOT
122135
inner_out_mit_sot,
123136
inner_out_sit_sot,
124137
inner_in_shared,
@@ -129,28 +142,39 @@ def update_mit_sot(mit_sot, new_val):
129142

130143
def jax_inner_func(carry, x):
131144
inner_args = jax_args_to_inner_scan(op, carry, x)
132-
inner_scan_outs = list(jax_at_inner_func(*inner_args))
145+
inner_scan_outs = list(scan_inner_func(*inner_args))
133146
new_carry = inner_scan_outs_to_jax_outs(op, carry, inner_scan_outs)
134147
return new_carry, inner_scan_outs
135148

136-
_, scan_out = jax.lax.scan(jax_inner_func, init_carry, seqs, length=n_steps)
149+
_, scan_outs = jax.lax.scan(jax_inner_func, init_carry, seqs, length=n_steps)
137150

138151
# We need to prepend the initial values so that the JAX output will
139152
# match the raw `Scan` `Op` output and, thus, work with a downstream
140153
# `Subtensor` `Op` introduced by the `scan` helper function.
141-
def append_scan_out(scan_in_part, scan_out_part):
142-
return jnp.concatenate([scan_in_part[:-n_steps], scan_out_part], axis=0)
143-
144-
if scan_args.outer_in_mit_sot:
145-
scan_out_final = [
146-
append_scan_out(init, out)
147-
for init, out in zip(scan_args.outer_in_mit_sot, scan_out)
148-
]
149-
elif scan_args.outer_in_sit_sot:
150-
scan_out_final = [
151-
append_scan_out(init, out)
152-
for init, out in zip(scan_args.outer_in_sit_sot, scan_out)
153-
]
154+
scan_out_final = []
155+
for init, scan_out, buffer in zip(
156+
mit_sot_in_slices
157+
+ sit_sot_in_slices
158+
+ [None] * len(scan_args.outer_in_nit_sot),
159+
scan_outs,
160+
scan_args.outer_in_mit_sot
161+
+ scan_args.outer_in_sit_sot
162+
+ scan_args.outer_in_nit_sot,
163+
):
164+
if init is not None:
165+
# MIT-SOT and SIT-SOT: The final output should be as long as the input buffer
166+
full_scan_out = jnp.concatenate(
167+
[
168+
jnp.atleast_1d(init),
169+
jnp.atleast_1d(scan_out),
170+
],
171+
axis=0,
172+
)
173+
partial_scan_out = full_scan_out[-buffer.shape[0] :]
174+
else:
175+
# NIT-SOT: Buffer is just the number of entries that should be returned
176+
partial_scan_out = jnp.atleast_1d(scan_out)[-buffer:]
177+
scan_out_final.append(partial_scan_out)
154178

155179
if len(scan_out_final) == 1:
156180
scan_out_final = scan_out_final[0]

tests/link/jax/test_scan.py

+56-14
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,70 @@
11
import numpy as np
22
import pytest
3-
from packaging.version import parse as version_parse
43

54
import pytensor.tensor as at
65
from pytensor.configdefaults import config
76
from pytensor.graph.fg import FunctionGraph
87
from pytensor.scan.basic import scan
98
from pytensor.tensor.math import gammaln, log
10-
from pytensor.tensor.type import ivector, lscalar, scalar
9+
from pytensor.tensor.type import lscalar, scalar, vector
1110
from tests.link.jax.test_basic import compare_jax_and_py
1211

1312

1413
jax = pytest.importorskip("jax")
1514

1615

17-
@pytest.mark.xfail(
18-
version_parse(jax.__version__) >= version_parse("0.2.12"),
19-
reason="Omnistaging cannot be disabled",
20-
)
21-
def test_jax_scan_multiple_output():
16+
@pytest.mark.parametrize("view", [None, (-1,), slice(-2, None, None)])
17+
def test_simple_sit_sot_scan(view):
18+
x0 = at.scalar("x0", dtype="float64")
19+
xs, _ = scan(
20+
lambda xtm1: xtm1 + 1,
21+
outputs_info=[x0],
22+
n_steps=10,
23+
)
24+
if view:
25+
xs = xs[view]
26+
fg = FunctionGraph([x0], [xs])
27+
test_input_vals = [np.e]
28+
compare_jax_and_py(fg, test_input_vals)
29+
30+
31+
@pytest.mark.parametrize("view", [None, (-1,), slice(-4, -1, None)])
32+
def test_simple_mit_sot_scan(view):
33+
x0 = at.vector("x0", dtype="float64", shape=(3,))
34+
xs, _ = scan(
35+
lambda xtm3, xtm1: xtm3 + xtm1 + 1,
36+
outputs_info=[{"initial": x0, "taps": [-3, -1]}],
37+
n_steps=10,
38+
)
39+
if view:
40+
xs = xs[view]
41+
fg = FunctionGraph([x0], [xs])
42+
test_input_vals = [np.full((3,), np.e)]
43+
compare_jax_and_py(fg, test_input_vals)
44+
45+
46+
@pytest.mark.parametrize("view", [None, (-2,), slice(None, None, 2)])
47+
def test_simple_nit_sot_scan(view):
48+
rng = np.random.default_rng(seed=49)
49+
50+
xs = at.vector("x0", dtype="float64", shape=(10,))
51+
s0 = at.zeros(())
52+
53+
# We need to have a recurring state, otherwise this simple scan would
54+
# be rewritten as a simple Elemwise on xs
55+
[_, ys], _ = scan(
56+
lambda x, s: (s + 1, at.exp(x + s)),
57+
outputs_info=[s0, None],
58+
sequences=[xs],
59+
)
60+
if view:
61+
ys = ys[view]
62+
fg = FunctionGraph([xs], [ys])
63+
test_input_vals = [rng.normal(size=10)]
64+
compare_jax_and_py(fg, test_input_vals)
65+
66+
67+
def test_SEIR_scan():
2268
"""Test a scan implementation of a SEIR model.
2369
2470
SEIR model definition:
@@ -38,8 +84,8 @@ def binom_log_prob(n, p, value):
3884
return binomln(n, value) + value * log(p) + (n - value) * log(1 - p)
3985

4086
# sequences
41-
at_C = ivector("C_t")
42-
at_D = ivector("D_t")
87+
at_C = vector("C_t", dtype="int32", shape=(8,))
88+
at_D = vector("D_t", dtype="int32", shape=(8,))
4389
# outputs_info (initial conditions)
4490
st0 = lscalar("s_t0")
4591
et0 = lscalar("e_t0")
@@ -108,11 +154,7 @@ def seir_one_step(ct0, dt0, st0, et0, it0, logp_c, logp_d, beta, gamma, delta):
108154
compare_jax_and_py(out_fg, test_input_vals)
109155

110156

111-
@pytest.mark.xfail(
112-
version_parse(jax.__version__) >= version_parse("0.2.12"),
113-
reason="Omnistaging cannot be disabled",
114-
)
115-
def test_jax_scan_tap_output():
157+
def test_mitsot_with_nonseq_scan():
116158
a_at = scalar("a")
117159

118160
def input_step_fn(y_tm1, y_tm3, a):

0 commit comments

Comments
 (0)