10
10
import pytensor .tensor .random .basic as aer
11
11
from pytensor .link .jax .dispatch .basic import jax_funcify , jax_typify
12
12
from pytensor .link .jax .dispatch .shape import JAXShapeTuple
13
+ from pytensor .tensor .random .type import RandomType
13
14
from pytensor .tensor .shape import Shape , Shape_i
14
15
15
16
@@ -57,8 +58,7 @@ def jax_typify_RandomState(state, **kwargs):
57
58
state = state .get_state (legacy = False )
58
59
state ["bit_generator" ] = numpy_bit_gens [state ["bit_generator" ]]
59
60
# XXX: Is this a reasonable approach?
60
- state ["jax_state" ] = state ["state" ]["key" ][0 :2 ]
61
- return state
61
+ return state ["state" ]["key" ][0 :2 ]
62
62
63
63
64
64
@jax_typify .register (Generator )
@@ -83,7 +83,36 @@ def jax_typify_Generator(rng, **kwargs):
83
83
state_32 = _coerce_to_uint32_array (state ["state" ]["state" ])
84
84
state ["state" ]["inc" ] = inc_32 [0 ] << 32 | inc_32 [1 ]
85
85
state ["state" ]["state" ] = state_32 [0 ] << 32 | state_32 [1 ]
86
- return state
86
+ return state ["jax_state" ]
87
+
88
+
89
+ class RandomPRNGKeyType (RandomType [jax .random .PRNGKey ]):
90
+ """JAX-compatible PRNGKey type.
91
+
92
+ This type is not exposed to users directly.
93
+
94
+ It is introduced by the JIT linker in place of any RandomType input
95
+ variables used in the original function. Nodes in the function graph will
96
+ still show the original types as inputs and outputs.
97
+ """
98
+
99
+ def filter (self , data , strict : bool = False , allow_downcast = None ):
100
+ # PRNGs are just JAX Arrays, we assume this is a valid one!
101
+ if isinstance (data , jax .Array ):
102
+ return data
103
+
104
+ if strict :
105
+ raise TypeError ()
106
+
107
+ return jax_typify (data )
108
+
109
+
110
+ random_prng_key_type = RandomPRNGKeyType ()
111
+
112
+
113
+ @jax_typify .register (RandomType )
114
+ def jax_typify_RandomType (type ):
115
+ return random_prng_key_type ()
87
116
88
117
89
118
@jax_funcify .register (aer .RandomVariable )
@@ -130,12 +159,10 @@ def jax_sample_fn_generic(op):
130
159
name = op .name
131
160
jax_op = getattr (jax .random , name )
132
161
133
- def sample_fn (rng , size , dtype , * parameters ):
134
- rng_key = rng ["jax_state" ]
162
+ def sample_fn (rng_key , size , dtype , * parameters ):
135
163
rng_key , sampling_key = jax .random .split (rng_key , 2 )
136
164
sample = jax_op (sampling_key , * parameters , shape = size , dtype = dtype )
137
- rng ["jax_state" ] = rng_key
138
- return (rng , sample )
165
+ return (rng_key , sample )
139
166
140
167
return sample_fn
141
168
@@ -157,13 +184,11 @@ def jax_sample_fn_loc_scale(op):
157
184
name = op .name
158
185
jax_op = getattr (jax .random , name )
159
186
160
- def sample_fn (rng , size , dtype , * parameters ):
161
- rng_key = rng ["jax_state" ]
187
+ def sample_fn (rng_key , size , dtype , * parameters ):
162
188
rng_key , sampling_key = jax .random .split (rng_key , 2 )
163
189
loc , scale = parameters
164
190
sample = loc + jax_op (sampling_key , size , dtype ) * scale
165
- rng ["jax_state" ] = rng_key
166
- return (rng , sample )
191
+ return (rng_key , sample )
167
192
168
193
return sample_fn
169
194
@@ -175,12 +200,10 @@ def jax_sample_fn_no_dtype(op):
175
200
name = op .name
176
201
jax_op = getattr (jax .random , name )
177
202
178
- def sample_fn (rng , size , dtype , * parameters ):
179
- rng_key = rng ["jax_state" ]
203
+ def sample_fn (rng_key , size , dtype , * parameters ):
180
204
rng_key , sampling_key = jax .random .split (rng_key , 2 )
181
205
sample = jax_op (sampling_key , * parameters , shape = size )
182
- rng ["jax_state" ] = rng_key
183
- return (rng , sample )
206
+ return (rng_key , sample )
184
207
185
208
return sample_fn
186
209
@@ -201,15 +224,13 @@ def jax_sample_fn_uniform(op):
201
224
name = "randint"
202
225
jax_op = getattr (jax .random , name )
203
226
204
- def sample_fn (rng , size , dtype , * parameters ):
205
- rng_key = rng ["jax_state" ]
227
+ def sample_fn (rng_key , size , dtype , * parameters ):
206
228
rng_key , sampling_key = jax .random .split (rng_key , 2 )
207
229
minval , maxval = parameters
208
230
sample = jax_op (
209
231
sampling_key , shape = size , dtype = dtype , minval = minval , maxval = maxval
210
232
)
211
- rng ["jax_state" ] = rng_key
212
- return (rng , sample )
233
+ return (rng_key , sample )
213
234
214
235
return sample_fn
215
236
@@ -226,13 +247,11 @@ def jax_sample_fn_shape_rate(op):
226
247
name = op .name
227
248
jax_op = getattr (jax .random , name )
228
249
229
- def sample_fn (rng , size , dtype , * parameters ):
230
- rng_key = rng ["jax_state" ]
250
+ def sample_fn (rng_key , size , dtype , * parameters ):
231
251
rng_key , sampling_key = jax .random .split (rng_key , 2 )
232
252
(shape , rate ) = parameters
233
253
sample = jax_op (sampling_key , shape , size , dtype ) / rate
234
- rng ["jax_state" ] = rng_key
235
- return (rng , sample )
254
+ return (rng_key , sample )
236
255
237
256
return sample_fn
238
257
@@ -241,13 +260,11 @@ def sample_fn(rng, size, dtype, *parameters):
241
260
def jax_sample_fn_exponential (op ):
242
261
"""JAX implementation of `ExponentialRV`."""
243
262
244
- def sample_fn (rng , size , dtype , * parameters ):
245
- rng_key = rng ["jax_state" ]
263
+ def sample_fn (rng_key , size , dtype , * parameters ):
246
264
rng_key , sampling_key = jax .random .split (rng_key , 2 )
247
265
(scale ,) = parameters
248
266
sample = jax .random .exponential (sampling_key , size , dtype ) * scale
249
- rng ["jax_state" ] = rng_key
250
- return (rng , sample )
267
+ return (rng_key , sample )
251
268
252
269
return sample_fn
253
270
@@ -256,17 +273,15 @@ def sample_fn(rng, size, dtype, *parameters):
256
273
def jax_sample_fn_t (op ):
257
274
"""JAX implementation of `StudentTRV`."""
258
275
259
- def sample_fn (rng , size , dtype , * parameters ):
260
- rng_key = rng ["jax_state" ]
276
+ def sample_fn (rng_key , size , dtype , * parameters ):
261
277
rng_key , sampling_key = jax .random .split (rng_key , 2 )
262
278
(
263
279
df ,
264
280
loc ,
265
281
scale ,
266
282
) = parameters
267
283
sample = loc + jax .random .t (sampling_key , df , size , dtype ) * scale
268
- rng ["jax_state" ] = rng_key
269
- return (rng , sample )
284
+ return (rng_key , sample )
270
285
271
286
return sample_fn
272
287
@@ -275,13 +290,11 @@ def sample_fn(rng, size, dtype, *parameters):
275
290
def jax_funcify_choice (op ):
276
291
"""JAX implementation of `ChoiceRV`."""
277
292
278
- def sample_fn (rng , size , dtype , * parameters ):
279
- rng_key = rng ["jax_state" ]
293
+ def sample_fn (rng_key , size , dtype , * parameters ):
280
294
rng_key , sampling_key = jax .random .split (rng_key , 2 )
281
295
(a , p , replace ) = parameters
282
296
smpl_value = jax .random .choice (sampling_key , a , size , replace , p )
283
- rng ["jax_state" ] = rng_key
284
- return (rng , smpl_value )
297
+ return (rng_key , smpl_value )
285
298
286
299
return sample_fn
287
300
@@ -290,13 +303,11 @@ def sample_fn(rng, size, dtype, *parameters):
290
303
def jax_sample_fn_permutation (op ):
291
304
"""JAX implementation of `PermutationRV`."""
292
305
293
- def sample_fn (rng , size , dtype , * parameters ):
294
- rng_key = rng ["jax_state" ]
306
+ def sample_fn (rng_key , size , dtype , * parameters ):
295
307
rng_key , sampling_key = jax .random .split (rng_key , 2 )
296
308
(x ,) = parameters
297
309
sample = jax .random .permutation (sampling_key , x )
298
- rng ["jax_state" ] = rng_key
299
- return (rng , sample )
310
+ return (rng_key , sample )
300
311
301
312
return sample_fn
302
313
@@ -311,15 +322,12 @@ def jax_sample_fn_binomial(op):
311
322
312
323
from numpyro .distributions .util import binomial
313
324
314
- def sample_fn (rng , size , dtype , n , p ):
315
- rng_key = rng ["jax_state" ]
325
+ def sample_fn (rng_key , size , dtype , n , p ):
316
326
rng_key , sampling_key = jax .random .split (rng_key , 2 )
317
327
318
328
sample = binomial (key = sampling_key , n = n , p = p , shape = size )
319
329
320
- rng ["jax_state" ] = rng_key
321
-
322
- return (rng , sample )
330
+ return (rng_key , sample )
323
331
324
332
return sample_fn
325
333
@@ -334,15 +342,12 @@ def jax_sample_fn_multinomial(op):
334
342
335
343
from numpyro .distributions .util import multinomial
336
344
337
- def sample_fn (rng , size , dtype , n , p ):
338
- rng_key = rng ["jax_state" ]
345
+ def sample_fn (rng_key , size , dtype , n , p ):
339
346
rng_key , sampling_key = jax .random .split (rng_key , 2 )
340
347
341
348
sample = multinomial (key = sampling_key , n = n , p = p , shape = size )
342
349
343
- rng ["jax_state" ] = rng_key
344
-
345
- return (rng , sample )
350
+ return (rng_key , sample )
346
351
347
352
return sample_fn
348
353
@@ -357,17 +362,14 @@ def jax_sample_fn_vonmises(op):
357
362
358
363
from numpyro .distributions .util import von_mises_centered
359
364
360
- def sample_fn (rng , size , dtype , mu , kappa ):
361
- rng_key = rng ["jax_state" ]
365
+ def sample_fn (rng_key , size , dtype , mu , kappa ):
362
366
rng_key , sampling_key = jax .random .split (rng_key , 2 )
363
367
364
368
sample = von_mises_centered (
365
369
key = sampling_key , concentration = kappa , shape = size , dtype = dtype
366
370
)
367
371
sample = (sample + mu + np .pi ) % (2.0 * np .pi ) - np .pi
368
372
369
- rng ["jax_state" ] = rng_key
370
-
371
- return (rng , sample )
373
+ return (rng_key , sample )
372
374
373
375
return sample_fn
0 commit comments