Skip to content

Commit 88cc33b

Browse files
committed
Fix Scan JAX dispatcher
1 parent 7b60904 commit 88cc33b

File tree

2 files changed

+344
-136
lines changed

2 files changed

+344
-136
lines changed

pytensor/link/jax/dispatch/scan.py

+156-122
Original file line numberDiff line numberDiff line change
@@ -1,159 +1,193 @@
11
import jax
22
import jax.numpy as jnp
33

4-
from pytensor.graph.fg import FunctionGraph
54
from pytensor.link.jax.dispatch.basic import jax_funcify
65
from pytensor.scan.op import Scan
7-
from pytensor.scan.utils import ScanArgs
86

97

108
@jax_funcify.register(Scan)
11-
def jax_funcify_Scan(op, **kwargs):
12-
inner_fg = FunctionGraph(op.inputs, op.outputs)
13-
jax_at_inner_func = jax_funcify(inner_fg, **kwargs)
9+
def jax_funcify_Scan(op: Scan, **kwargs):
10+
info = op.info
1411

15-
def scan(*outer_inputs):
16-
scan_args = ScanArgs(
17-
list(outer_inputs), [None] * op.info.n_outs, op.inputs, op.outputs, op.info
12+
if info.as_while:
13+
raise NotImplementedError("While Scan cannot yet be converted to JAX")
14+
15+
if info.n_mit_mot:
16+
raise NotImplementedError(
17+
"Scan with MIT-MOT (gradients of scan) cannot yet be converted to JAX"
1818
)
1919

20-
# `outer_inputs` is a list with the following composite form:
21-
# [n_steps]
22-
# + outer_in_seqs
23-
# + outer_in_mit_mot
24-
# + outer_in_mit_sot
25-
# + outer_in_sit_sot
26-
# + outer_in_shared
27-
# + outer_in_nit_sot
28-
# + outer_in_non_seqs
29-
n_steps = scan_args.n_steps
30-
seqs = scan_args.outer_in_seqs
31-
32-
# TODO: mit_mots
33-
mit_mot_in_slices = []
34-
35-
mit_sot_in_slices = []
36-
for tap, seq in zip(scan_args.mit_sot_in_slices, scan_args.outer_in_mit_sot):
37-
neg_taps = [abs(t) for t in tap if t < 0]
38-
pos_taps = [abs(t) for t in tap if t > 0]
39-
max_neg = max(neg_taps) if neg_taps else 0
40-
max_pos = max(pos_taps) if pos_taps else 0
41-
init_slice = seq[: max_neg + max_pos]
42-
mit_sot_in_slices.append(init_slice)
43-
44-
sit_sot_in_slices = [seq[0] for seq in scan_args.outer_in_sit_sot]
20+
# Optimize inner graph
21+
rewriter = op.mode_instance.optimizer
22+
rewriter(op.fgraph)
23+
scan_inner_func = jax_funcify(op.fgraph, **kwargs)
24+
25+
def scan(*outer_inputs):
26+
# Extract JAX scan inputs
27+
outer_inputs = list(outer_inputs)
28+
n_steps = outer_inputs[0] # JAX `length`
29+
seqs = op.outer_seqs(outer_inputs) # JAX `xs`
30+
31+
mit_sot_init = []
32+
for tap, seq in zip(op.info.mit_sot_in_slices, op.outer_mitsot(outer_inputs)):
33+
init_slice = seq[: abs(min(tap))]
34+
mit_sot_init.append(init_slice)
35+
36+
sit_sot_init = [seq[0] for seq in op.outer_sitsot(outer_inputs)]
4537

4638
init_carry = (
47-
mit_mot_in_slices,
48-
mit_sot_in_slices,
49-
sit_sot_in_slices,
50-
scan_args.outer_in_shared,
51-
scan_args.outer_in_non_seqs,
52-
)
39+
mit_sot_init,
40+
sit_sot_init,
41+
op.outer_shared(outer_inputs),
42+
op.outer_non_seqs(outer_inputs),
43+
) # JAX `init`
44+
45+
def jax_args_to_inner_func_args(carry, x):
46+
"""Convert JAX scan arguments into format expected by scan_inner_func.
47+
48+
scan(carry, x) -> scan_inner_func(seqs, mit_sot, sit_sot, shared, non_seqs)
49+
"""
5350

