1
1
import jax
2
2
import jax .numpy as jnp
3
3
4
- from pytensor .graph .fg import FunctionGraph
5
4
from pytensor .link .jax .dispatch .basic import jax_funcify
6
5
from pytensor .scan .op import Scan
7
6
from pytensor .scan .utils import ScanArgs
8
7
9
8
10
9
@jax_funcify .register (Scan )
11
10
def jax_funcify_Scan (op , ** kwargs ):
12
- inner_fg = FunctionGraph (op .inputs , op .outputs )
13
- jax_at_inner_func = jax_funcify (inner_fg , ** kwargs )
11
+ # TODO: Raise NotImplementedError if While scan
12
+
13
+ # Apply inner rewrites
14
+ # TODO: Not sure this is the right place to do this, should we have a rewrite that
15
+ # explicitly triggers the optimization of the inner graphs of Scan?
16
+ # The C-code defers it to the make_thunk phase
17
+ fgraph = op .fgraph .clone ()
18
+ rewriter = op .mode_instance .optimizer
19
+ rewriter (fgraph )
20
+ scan_inner_func = jax_funcify (fgraph , ** kwargs )
14
21
15
22
def scan (* outer_inputs ):
16
23
scan_args = ScanArgs (
17
- list (outer_inputs ), [None ] * op .info .n_outs , op .inputs , op .outputs , op .info
24
+ list (outer_inputs ),
25
+ [None ] * len (op .inner_outputs ),
26
+ op .inner_inputs ,
27
+ op .inner_outputs ,
28
+ op .info ,
18
29
)
19
30
20
31
# `outer_inputs` is a list with the following composite form:
@@ -29,16 +40,13 @@ def scan(*outer_inputs):
29
40
n_steps = scan_args .n_steps
30
41
seqs = scan_args .outer_in_seqs
31
42
32
- # TODO: mit_mots
33
43
mit_mot_in_slices = []
44
+ if scan_args .outer_in_mit_mot :
45
+ raise NotImplementedError ("JAX Scan with MIT-MOT not supported yet." )
34
46
35
47
mit_sot_in_slices = []
36
48
for tap , seq in zip (scan_args .mit_sot_in_slices , scan_args .outer_in_mit_sot ):
37
- neg_taps = [abs (t ) for t in tap if t < 0 ]
38
- pos_taps = [abs (t ) for t in tap if t > 0 ]
39
- max_neg = max (neg_taps ) if neg_taps else 0
40
- max_pos = max (pos_taps ) if pos_taps else 0
41
- init_slice = seq [: max_neg + max_pos ]
49
+ init_slice = seq [: abs (min (tap ))]
42
50
mit_sot_in_slices .append (init_slice )
43
51
44
52
sit_sot_in_slices = [seq [0 ] for seq in scan_args .outer_in_sit_sot ]
@@ -76,6 +84,7 @@ def jax_args_to_inner_scan(op, carry, x):
76
84
for array , index in zip (inner_in_mit_sot , scan_args .mit_sot_in_slices ):
77
85
inner_in_mit_sot_flatten .extend (array [jnp .array (index )])
78
86
87
+ # Concatenate lists
79
88
inner_scan_inputs = sum (
80
89
[
81
90
inner_in_seqs ,
@@ -116,9 +125,13 @@ def update_mit_sot(mit_sot, new_val):
116
125
if not inner_in_sit_sot :
117
126
inner_out_sit_sot = []
118
127
else :
119
- inner_out_sit_sot = inner_scan_outs
128
+ inner_out_sit_sot = inner_scan_outs [
129
+ len (inner_in_mit_sot ) : len (inner_in_mit_sot )
130
+ + len (inner_in_sit_sot )
131
+ ]
132
+
120
133
new_carry = (
121
- inner_in_mit_mot ,
134
+ inner_in_mit_mot , # Just an empty list, we raise earlier if there are any MIT-MOT
122
135
inner_out_mit_sot ,
123
136
inner_out_sit_sot ,
124
137
inner_in_shared ,
@@ -129,28 +142,39 @@ def update_mit_sot(mit_sot, new_val):
129
142
130
143
def jax_inner_func (carry , x ):
131
144
inner_args = jax_args_to_inner_scan (op , carry , x )
132
- inner_scan_outs = list (jax_at_inner_func (* inner_args ))
145
+ inner_scan_outs = list (scan_inner_func (* inner_args ))
133
146
new_carry = inner_scan_outs_to_jax_outs (op , carry , inner_scan_outs )
134
147
return new_carry , inner_scan_outs
135
148
136
- _ , scan_out = jax .lax .scan (jax_inner_func , init_carry , seqs , length = n_steps )
149
+ _ , scan_outs = jax .lax .scan (jax_inner_func , init_carry , seqs , length = n_steps )
137
150
138
151
# We need to prepend the initial values so that the JAX output will
139
152
# match the raw `Scan` `Op` output and, thus, work with a downstream
140
153
# `Subtensor` `Op` introduced by the `scan` helper function.
141
- def append_scan_out (scan_in_part , scan_out_part ):
142
- return jnp .concatenate ([scan_in_part [:- n_steps ], scan_out_part ], axis = 0 )
143
-
144
- if scan_args .outer_in_mit_sot :
145
- scan_out_final = [
146
- append_scan_out (init , out )
147
- for init , out in zip (scan_args .outer_in_mit_sot , scan_out )
148
- ]
149
- elif scan_args .outer_in_sit_sot :
150
- scan_out_final = [
151
- append_scan_out (init , out )
152
- for init , out in zip (scan_args .outer_in_sit_sot , scan_out )
153
- ]
154
+ scan_out_final = []
155
+ for init , scan_out , buffer in zip (
156
+ mit_sot_in_slices
157
+ + sit_sot_in_slices
158
+ + [None ] * len (scan_args .outer_in_nit_sot ),
159
+ scan_outs ,
160
+ scan_args .outer_in_mit_sot
161
+ + scan_args .outer_in_sit_sot
162
+ + scan_args .outer_in_nit_sot ,
163
+ ):
164
+ if init is not None :
165
+ # MIT-SOT and SIT-SOT: The final output should be as long as the input buffer
166
+ full_scan_out = jnp .concatenate (
167
+ [
168
+ jnp .atleast_1d (init ),
169
+ jnp .atleast_1d (scan_out ),
170
+ ],
171
+ axis = 0 ,
172
+ )
173
+ partial_scan_out = full_scan_out [- buffer .shape [0 ] :]
174
+ else :
175
+ # NIT-SOT: Buffer is just the number of entries that should be returned
176
+ partial_scan_out = jnp .atleast_1d (scan_out )[- buffer :]
177
+ scan_out_final .append (partial_scan_out )
154
178
155
179
if len (scan_out_final ) == 1 :
156
180
scan_out_final = scan_out_final [0 ]
0 commit comments