Skip to content

Commit b6ea0b1

Browse files
committed
Add JAX rewrite for new Scan Op
1 parent ebb9d1c commit b6ea0b1

File tree

5 files changed

+164
-7
lines changed

5 files changed

+164
-7
lines changed

pytensor/compile/mode.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -449,7 +449,9 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs):
449449

450450
JAX = Mode(
451451
JAXLinker(),
452-
RewriteDatabaseQuery(include=["fast_run", "jax"], exclude=["cxx_only", "BlasOpt"]),
452+
RewriteDatabaseQuery(
453+
include=["fast_run", "jax"], exclude=["cxx_only", "BlasOpt", "not_jax"]
454+
),
453455
)
454456
NUMBA = Mode(
455457
NumbaLinker(),

pytensor/link/jax/dispatch/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -12,5 +12,6 @@
1212
import pytensor.link.jax.dispatch.random
1313
import pytensor.link.jax.dispatch.elemwise
1414
import pytensor.link.jax.dispatch.scan
15+
import pytensor.link.jax.dispatch.loop
1516

1617
# isort: on

pytensor/link/jax/dispatch/loop.py

+52
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
import jax
2+
3+
from pytensor.compile.mode import get_mode
4+
from pytensor.link.jax.dispatch.basic import jax_funcify
5+
from pytensor.loop.op import Scan
6+
from pytensor.tensor.type_other import NoneTypeT
7+
8+
9+
@jax_funcify.register(Scan)
10+
def jax_funcify_Scan(op, node, **kwargs):
11+
# TODO: Rewrite as a while loop if only last states are used
12+
if op.has_while_condition:
13+
raise NotImplementedError(
14+
"Scan ops with while condition cannot be transpiled JAX"
15+
)
16+
17+
# Apply inner rewrites
18+
# TODO: Not sure this is the right place to do this, should we have a rewrite that
19+
# explicitly triggers the optimization of the inner graphs of Scan?
20+
update_fg = op.update_fg.clone()
21+
rewriter = get_mode("JAX").optimizer
22+
rewriter(update_fg)
23+
24+
jaxified_scan_inner_fn = jax_funcify(update_fg, **kwargs)
25+
traceable_states = [
26+
i
27+
for i, n in enumerate(node.outputs[op.n_states :])
28+
if not isinstance(n.type, NoneTypeT)
29+
]
30+
31+
def scan(max_iters, *outer_inputs):
32+
states = outer_inputs[: op.n_states]
33+
constants = outer_inputs[op.n_states :]
34+
35+
def scan_fn(carry, _):
36+
resume, *carry = jaxified_scan_inner_fn(*carry, *constants)
37+
assert resume
38+
carry = list(carry)
39+
# Return states as both carry and output to be appended
40+
return carry, [c for i, c in enumerate(carry) if i in traceable_states]
41+
42+
print(max_iters)
43+
states, traces = jax.lax.scan(
44+
scan_fn, init=list(states), xs=None, length=max_iters
45+
)
46+
for i in range(len(states)):
47+
if i not in traceable_states:
48+
traces.insert(i, None)
49+
50+
return *states, *traces
51+
52+
return scan

tests/link/jax/test_basic.py

+4-6
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,13 @@
55
import pytest
66

77
from pytensor.compile.function import function
8-
from pytensor.compile.mode import Mode
8+
from pytensor.compile.mode import get_mode
99
from pytensor.compile.sharedvalue import SharedVariable, shared
1010
from pytensor.configdefaults import config
1111
from pytensor.graph.basic import Apply
1212
from pytensor.graph.fg import FunctionGraph
1313
from pytensor.graph.op import Op, get_test_value
14-
from pytensor.graph.rewriting.db import RewriteDatabaseQuery
1514
from pytensor.ifelse import ifelse
16-
from pytensor.link.jax import JAXLinker
1715
from pytensor.raise_op import assert_op
1816
from pytensor.tensor.type import dscalar, scalar, vector
1917

@@ -27,9 +25,9 @@ def set_pytensor_flags():
2725
jax = pytest.importorskip("jax")
2826

2927

30-
opts = RewriteDatabaseQuery(include=["jax"], exclude=["cxx_only", "BlasOpt"])
31-
jax_mode = Mode(JAXLinker(), opts)
32-
py_mode = Mode("py", opts)
28+
# We assume that the JAX mode includes all the rewrites needed to transpile JAX graphs
29+
jax_mode = get_mode("JAX")
30+
py_mode = get_mode("FAST_COMPILE")
3331

3432

3533
def compare_jax_and_py(

tests/link/jax/test_loop.py

+104
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
import numpy as np
2+
import pytest
3+
from link.jax.test_basic import compare_jax_and_py
4+
5+
from pytensor import function, shared
6+
from pytensor.graph import FunctionGraph
7+
from pytensor.loop.basic import scan
8+
from pytensor.scan import until
9+
from pytensor.tensor import scalar, vector, zeros
10+
from pytensor.tensor.random import normal
11+
12+
13+
def test_scan_with_single_sequence():
14+
xs = vector("xs")
15+
_, [ys] = scan(lambda x: x * 100, sequences=[xs])
16+
17+
out_fg = FunctionGraph([xs], [ys])
18+
compare_jax_and_py(out_fg, [np.arange(10)])
19+
20+
21+
def test_scan_with_single_sequence_shortened_by_nsteps():
22+
xs = vector("xs", shape=(10,)) # JAX needs the length to be constant
23+
_, [ys] = scan(
24+
lambda x: x * 100,
25+
sequences=[xs],
26+
n_steps=9,
27+
)
28+
29+
out_fg = FunctionGraph([xs], [ys])
30+
compare_jax_and_py(out_fg, [np.arange(10)])
31+
32+
33+
def test_scan_with_multiple_sequences():
34+
# JAX can only handle constant n_steps
35+
xs = vector("xs", shape=(10,))
36+
ys = vector("ys", shape=(10,))
37+
_, [zs] = scan(
38+
fn=lambda x, y: x * y,
39+
sequences=[xs, ys],
40+
)
41+
42+
out_fg = FunctionGraph([xs, ys], [zs])
43+
compare_jax_and_py(
44+
out_fg, [np.arange(10, dtype=xs.dtype), np.arange(10, dtype=ys.dtype)]
45+
)
46+
47+
48+
def test_scan_with_carried_and_non_carried_states():
49+
x = scalar("x")
50+
_, [ys1, ys2] = scan(
51+
fn=lambda xtm1: (xtm1 + 1, (xtm1 + 1) * 2),
52+
init_states=[x, None],
53+
n_steps=10,
54+
)
55+
out_fg = FunctionGraph([x], [ys1, ys2])
56+
compare_jax_and_py(out_fg, [-1])
57+
58+
59+
def test_scan_with_sequence_and_carried_state():
60+
xs = vector("xs")
61+
_, [ys] = scan(
62+
fn=lambda x, ytm1: (ytm1 + 1) * x,
63+
init_states=[zeros(())],
64+
sequences=[xs],
65+
)
66+
out_fg = FunctionGraph([xs], [ys])
67+
compare_jax_and_py(out_fg, [[1, 2, 3]])
68+
69+
70+
def test_scan_with_rvs():
71+
rng = shared(np.random.default_rng(123))
72+
73+
[next_rng, _], [_, xs] = scan(
74+
fn=lambda prev_rng: normal(rng=prev_rng).owner.outputs,
75+
init_states=[rng, None],
76+
n_steps=10,
77+
)
78+
79+
# First without updates
80+
fn = function([], xs, mode="JAX", updates=None)
81+
res1 = fn()
82+
res2 = fn()
83+
assert not set(tuple(np.array(res1))) ^ set(tuple(np.array(res2)))
84+
85+
# Now with updates
86+
fn = function([], xs, mode="JAX", updates={rng: next_rng})
87+
res1 = fn()
88+
res2 = fn()
89+
assert not set(tuple(np.array(res1))) & set(tuple(np.array(res2)))
90+
91+
92+
def test_while_scan_fails():
93+
_, [xs] = scan(
94+
fn=lambda x: (x + 1, until((x + 1) >= 9)),
95+
init_states=[-1],
96+
n_steps=20,
97+
)
98+
99+
out_fg = FunctionGraph([], [xs])
100+
with pytest.raises(
101+
NotImplementedError,
102+
match="Scan ops with while condition cannot be transpiled JAX",
103+
):
104+
compare_jax_and_py(out_fg, [])

0 commit comments

Comments
 (0)