Skip to content

Commit d5ba7fe

Browse files
committed
Fix Scan JAX dispatcher
1 parent eda9c46 commit d5ba7fe

File tree

2 files changed

+323
-136
lines changed

2 files changed

+323
-136
lines changed

pytensor/link/jax/dispatch/scan.py

+162-122
Original file line numberDiff line numberDiff line change
@@ -1,159 +1,199 @@
11
import jax
22
import jax.numpy as jnp
33

4-
from pytensor.graph.fg import FunctionGraph
54
from pytensor.link.jax.dispatch.basic import jax_funcify
65
from pytensor.scan.op import Scan
7-
from pytensor.scan.utils import ScanArgs
86

97

108
@jax_funcify.register(Scan)
11-
def jax_funcify_Scan(op, **kwargs):
12-
inner_fg = FunctionGraph(op.inputs, op.outputs)
13-
jax_at_inner_func = jax_funcify(inner_fg, **kwargs)
9+
def jax_funcify_Scan(op: Scan, **kwargs):
10+
info = op.info
1411

15-
def scan(*outer_inputs):
16-
scan_args = ScanArgs(
17-
list(outer_inputs), [None] * op.info.n_outs, op.inputs, op.outputs, op.info
12+
if info.as_while:
13+
raise NotImplementedError("While Scan cannot yet be converted to JAX")
14+
15+
if info.n_mit_mot:
16+
raise NotImplementedError(
17+
"Scan with MIT-MOT (gradients of scan) cannot yet be converted to JAX"
1818
)
1919

20-
# `outer_inputs` is a list with the following composite form:
21-
# [n_steps]
22-
# + outer_in_seqs
23-
# + outer_in_mit_mot
24-
# + outer_in_mit_sot
25-
# + outer_in_sit_sot
26-
# + outer_in_shared
27-
# + outer_in_nit_sot
28-
# + outer_in_non_seqs
29-
n_steps = scan_args.n_steps
30-
seqs = scan_args.outer_in_seqs
31-
32-
# TODO: mit_mots
33-
mit_mot_in_slices = []
34-
35-
mit_sot_in_slices = []
36-
for tap, seq in zip(scan_args.mit_sot_in_slices, scan_args.outer_in_mit_sot):
37-
neg_taps = [abs(t) for t in tap if t < 0]
38-
pos_taps = [abs(t) for t in tap if t > 0]
39-
max_neg = max(neg_taps) if neg_taps else 0
40-
max_pos = max(pos_taps) if pos_taps else 0
41-
init_slice = seq[: max_neg + max_pos]
42-
mit_sot_in_slices.append(init_slice)
43-
44-
sit_sot_in_slices = [seq[0] for seq in scan_args.outer_in_sit_sot]
20+
# Optimize inner graph
21+
rewriter = op.mode_instance.optimizer
22+
rewriter(op.fgraph)
23+
scan_inner_func = jax_funcify(op.fgraph, **kwargs)
24+
25+
def scan(*outer_inputs):
26+
# Extract JAX scan inputs
27+
outer_inputs = list(outer_inputs)
28+
n_steps = outer_inputs[0] # JAX `length`
29+
seqs = op.outer_seqs(outer_inputs) # JAX `xs`
30+
31+
mit_sot_init = []
32+
for tap, seq in zip(op.info.mit_sot_in_slices, op.outer_mitsot(outer_inputs)):
33+
init_slice = seq[: abs(min(tap))]
34+
mit_sot_init.append(init_slice)
35+
36+
sit_sot_init = [seq[0] for seq in op.outer_sitsot(outer_inputs)]
4537

4638
init_carry = (
47-
mit_mot_in_slices,
48-
mit_sot_in_slices,
49-
sit_sot_in_slices,
50-
scan_args.outer_in_shared,
51-
scan_args.outer_in_non_seqs,
52-
)
39+
mit_sot_init,
40+
sit_sot_init,
41+
op.outer_shared(outer_inputs),
42+
op.outer_non_seqs(outer_inputs),
43+
) # JAX `init`
44+
45+
def jax_args_to_inner_func_args(carry, x):
46+
"""Convert JAX scan arguments into format expected by scan_inner_func.
47+
48+
scan(carry, x) -> scan_inner_func(seqs, mit_sot, sit_sot, shared, non_seqs)
49+
"""
5350

