Skip to content

Commit b0aa0b9

Browse files
committed
Fix Scan JAX dispatcher
1 parent eda9c46 commit b0aa0b9

File tree

2 files changed

+327
-130
lines changed

2 files changed

+327
-130
lines changed

pytensor/link/jax/dispatch/scan.py

+166-116
Original file line numberDiff line numberDiff line change
@@ -1,159 +1,209 @@
11
import jax
22
import jax.numpy as jnp
33

4-
from pytensor.graph.fg import FunctionGraph
54
from pytensor.link.jax.dispatch.basic import jax_funcify
65
from pytensor.scan.op import Scan
7-
from pytensor.scan.utils import ScanArgs
86

97

108
@jax_funcify.register(Scan)
11-
def jax_funcify_Scan(op, **kwargs):
12-
inner_fg = FunctionGraph(op.inputs, op.outputs)
13-
jax_at_inner_func = jax_funcify(inner_fg, **kwargs)
9+
def jax_funcify_Scan(op: Scan, **kwargs):
10+
info = op.info
1411

15-
def scan(*outer_inputs):
16-
scan_args = ScanArgs(
17-
list(outer_inputs), [None] * op.info.n_outs, op.inputs, op.outputs, op.info
12+
if info.as_while:
13+
raise NotImplementedError("While Scan cannot yet be converted to JAX")
14+
15+
if info.n_mit_mot:
16+
raise NotImplementedError(
17+
"Scan with MIT-MOT (gradients of scan) cannot yet be converted to JAX"
1818
)
1919

20-
# `outer_inputs` is a list with the following composite form:
21-
# [n_steps]
22-
# + outer_in_seqs
23-
# + outer_in_mit_mot
24-
# + outer_in_mit_sot
25-
# + outer_in_sit_sot
26-
# + outer_in_shared
27-
# + outer_in_nit_sot
28-
# + outer_in_non_seqs
29-
n_steps = scan_args.n_steps
30-
seqs = scan_args.outer_in_seqs
31-
32-
# TODO: mit_mots
33-
mit_mot_in_slices = []
34-
35-
mit_sot_in_slices = []
36-
for tap, seq in zip(scan_args.mit_sot_in_slices, scan_args.outer_in_mit_sot):
37-
neg_taps = [abs(t) for t in tap if t < 0]
38-
pos_taps = [abs(t) for t in tap if t > 0]
39-
max_neg = max(neg_taps) if neg_taps else 0
40-
max_pos = max(pos_taps) if pos_taps else 0
41-
init_slice = seq[: max_neg + max_pos]
42-
mit_sot_in_slices.append(init_slice)
43-
44-
sit_sot_in_slices = [seq[0] for seq in scan_args.outer_in_sit_sot]
20+
# Optimize inner graph
21+
rewriter = op.mode_instance.optimizer
22+
rewriter(op.fgraph)
23+
scan_inner_func = jax_funcify(op.fgraph, **kwargs)
24+
25+
def scan(*outer_inputs):
26+
# Extract JAX scan inputs
27+
outer_inputs = list(outer_inputs)
28+
n_steps = outer_inputs[0] # JAX `length`
29+
seqs = op.outer_seqs(outer_inputs) # JAX `xs`
30+
31+
mit_sot_init = []
32+
for tap, seq in zip(op.info.mit_sot_in_slices, op.outer_mitsot(outer_inputs)):
33+
init_slice = seq[: abs(min(tap))]
34+
mit_sot_init.append(init_slice)
35+
36+
sit_sot_init = [seq[0] for seq in op.outer_sitsot(outer_inputs)]
4537

4638
init_carry = (
47-
mit_mot_in_slices,
48-
mit_sot_in_slices,
49-
sit_sot_in_slices,
50-
scan_args.outer_in_shared,
51-
scan_args.outer_in_non_seqs,
52-
)
39+
mit_sot_init,
40+
sit_sot_init,
41+
op.outer_shared(outer_inputs),
42+
op.outer_non_seqs(outer_inputs),
43+
) # JAX `init`
44+
45+
def jax_args_to_inner_func_args(op, carry, x):
46+
"""Convert JAX scan arguments into format expected by scan_inner_func.
47+
48+
scan(carry, x) -> scan_inner_func(seqs, mit_sot, sit_sot, shared, non_seqs)
49+
"""
5350

