10
10
import pytensor .tensor .random .basic as aer
11
11
from pytensor .link .jax .dispatch .basic import jax_funcify , jax_typify
12
12
from pytensor .link .jax .dispatch .shape import JAXShapeTuple
13
+ from pytensor .tensor .random .type import RandomType
13
14
from pytensor .tensor .shape import Shape , Shape_i
14
15
15
16
@@ -55,8 +56,7 @@ def jax_typify_RandomState(state, **kwargs):
55
56
state = state .get_state (legacy = False )
56
57
state ["bit_generator" ] = numpy_bit_gens [state ["bit_generator" ]]
57
58
# XXX: Is this a reasonable approach?
58
- state ["jax_state" ] = state ["state" ]["key" ][0 :2 ]
59
- return state
59
+ return state ["state" ]["key" ][0 :2 ]
60
60
61
61
62
62
@jax_typify .register (Generator )
@@ -81,7 +81,27 @@ def jax_typify_Generator(rng, **kwargs):
81
81
state_32 = _coerce_to_uint32_array (state ["state" ]["state" ])
82
82
state ["state" ]["inc" ] = inc_32 [0 ] << 32 | inc_32 [1 ]
83
83
state ["state" ]["state" ] = state_32 [0 ] << 32 | state_32 [1 ]
84
- return state
84
+ return state ["jax_state" ]
85
+
86
+
87
+ class RandomPRNGKeyType (RandomType [jax .random .PRNGKey ]):
88
+ def filter (self , data , strict : bool = False , allow_downcast = None ):
89
+ # PRNGs are just JAX Arrays, we assume this is a valid one!
90
+ if isinstance (data , jax .Array ):
91
+ return data
92
+
93
+ if strict :
94
+ raise TypeError ()
95
+
96
+ return jax_typify (data )
97
+
98
+
99
+ random_prng_key_type = RandomPRNGKeyType ()
100
+
101
+
102
+ @jax_typify .register (RandomType )
103
+ def jax_typify_RandomType (type ):
104
+ return random_prng_key_type ()
85
105
86
106
87
107
@jax_funcify .register (aer .RandomVariable )
@@ -128,12 +148,10 @@ def jax_sample_fn_generic(op):
128
148
name = op .name
129
149
jax_op = getattr (jax .random , name )
130
150
131
- def sample_fn (rng , size , dtype , * parameters ):
132
- rng_key = rng ["jax_state" ]
151
+ def sample_fn (rng_key , size , dtype , * parameters ):
133
152
rng_key , sampling_key = jax .random .split (rng_key , 2 )
134
153
sample = jax_op (sampling_key , * parameters , shape = size , dtype = dtype )
135
- rng ["jax_state" ] = rng_key
136
- return (rng , sample )
154
+ return (rng_key , sample )
137
155
138
156
return sample_fn
139
157
@@ -155,13 +173,11 @@ def jax_sample_fn_loc_scale(op):
155
173
name = op .name
156
174
jax_op = getattr (jax .random , name )
157
175
158
- def sample_fn (rng , size , dtype , * parameters ):
159
- rng_key = rng ["jax_state" ]
176
+ def sample_fn (rng_key , size , dtype , * parameters ):
160
177
rng_key , sampling_key = jax .random .split (rng_key , 2 )
161
178
loc , scale = parameters
162
179
sample = loc + jax_op (sampling_key , size , dtype ) * scale
163
- rng ["jax_state" ] = rng_key
164
- return (rng , sample )
180
+ return (rng_key , sample )
165
181
166
182
return sample_fn
167
183
@@ -173,12 +189,10 @@ def jax_sample_fn_no_dtype(op):
173
189
name = op .name
174
190
jax_op = getattr (jax .random , name )
175
191
176
- def sample_fn (rng , size , dtype , * parameters ):
177
- rng_key = rng ["jax_state" ]
192
+ def sample_fn (rng_key , size , dtype , * parameters ):
178
193
rng_key , sampling_key = jax .random .split (rng_key , 2 )
179
194
sample = jax_op (sampling_key , * parameters , shape = size )
180
- rng ["jax_state" ] = rng_key
181
- return (rng , sample )
195
+ return (rng_key , sample )
182
196
183
197
return sample_fn
184
198
@@ -199,15 +213,13 @@ def jax_sample_fn_uniform(op):
199
213
name = "randint"
200
214
jax_op = getattr (jax .random , name )
201
215
202
- def sample_fn (rng , size , dtype , * parameters ):
203
- rng_key = rng ["jax_state" ]
216
+ def sample_fn (rng_key , size , dtype , * parameters ):
204
217
rng_key , sampling_key = jax .random .split (rng_key , 2 )
205
218
minval , maxval = parameters
206
219
sample = jax_op (
207
220
sampling_key , shape = size , dtype = dtype , minval = minval , maxval = maxval
208
221
)
209
- rng ["jax_state" ] = rng_key
210
- return (rng , sample )
222
+ return (rng_key , sample )
211
223
212
224
return sample_fn
213
225
@@ -224,13 +236,11 @@ def jax_sample_fn_shape_rate(op):
224
236
name = op .name
225
237
jax_op = getattr (jax .random , name )
226
238
227
- def sample_fn (rng , size , dtype , * parameters ):
228
- rng_key = rng ["jax_state" ]
239
+ def sample_fn (rng_key , size , dtype , * parameters ):
229
240
rng_key , sampling_key = jax .random .split (rng_key , 2 )
230
241
(shape , rate ) = parameters
231
242
sample = jax_op (sampling_key , shape , size , dtype ) / rate
232
- rng ["jax_state" ] = rng_key
233
- return (rng , sample )
243
+ return (rng_key , sample )
234
244
235
245
return sample_fn
236
246
@@ -239,13 +249,11 @@ def sample_fn(rng, size, dtype, *parameters):
239
249
def jax_sample_fn_exponential (op ):
240
250
"""JAX implementation of `ExponentialRV`."""
241
251
242
- def sample_fn (rng , size , dtype , * parameters ):
243
- rng_key = rng ["jax_state" ]
252
+ def sample_fn (rng_key , size , dtype , * parameters ):
244
253
rng_key , sampling_key = jax .random .split (rng_key , 2 )
245
254
(scale ,) = parameters
246
255
sample = jax .random .exponential (sampling_key , size , dtype ) * scale
247
- rng ["jax_state" ] = rng_key
248
- return (rng , sample )
256
+ return (rng_key , sample )
249
257
250
258
return sample_fn
251
259
@@ -254,17 +262,15 @@ def sample_fn(rng, size, dtype, *parameters):
254
262
def jax_sample_fn_t (op ):
255
263
"""JAX implementation of `StudentTRV`."""
256
264
257
- def sample_fn (rng , size , dtype , * parameters ):
258
- rng_key = rng ["jax_state" ]
265
+ def sample_fn (rng_key , size , dtype , * parameters ):
259
266
rng_key , sampling_key = jax .random .split (rng_key , 2 )
260
267
(
261
268
df ,
262
269
loc ,
263
270
scale ,
264
271
) = parameters
265
272
sample = loc + jax .random .t (sampling_key , df , size , dtype ) * scale
266
- rng ["jax_state" ] = rng_key
267
- return (rng , sample )
273
+ return (rng_key , sample )
268
274
269
275
return sample_fn
270
276
@@ -273,13 +279,11 @@ def sample_fn(rng, size, dtype, *parameters):
273
279
def jax_funcify_choice (op ):
274
280
"""JAX implementation of `ChoiceRV`."""
275
281
276
- def sample_fn (rng , size , dtype , * parameters ):
277
- rng_key = rng ["jax_state" ]
282
+ def sample_fn (rng_key , size , dtype , * parameters ):
278
283
rng_key , sampling_key = jax .random .split (rng_key , 2 )
279
284
(a , p , replace ) = parameters
280
285
smpl_value = jax .random .choice (sampling_key , a , size , replace , p )
281
- rng ["jax_state" ] = rng_key
282
- return (rng , smpl_value )
286
+ return (rng_key , smpl_value )
283
287
284
288
return sample_fn
285
289
@@ -288,13 +292,11 @@ def sample_fn(rng, size, dtype, *parameters):
288
292
def jax_sample_fn_permutation (op ):
289
293
"""JAX implementation of `PermutationRV`."""
290
294
291
- def sample_fn (rng , size , dtype , * parameters ):
292
- rng_key = rng ["jax_state" ]
295
+ def sample_fn (rng_key , size , dtype , * parameters ):
293
296
rng_key , sampling_key = jax .random .split (rng_key , 2 )
294
297
(x ,) = parameters
295
298
sample = jax .random .permutation (sampling_key , x )
296
- rng ["jax_state" ] = rng_key
297
- return (rng , sample )
299
+ return (rng_key , sample )
298
300
299
301
return sample_fn
300
302
@@ -309,15 +311,12 @@ def jax_sample_fn_binomial(op):
309
311
310
312
from numpyro .distributions .util import binomial
311
313
312
- def sample_fn (rng , size , dtype , n , p ):
313
- rng_key = rng ["jax_state" ]
314
+ def sample_fn (rng_key , size , dtype , n , p ):
314
315
rng_key , sampling_key = jax .random .split (rng_key , 2 )
315
316
316
317
sample = binomial (key = sampling_key , n = n , p = p , shape = size )
317
318
318
- rng ["jax_state" ] = rng_key
319
-
320
- return (rng , sample )
319
+ return (rng_key , sample )
321
320
322
321
return sample_fn
323
322
@@ -332,15 +331,12 @@ def jax_sample_fn_multinomial(op):
332
331
333
332
from numpyro .distributions .util import multinomial
334
333
335
- def sample_fn (rng , size , dtype , n , p ):
336
- rng_key = rng ["jax_state" ]
334
+ def sample_fn (rng_key , size , dtype , n , p ):
337
335
rng_key , sampling_key = jax .random .split (rng_key , 2 )
338
336
339
337
sample = multinomial (key = sampling_key , n = n , p = p , shape = size )
340
338
341
- rng ["jax_state" ] = rng_key
342
-
343
- return (rng , sample )
339
+ return (rng_key , sample )
344
340
345
341
return sample_fn
346
342
@@ -355,17 +351,14 @@ def jax_sample_fn_vonmises(op):
355
351
356
352
from numpyro .distributions .util import von_mises_centered
357
353
358
- def sample_fn (rng , size , dtype , mu , kappa ):
359
- rng_key = rng ["jax_state" ]
354
+ def sample_fn (rng_key , size , dtype , mu , kappa ):
360
355
rng_key , sampling_key = jax .random .split (rng_key , 2 )
361
356
362
357
sample = von_mises_centered (
363
358
key = sampling_key , concentration = kappa , shape = size , dtype = dtype
364
359
)
365
360
sample = (sample + mu + np .pi ) % (2.0 * np .pi ) - np .pi
366
361
367
- rng ["jax_state" ] = rng_key
368
-
369
- return (rng , sample )
362
+ return (rng_key , sample )
370
363
371
364
return sample_fn
0 commit comments