54-
def jax_args_to_inner_scan(op, carry, x):
55-
# `carry` contains all inner-output taps, non_seqs, and shared
56-
# terms
51+
# `carry` contains all inner taps, shared terms, and non_seqs
5752
(
58-
inner_in_mit_mot,
59-
inner_in_mit_sot,
60-
inner_in_sit_sot,
61-
inner_in_shared,
62-
inner_in_non_seqs,
53+
inner_mit_sot,
54+
inner_sit_sot,
55+
inner_shared,
56+
inner_non_seqs,
6357
) = carry
6458

65-
# `x` contains the in_seqs
66-
inner_in_seqs = x
67-
68-
# `inner_scan_inputs` is a list with the following composite form:
69-
# inner_in_seqs
70-
# + sum(inner_in_mit_mot, [])
71-
# + sum(inner_in_mit_sot, [])
72-
# + inner_in_sit_sot
73-
# + inner_in_shared
74-
# + inner_in_non_seqs
75-
inner_in_mit_sot_flatten = []
76-
for array, index in zip(inner_in_mit_sot, scan_args.mit_sot_in_slices):
77-
inner_in_mit_sot_flatten.extend(array[jnp.array(index)])
78-
79-
inner_scan_inputs = sum(
80-
[
81-
inner_in_seqs,
82-
inner_in_mit_mot,
83-
inner_in_mit_sot_flatten,
84-
inner_in_sit_sot,
85-
inner_in_shared,
86-
inner_in_non_seqs,
87-
],
88-
[],
89-
)
59+
# `x` contains the inner sequences
60+
inner_seqs = x
61+
62+
mit_sot_flatten = []
63+
for array, index in zip(inner_mit_sot, op.info.mit_sot_in_slices):
64+
mit_sot_flatten.extend(array[jnp.array(index)])
65+
66+
inner_scan_inputs = [
67+
*inner_seqs,
68+
*mit_sot_flatten,
69+
*inner_sit_sot,
70+
*inner_shared,
71+
*inner_non_seqs,
72+
]
9073

9174
return inner_scan_inputs
9275