54-
def jax_args_to_inner_scan(op, carry, x):
55-
# `carry` contains all inner-output taps, non_seqs, and shared
56-
# terms
51+
# `carry` contains all inner taps, shared terms, and non_seqs
5752
(
58-
inner_in_mit_mot,
59-
inner_in_mit_sot,
60-
inner_in_sit_sot,
61-
inner_in_shared,
62-
inner_in_non_seqs,
53+
inner_mit_sot,
54+
inner_sit_sot,
55+
inner_shared,
56+
inner_non_seqs,
6357
) = carry
6458

65-
# `x` contains the in_seqs
66-
inner_in_seqs = x
67-
68-
# `inner_scan_inputs` is a list with the following composite form:
69-
# inner_in_seqs
70-
# + sum(inner_in_mit_mot, [])
71-
# + sum(inner_in_mit_sot, [])
72-
# + inner_in_sit_sot
73-
# + inner_in_shared
74-
# + inner_in_non_seqs
75-
inner_in_mit_sot_flatten = []
76-
for array, index in zip(inner_in_mit_sot, scan_args.mit_sot_in_slices):
77-
inner_in_mit_sot_flatten.extend(array[jnp.array(index)])
78-
79-
inner_scan_inputs = sum(
80-
[
81-
inner_in_seqs,
82-
inner_in_mit_mot,
83-
inner_in_mit_sot_flatten,
84-
inner_in_sit_sot,
85-
inner_in_shared,
86-
inner_in_non_seqs,
87-
],
88-
[],
89-
)
59+
# `x` contains the inner sequences
60+
inner_seqs = x
61+
62+
mit_sot_flatten = []
63+
for array, index in zip(inner_mit_sot, op.info.mit_sot_in_slices):
64+
mit_sot_flatten.extend(array[jnp.array(index)])
65+
66+
inner_scan_inputs = [
67+
*inner_seqs,
68+
*mit_sot_flatten,
69+
*inner_sit_sot,
70+
*inner_shared,
71+
*inner_non_seqs,
72+
]
9073

9174
return inner_scan_inputs
9275