54-
def jax_args_to_inner_scan(op, carry, x):
55-
# `carry` contains all inner-output taps, non_seqs, and shared
56-
# terms
51+
# `carry` contains all inner taps, non_seqs, and shared terms
5752
(
58-
inner_in_mit_mot,
59-
inner_in_mit_sot,
60-
inner_in_sit_sot,
61-
inner_in_shared,
62-
inner_in_non_seqs,
53+
inner_mit_sot,
54+
inner_sit_sot,
55+
inner_shared,
56+
inner_non_seqs,
6357
) = carry
6458

65-
# `x` contains the in_seqs
66-
inner_in_seqs = x
59+
# `x` contains the inner sequences
60+
inner_seqs = x
6761

68-
# `inner_scan_inputs` is a list with the following composite form:
69-
# inner_in_seqs
70-
# + sum(inner_in_mit_mot, [])
71-
# + sum(inner_in_mit_sot, [])
72-
# + inner_in_sit_sot
73-
# + inner_in_shared
74-
# + inner_in_non_seqs
75-
inner_in_mit_sot_flatten = []
76-
for array, index in zip(inner_in_mit_sot, scan_args.mit_sot_in_slices):
77-
inner_in_mit_sot_flatten.extend(array[jnp.array(index)])
62+
mit_sot_flatten = []
63+
for array, index in zip(inner_mit_sot, op.info.mit_sot_in_slices):
64+
mit_sot_flatten.extend(array[jnp.array(index)])
7865

66+
# Concatenate lists
7967
inner_scan_inputs = sum(
8068
[
81-
inner_in_seqs,
82-
inner_in_mit_mot,
83-
inner_in_mit_sot_flatten,
84-
inner_in_sit_sot,
85-
inner_in_shared,
86-
inner_in_non_seqs,
69+
inner_seqs,
70+
mit_sot_flatten,
71+
inner_sit_sot,
72+
inner_shared,
73+
inner_non_seqs,
8774
],
8875
[],
8976
)
9077

9178
return inner_scan_inputs
9279

93-
def inner_scan_outs_to_jax_outs(
80+
def inner_func_outs_to_jax_outs(
9481
op,
9582
old_carry,
9683
inner_scan_outs,
9784
):
85+
"""Convert inner_scan_func outputs into format expected by JAX scan.
86+
87+
old_carry + (mit_sot_outs, sit_sot_outs, nit_sot_outs, shared_outs) -> (new_carry, ys)
88+
"""
9889
(
99-
inner_in_mit_mot,
100-
inner_in_mit_sot,
101-
inner_in_sit_sot,
102-
inner_in_shared,
103-
inner_in_non_seqs,
90+
inner_mit_sot,
91+
inner_sit_sot,
92+
inner_shared,
93+
inner_non_seqs,
10494
) = old_carry
10595

106-
def update_mit_sot(mit_sot, new_val):
107-
return jnp.concatenate([mit_sot[1:], new_val[None, ...]], axis=0)
108-
109-
inner_out_mit_sot = [
110-
update_mit_sot(mit_sot, new_val)
111-
for mit_sot, new_val in zip(inner_in_mit_sot, inner_scan_outs)
112-
]
96+
inner_mit_sot_outs = op.inner_mitsot_outs(inner_scan_outs)
97+
inner_sit_sot_outs = op.inner_sitsot_outs(inner_scan_outs)
98+
inner_nit_sot_outs = op.inner_nitsot_outs(inner_scan_outs)
99+
inner_shared_outs = op.inner_shared_outs(inner_scan_outs)
100+
101+
# Replace the oldest mit_sot tap by the newest value
102+
inner_mit_sot_new = []
103+
if inner_mit_sot:
104+
inner_mit_sot_new = [
105+
jnp.concatenate([old_mit_sot[1:], new_val[None, ...]], axis=0)
106+
for old_mit_sot, new_val in zip(
107+
inner_mit_sot,
108+
inner_mit_sot_outs,
109+
)
110+
]
111+
112+
# Nothing needs to be done with sit_sot
113+
inner_sit_sot_new = inner_sit_sot_outs
114+
115+
inner_shared_new = inner_shared
116+
if inner_shared_outs:
117+
# Replace old shared inputs by new shared outputs
118+
inner_shared_new[: len(inner_shared_outs)] = inner_shared_outs
113119

114-
# This should contain all inner-output taps, non_seqs, and shared
115-
# terms
116-
if not inner_in_sit_sot:
117-
inner_out_sit_sot = []
118-
else:
119-
inner_out_sit_sot = inner_scan_outs
120120
new_carry = (
121-
inner_in_mit_mot,
121+
inner_mit_sot_new,
122+
inner_sit_sot_new,
123+
inner_shared_new,
124+
inner_non_seqs,
125+
)
126+
127+
# Shared variables and non_seqs are not traced
128+
traced_outs = sum(
129+
[
130+
inner_mit_sot_outs,
131+
inner_sit_sot_outs,
132+
inner_nit_sot_outs,
133+
],
134+
[],
135+
)
136+
137+
return new_carry, traced_outs
138+
139+
def jax_inner_func(carry, x):
140+
inner_args = jax_args_to_inner_func_args(op, carry, x)
141+
inner_scan_outs = list(scan_inner_func(*inner_args))
142+
new_carry, traced_outs = inner_func_outs_to_jax_outs(
143+
op, carry, inner_scan_outs
144+
)
145+
return new_carry, traced_outs
146+
147+
# Extract PyTensor scan outputs
148+
final_carry, traces = jax.lax.scan(
149+
jax_inner_func, init_carry, seqs, length=n_steps
150+
)
151+
152+
def get_partial_traces(traces):
153+
"""Convert JAX scan traces to PyTensor traces.
154+
155+
We need to:
156+
1. Prepend initial states to JAX output traces
157+
2. Slice final traces if Scan was instructed to only keep a portion
158+
"""
159+
160+
init_states = mit_sot_init + sit_sot_init + [None] * op.info.n_nit_sot
161+
buffers = (
162+
op.outer_mitsot(outer_inputs)
163+
+ op.outer_sitsot(outer_inputs)
164+
+ op.outer_nitsot(outer_inputs)
165+
)
166+
partial_traces = []
167+
for init_state, trace, buffer in zip(init_states, traces, buffers):
168+
if init_state is not None:
169+
# MIT-SOT and SIT-SOT: The final output should be as long as the input buffer
170+
full_trace = jnp.concatenate(
171+
[jnp.atleast_1d(init_state), jnp.atleast_1d(trace)],
172+
axis=0,
173+
)
174+
buffer_size = buffer.shape[0]
175+
else:
176+
# NIT-SOT: Buffer is just the number of entries that should be returned
177+
full_trace = jnp.atleast_1d(trace)
178+
buffer_size = buffer
179+
180+
partial_trace = full_trace[-buffer_size:]
181+
partial_traces.append(partial_trace)
182+
183+
return partial_traces
184+
185+
def get_shared_outs(final_carry):
186+
"""Retrive last state of shared_outs from final_carry.
187+
188+
These outputs cannot be traced in PyTensor Scan
189+
"""
190+
if not info.n_shared_outs:
191+
return []
192+
193+
(
122194
inner_out_mit_sot,
123195
inner_out_sit_sot,
124-
inner_in_shared,
196+
inner_out_shared,
125197
inner_in_non_seqs,
126-
)
198+
) = final_carry
127199