93-
def inner_scan_outs_to_jax_outs(
94-
op,
76+
def inner_func_outs_to_jax_outs(
9577
old_carry,
9678
inner_scan_outs,
9779
):
80+
"""Convert inner_scan_func outputs into format expected by JAX scan.
81+
82+
old_carry + (mit_sot_outs, sit_sot_outs, nit_sot_outs, shared_outs) -> (new_carry, ys)
83+
"""
9884
(
99-
inner_in_mit_mot,
100-
inner_in_mit_sot,
101-
inner_in_sit_sot,
102-
inner_in_shared,
103-
inner_in_non_seqs,
85+
inner_mit_sot,
86+
inner_sit_sot,
87+
inner_shared,
88+
inner_non_seqs,
10489
) = old_carry
10590

106-
def update_mit_sot(mit_sot, new_val):
107-
return jnp.concatenate([mit_sot[1:], new_val[None, ...]], axis=0)
91+
inner_mit_sot_outs = op.inner_mitsot_outs(inner_scan_outs)
92+
inner_sit_sot_outs = op.inner_sitsot_outs(inner_scan_outs)
93+
inner_nit_sot_outs = op.inner_nitsot_outs(inner_scan_outs)
94+
inner_shared_outs = op.inner_shared_outs(inner_scan_outs)
95+
96+
# Replace the oldest mit_sot tap by the newest value
97+
inner_mit_sot_new = []
98+
if inner_mit_sot:
99+
inner_mit_sot_new = [
100+
jnp.concatenate([old_mit_sot[1:], new_val[None, ...]], axis=0)
101+
for old_mit_sot, new_val in zip(
102+
inner_mit_sot,
103+
inner_mit_sot_outs,
104+
)
105+
]
106+
107+
# Nothing needs to be done with sit_sot
108+
inner_sit_sot_new = inner_sit_sot_outs
109+
110+
inner_shared_new = inner_shared
111+
if inner_shared_outs:
112+
# Replace old shared inputs by new shared outputs
113+
inner_shared_new[: len(inner_shared_outs)] = inner_shared_outs
108114

109-
inner_out_mit_sot = [
110-
update_mit_sot(mit_sot, new_val)
111-
for mit_sot, new_val in zip(inner_in_mit_sot, inner_scan_outs)
115+
new_carry = (
116+
inner_mit_sot_new,
117+
inner_sit_sot_new,
118+
inner_shared_new,
119+
inner_non_seqs,
120+
)
121+
122+
# Shared variables and non_seqs are not traced
123+
traced_outs = [
124+
*inner_mit_sot_outs,
125+
*inner_sit_sot_outs,
126+
*inner_nit_sot_outs,
112127
]
113128

114-
# This should contain all inner-output taps, non_seqs, and shared
115-
# terms
116-
if not inner_in_sit_sot:
117-
inner_out_sit_sot = []
118-
else:
119-
inner_out_sit_sot = inner_scan_outs
120-
new_carry = (
121-
inner_in_mit_mot,
129+
return new_carry, traced_outs
130+
131+
def jax_inner_func(carry, x):
132+
inner_args = jax_args_to_inner_func_args(carry, x)
133+
inner_scan_outs = list(scan_inner_func(*inner_args))
134+
new_carry, traced_outs = inner_func_outs_to_jax_outs(carry, inner_scan_outs)
135+
return new_carry, traced_outs
136+
137+
# Extract PyTensor scan outputs
138+
final_carry, traces = jax.lax.scan(
139+
jax_inner_func, init_carry, seqs, length=n_steps
140+
)
141+
142+
def get_partial_traces(traces):
143+
"""Convert JAX scan traces to PyTensor traces.
144+
145+
We need to:
146+
1. Prepend initial states to JAX output traces
147+
2. Slice final traces if Scan was instructed to only keep a portion
148+
"""
149+
150+
init_states = mit_sot_init + sit_sot_init + [None] * op.info.n_nit_sot
151+
buffers = (
152+
op.outer_mitsot(outer_inputs)
153+
+ op.outer_sitsot(outer_inputs)
154+
+ op.outer_nitsot(outer_inputs)
155+
)
156+
partial_traces = []
157+
for init_state, trace, buffer in zip(init_states, traces, buffers):
158+
if init_state is not None:
159+
# MIT-SOT and SIT-SOT: The final output should be as long as the input buffer
160+
full_trace = jnp.concatenate(
161+
[jnp.atleast_1d(init_state), jnp.atleast_1d(trace)],
162+
axis=0,
163+
)
164+
buffer_size = buffer.shape[0]
165+
else:
166+
# NIT-SOT: Buffer is just the number of entries that should be returned
167+
full_trace = jnp.atleast_1d(trace)
168+
buffer_size = buffer
169+
170+
partial_trace = full_trace[-buffer_size:]
171+
partial_traces.append(partial_trace)
172+
173+
return partial_traces
174+
175+
def get_shared_outs(final_carry):
176+
"""Retrive last state of shared_outs from final_carry.
177+
178+
These outputs cannot be traced in PyTensor Scan
179+
"""
180+
if not info.n_shared_outs:
181+
return []
182+
183+
(
122184
inner_out_mit_sot,
123185
inner_out_sit_sot,
124-
inner_in_shared,
186+
inner_out_shared,
125187
inner_in_non_seqs,
126-
)
188+
) = final_carry
127189

128-
return new_carry
190+
shared_outs = inner_out_shared[: info.n_shared_outs]
191+
return list(shared_outs)
129192

130-
def jax_inner_func(carry, x):
131-
inner_args = jax_args_to_inner_scan(op, carry, x)
132-
inner_scan_outs = list(jax_at_inner_func(*inner_args))
133-
new_carry = inner_scan_outs_to_jax_outs(op, carry, inner_scan_outs)
134-
return new_carry, inner_scan_outs
135-
136-
_, scan_out = jax.lax.scan(jax_inner_func, init_carry, seqs, length=n_steps)
137-
138-
# We need to prepend the initial values so that the JAX output will
139-
# match the raw `Scan` `Op` output and, thus, work with a downstream
140-
# `Subtensor` `Op` introduced by the `scan` helper function.
141-
def append_scan_out(scan_in_part, scan_out_part):
142-
return jnp.concatenate([scan_in_part[:-n_steps], scan_out_part], axis=0)
143-
144-
if scan_args.outer_in_mit_sot:
145-
scan_out_final = [
146-
append_scan_out(init, out)
147-
for init, out in zip(scan_args.outer_in_mit_sot, scan_out)
148-
]
149-
elif scan_args.outer_in_sit_sot:
150-
scan_out_final = [
151-
append_scan_out(init, out)
152-
for init, out in zip(scan_args.outer_in_sit_sot, scan_out)
153-
]
193+
scan_outs_final = get_partial_traces(traces) + get_shared_outs(final_carry)
154194

155-
if len(scan_out_final) == 1:
156-
scan_out_final = scan_out_final[0]
157-
return scan_out_final
195+
if len(scan_outs_final) == 1:
196+
scan_outs_final = scan_outs_final[0]
197+
return scan_outs_final
158198

159199
return scan

0 commit comments

Comments
 (0)