93-
def inner_scan_outs_to_jax_outs(
94-
op,
76+
def inner_func_outs_to_jax_outs(
9577
old_carry,
9678
inner_scan_outs,
9779
):
80+
"""Convert inner_scan_func outputs into format expected by JAX scan.
81+
82+
old_carry + (mit_sot_outs, sit_sot_outs, nit_sot_outs, shared_outs) -> (new_carry, ys)
83+
"""
9884
(
99-
inner_in_mit_mot,
100-
inner_in_mit_sot,
101-
inner_in_sit_sot,
102-
inner_in_shared,
103-
inner_in_non_seqs,
85+
inner_mit_sot,
86+
inner_sit_sot,
87+
inner_shared,
88+
inner_non_seqs,
10489
) = old_carry
10590

106-
def update_mit_sot(mit_sot, new_val):
107-
return jnp.concatenate([mit_sot[1:], new_val[None, ...]], axis=0)
108-
109-
inner_out_mit_sot = [
110-
update_mit_sot(mit_sot, new_val)
111-
for mit_sot, new_val in zip(inner_in_mit_sot, inner_scan_outs)
91+
inner_mit_sot_outs = op.inner_mitsot_outs(inner_scan_outs)
92+
inner_sit_sot_outs = op.inner_sitsot_outs(inner_scan_outs)
93+
inner_nit_sot_outs = op.inner_nitsot_outs(inner_scan_outs)
94+
inner_shared_outs = op.inner_shared_outs(inner_scan_outs)
95+
96+
# Replace the oldest mit_sot tap by the newest value
97+
inner_mit_sot_new = [
98+
jnp.concatenate([old_mit_sot[1:], new_val[None, ...]], axis=0)
99+
for old_mit_sot, new_val in zip(
100+
inner_mit_sot,
101+
inner_mit_sot_outs,
102+
)
112103
]
113104

114-
# This should contain all inner-output taps, non_seqs, and shared
115-
# terms
116-
if not inner_in_sit_sot:
117-
inner_out_sit_sot = []
118-
else:
119-
inner_out_sit_sot = inner_scan_outs
105+
# Nothing needs to be done with sit_sot
106+
inner_sit_sot_new = inner_sit_sot_outs
107+
108+
inner_shared_new = inner_shared
109+
# Replace old shared inputs by new shared outputs
110+
inner_shared_new[: len(inner_shared_outs)] = inner_shared_outs
111+
120112
new_carry = (
121-
inner_in_mit_mot,
113+
inner_mit_sot_new,
114+
inner_sit_sot_new,
115+
inner_shared_new,
116+
inner_non_seqs,
117+
)
118+
119+
# Shared variables and non_seqs are not traced
120+
traced_outs = [
121+
*inner_mit_sot_outs,
122+
*inner_sit_sot_outs,
123+
*inner_nit_sot_outs,
124+
]
125+
126+
return new_carry, traced_outs
127+
128+
def jax_inner_func(carry, x):
129+
inner_args = jax_args_to_inner_func_args(carry, x)
130+
inner_scan_outs = list(scan_inner_func(*inner_args))
131+
new_carry, traced_outs = inner_func_outs_to_jax_outs(carry, inner_scan_outs)
132+
return new_carry, traced_outs
133+
134+
# Extract PyTensor scan outputs
135+
final_carry, traces = jax.lax.scan(
136+
jax_inner_func, init_carry, seqs, length=n_steps
137+
)
138+
139+
def get_partial_traces(traces):
140+
"""Convert JAX scan traces to PyTensor traces.
141+
142+
We need to:
143+
1. Prepend initial states to JAX output traces
144+
2. Slice final traces if Scan was instructed to only keep a portion
145+
"""
146+
147+
init_states = mit_sot_init + sit_sot_init + [None] * op.info.n_nit_sot
148+
buffers = (
149+
op.outer_mitsot(outer_inputs)
150+
+ op.outer_sitsot(outer_inputs)
151+
+ op.outer_nitsot(outer_inputs)
152+
)
153+
partial_traces = []
154+
for init_state, trace, buffer in zip(init_states, traces, buffers):
155+
if init_state is not None:
156+
# MIT-SOT and SIT-SOT: The final output should be as long as the input buffer
157+
full_trace = jnp.concatenate(
158+
[jnp.atleast_1d(init_state), jnp.atleast_1d(trace)],
159+
axis=0,
160+
)
161+
buffer_size = buffer.shape[0]
162+
else:
163+
# NIT-SOT: Buffer is just the number of entries that should be returned
164+
full_trace = jnp.atleast_1d(trace)
165+
buffer_size = buffer
166+
167+
partial_trace = full_trace[-buffer_size:]
168+
partial_traces.append(partial_trace)
169+
170+
return partial_traces
171+
172+
def get_shared_outs(final_carry):
173+
"""Retrive last state of shared_outs from final_carry.
174+
175+
These outputs cannot be traced in PyTensor Scan
176+
"""
177+
(
122178
inner_out_mit_sot,
123179
inner_out_sit_sot,
124-
inner_in_shared,
180+
inner_out_shared,
125181
inner_in_non_seqs,
126-
)
182+
) = final_carry
127183

128-
return new_carry
184+
shared_outs = inner_out_shared[: info.n_shared_outs]
185+
return list(shared_outs)
129186

130-
def jax_inner_func(carry, x):
131-
inner_args = jax_args_to_inner_scan(op, carry, x)
132-
inner_scan_outs = list(jax_at_inner_func(*inner_args))
133-
new_carry = inner_scan_outs_to_jax_outs(op, carry, inner_scan_outs)
134-
return new_carry, inner_scan_outs
135-
136-
_, scan_out = jax.lax.scan(jax_inner_func, init_carry, seqs, length=n_steps)
137-
138-
# We need to prepend the initial values so that the JAX output will
139-
# match the raw `Scan` `Op` output and, thus, work with a downstream
140-
# `Subtensor` `Op` introduced by the `scan` helper function.
141-
def append_scan_out(scan_in_part, scan_out_part):
142-
return jnp.concatenate([scan_in_part[:-n_steps], scan_out_part], axis=0)
143-
144-
if scan_args.outer_in_mit_sot:
145-
scan_out_final = [
146-
append_scan_out(init, out)
147-
for init, out in zip(scan_args.outer_in_mit_sot, scan_out)
148-
]
149-
elif scan_args.outer_in_sit_sot:
150-
scan_out_final = [
151-
append_scan_out(init, out)
152-
for init, out in zip(scan_args.outer_in_sit_sot, scan_out)
153-
]
187+
scan_outs_final = get_partial_traces(traces) + get_shared_outs(final_carry)
154188

155-
if len(scan_out_final) == 1:
156-
scan_out_final = scan_out_final[0]
157-
return scan_out_final
189+
if len(scan_outs_final) == 1:
190+
scan_outs_final = scan_outs_final[0]
191+
return scan_outs_final
158192

159193
return scan

0 commit comments

Comments
 (0)