128-
return new_carry
200+
shared_outs = inner_out_shared[: info.n_shared_outs]
201+
return list(shared_outs)
129202

130-
def jax_inner_func(carry, x):
131-
inner_args = jax_args_to_inner_scan(op, carry, x)
132-
inner_scan_outs = list(jax_at_inner_func(*inner_args))
133-
new_carry = inner_scan_outs_to_jax_outs(op, carry, inner_scan_outs)
134-
return new_carry, inner_scan_outs
135-
136-
_, scan_out = jax.lax.scan(jax_inner_func, init_carry, seqs, length=n_steps)
137-
138-
# We need to prepend the initial values so that the JAX output will
139-
# match the raw `Scan` `Op` output and, thus, work with a downstream
140-
# `Subtensor` `Op` introduced by the `scan` helper function.
141-
def append_scan_out(scan_in_part, scan_out_part):
142-
return jnp.concatenate([scan_in_part[:-n_steps], scan_out_part], axis=0)
143-
144-
if scan_args.outer_in_mit_sot:
145-
scan_out_final = [
146-
append_scan_out(init, out)
147-
for init, out in zip(scan_args.outer_in_mit_sot, scan_out)
148-
]
149-
elif scan_args.outer_in_sit_sot:
150-
scan_out_final = [
151-
append_scan_out(init, out)
152-
for init, out in zip(scan_args.outer_in_sit_sot, scan_out)
153-
]
154-
155-
if len(scan_out_final) == 1:
156-
scan_out_final = scan_out_final[0]
157-
return scan_out_final
203+
scan_outs_final = get_partial_traces(traces) + get_shared_outs(final_carry)
204+
205+
if len(scan_outs_final) == 1:
206+
scan_outs_final = scan_outs_final[0]
207+
return scan_outs_final
158208

159209
return scan

0 commit comments

Comments
 (0)