Skip to content

Commit 751365e

Browse files
committed
Typify RNG input variables in JAX linker
1 parent cd44a2b commit 751365e

File tree

4 files changed

+85
-56
lines changed

4 files changed

+85
-56
lines changed

pytensor/link/basic.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -609,6 +609,9 @@ def create_thunk_inputs(self, storage_map: Dict[Variable, List[Any]]) -> List[An
609609
def jit_compile(self, fn: Callable) -> Callable:
610610
"""JIT compile a converted ``FunctionGraph``."""
611611

612+
def typify(self, var: Variable):
613+
return var
614+
612615
def output_filter(self, var: Variable, out: Any) -> Any:
613616
"""Apply a filter to the data output by a JITed function call."""
614617
return out
@@ -735,7 +738,7 @@ def make_all(self, input_storage=None, output_storage=None, storage_map=None):
735738
return (
736739
fn,
737740
[
738-
Container(input, storage)
741+
Container(self.typify(input), storage)
739742
for input, storage in zip(fgraph.inputs, input_storage)
740743
],
741744
[

pytensor/link/jax/dispatch/random.py

+47-54
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import pytensor.tensor.random.basic as aer
1111
from pytensor.link.jax.dispatch.basic import jax_funcify, jax_typify
1212
from pytensor.link.jax.dispatch.shape import JAXShapeTuple
13+
from pytensor.tensor.random.type import RandomType
1314
from pytensor.tensor.shape import Shape, Shape_i
1415

1516

@@ -55,8 +56,7 @@ def jax_typify_RandomState(state, **kwargs):
5556
state = state.get_state(legacy=False)
5657
state["bit_generator"] = numpy_bit_gens[state["bit_generator"]]
5758
# XXX: Is this a reasonable approach?
58-
state["jax_state"] = state["state"]["key"][0:2]
59-
return state
59+
return state["state"]["key"][0:2]
6060

6161

6262
@jax_typify.register(Generator)
@@ -81,7 +81,27 @@ def jax_typify_Generator(rng, **kwargs):
8181
state_32 = _coerce_to_uint32_array(state["state"]["state"])
8282
state["state"]["inc"] = inc_32[0] << 32 | inc_32[1]
8383
state["state"]["state"] = state_32[0] << 32 | state_32[1]
84-
return state
84+
return state["jax_state"]
85+
86+
87+
class RandomPRNGKeyType(RandomType[jax.random.PRNGKey]):
88+
def filter(self, data, strict: bool = False, allow_downcast=None):
89+
# PRNGs are just JAX Arrays, we assume this is a valid one!
90+
if isinstance(data, jax.Array):
91+
return data
92+
93+
if strict:
94+
raise TypeError()
95+
96+
return jax_typify(data)
97+
98+
99+
random_prng_key_type = RandomPRNGKeyType()
100+
101+
102+
@jax_typify.register(RandomType)
103+
def jax_typify_RandomType(type):
104+
return random_prng_key_type()
85105

86106

87107
@jax_funcify.register(aer.RandomVariable)
@@ -128,12 +148,10 @@ def jax_sample_fn_generic(op):
128148
name = op.name
129149
jax_op = getattr(jax.random, name)
130150

131-
def sample_fn(rng, size, dtype, *parameters):
132-
rng_key = rng["jax_state"]
151+
def sample_fn(rng_key, size, dtype, *parameters):
133152
rng_key, sampling_key = jax.random.split(rng_key, 2)
134153
sample = jax_op(sampling_key, *parameters, shape=size, dtype=dtype)
135-
rng["jax_state"] = rng_key
136-
return (rng, sample)
154+
return (rng_key, sample)
137155

138156
return sample_fn
139157

@@ -155,13 +173,11 @@ def jax_sample_fn_loc_scale(op):
155173
name = op.name
156174
jax_op = getattr(jax.random, name)
157175

158-
def sample_fn(rng, size, dtype, *parameters):
159-
rng_key = rng["jax_state"]
176+
def sample_fn(rng_key, size, dtype, *parameters):
160177
rng_key, sampling_key = jax.random.split(rng_key, 2)
161178
loc, scale = parameters
162179
sample = loc + jax_op(sampling_key, size, dtype) * scale
163-
rng["jax_state"] = rng_key
164-
return (rng, sample)
180+
return (rng_key, sample)
165181

166182
return sample_fn
167183

@@ -173,12 +189,10 @@ def jax_sample_fn_no_dtype(op):
173189
name = op.name
174190
jax_op = getattr(jax.random, name)
175191

176-
def sample_fn(rng, size, dtype, *parameters):
177-
rng_key = rng["jax_state"]
192+
def sample_fn(rng_key, size, dtype, *parameters):
178193
rng_key, sampling_key = jax.random.split(rng_key, 2)
179194
sample = jax_op(sampling_key, *parameters, shape=size)
180-
rng["jax_state"] = rng_key
181-
return (rng, sample)
195+
return (rng_key, sample)
182196

183197
return sample_fn
184198

@@ -199,15 +213,13 @@ def jax_sample_fn_uniform(op):
199213
name = "randint"
200214
jax_op = getattr(jax.random, name)
201215

202-
def sample_fn(rng, size, dtype, *parameters):
203-
rng_key = rng["jax_state"]
216+
def sample_fn(rng_key, size, dtype, *parameters):
204217
rng_key, sampling_key = jax.random.split(rng_key, 2)
205218
minval, maxval = parameters
206219
sample = jax_op(
207220
sampling_key, shape=size, dtype=dtype, minval=minval, maxval=maxval
208221
)
209-
rng["jax_state"] = rng_key
210-
return (rng, sample)
222+
return (rng_key, sample)
211223

212224
return sample_fn
213225

@@ -224,13 +236,11 @@ def jax_sample_fn_shape_rate(op):
224236
name = op.name
225237
jax_op = getattr(jax.random, name)
226238

227-
def sample_fn(rng, size, dtype, *parameters):
228-
rng_key = rng["jax_state"]
239+
def sample_fn(rng_key, size, dtype, *parameters):
229240
rng_key, sampling_key = jax.random.split(rng_key, 2)
230241
(shape, rate) = parameters
231242
sample = jax_op(sampling_key, shape, size, dtype) / rate
232-
rng["jax_state"] = rng_key
233-
return (rng, sample)
243+
return (rng_key, sample)
234244

235245
return sample_fn
236246

@@ -239,13 +249,11 @@ def sample_fn(rng, size, dtype, *parameters):
239249
def jax_sample_fn_exponential(op):
240250
"""JAX implementation of `ExponentialRV`."""
241251

242-
def sample_fn(rng, size, dtype, *parameters):
243-
rng_key = rng["jax_state"]
252+
def sample_fn(rng_key, size, dtype, *parameters):
244253
rng_key, sampling_key = jax.random.split(rng_key, 2)
245254
(scale,) = parameters
246255
sample = jax.random.exponential(sampling_key, size, dtype) * scale
247-
rng["jax_state"] = rng_key
248-
return (rng, sample)
256+
return (rng_key, sample)
249257

250258
return sample_fn
251259

@@ -254,17 +262,15 @@ def sample_fn(rng, size, dtype, *parameters):
254262
def jax_sample_fn_t(op):
255263
"""JAX implementation of `StudentTRV`."""
256264

257-
def sample_fn(rng, size, dtype, *parameters):
258-
rng_key = rng["jax_state"]
265+
def sample_fn(rng_key, size, dtype, *parameters):
259266
rng_key, sampling_key = jax.random.split(rng_key, 2)
260267
(
261268
df,
262269
loc,
263270
scale,
264271
) = parameters
265272
sample = loc + jax.random.t(sampling_key, df, size, dtype) * scale
266-
rng["jax_state"] = rng_key
267-
return (rng, sample)
273+
return (rng_key, sample)
268274

269275
return sample_fn
270276

@@ -273,13 +279,11 @@ def sample_fn(rng, size, dtype, *parameters):
273279
def jax_funcify_choice(op):
274280
"""JAX implementation of `ChoiceRV`."""
275281

276-
def sample_fn(rng, size, dtype, *parameters):
277-
rng_key = rng["jax_state"]
282+
def sample_fn(rng_key, size, dtype, *parameters):
278283
rng_key, sampling_key = jax.random.split(rng_key, 2)
279284
(a, p, replace) = parameters
280285
smpl_value = jax.random.choice(sampling_key, a, size, replace, p)
281-
rng["jax_state"] = rng_key
282-
return (rng, smpl_value)
286+
return (rng_key, smpl_value)
283287

284288
return sample_fn
285289

@@ -288,13 +292,11 @@ def sample_fn(rng, size, dtype, *parameters):
288292
def jax_sample_fn_permutation(op):
289293
"""JAX implementation of `PermutationRV`."""
290294

291-
def sample_fn(rng, size, dtype, *parameters):
292-
rng_key = rng["jax_state"]
295+
def sample_fn(rng_key, size, dtype, *parameters):
293296
rng_key, sampling_key = jax.random.split(rng_key, 2)
294297
(x,) = parameters
295298
sample = jax.random.permutation(sampling_key, x)
296-
rng["jax_state"] = rng_key
297-
return (rng, sample)
299+
return (rng_key, sample)
298300

299301
return sample_fn
300302

@@ -309,15 +311,12 @@ def jax_sample_fn_binomial(op):
309311

310312
from numpyro.distributions.util import binomial
311313

312-
def sample_fn(rng, size, dtype, n, p):
313-
rng_key = rng["jax_state"]
314+
def sample_fn(rng_key, size, dtype, n, p):
314315
rng_key, sampling_key = jax.random.split(rng_key, 2)
315316

316317
sample = binomial(key=sampling_key, n=n, p=p, shape=size)
317318

318-
rng["jax_state"] = rng_key
319-
320-
return (rng, sample)
319+
return (rng_key, sample)
321320

322321
return sample_fn
323322

@@ -332,15 +331,12 @@ def jax_sample_fn_multinomial(op):
332331

333332
from numpyro.distributions.util import multinomial
334333

335-
def sample_fn(rng, size, dtype, n, p):
336-
rng_key = rng["jax_state"]
334+
def sample_fn(rng_key, size, dtype, n, p):
337335
rng_key, sampling_key = jax.random.split(rng_key, 2)
338336

339337
sample = multinomial(key=sampling_key, n=n, p=p, shape=size)
340338

341-
rng["jax_state"] = rng_key
342-
343-
return (rng, sample)
339+
return (rng_key, sample)
344340

345341
return sample_fn
346342

@@ -355,17 +351,14 @@ def jax_sample_fn_vonmises(op):
355351

356352
from numpyro.distributions.util import von_mises_centered
357353

358-
def sample_fn(rng, size, dtype, mu, kappa):
359-
rng_key = rng["jax_state"]
354+
def sample_fn(rng_key, size, dtype, mu, kappa):
360355
rng_key, sampling_key = jax.random.split(rng_key, 2)
361356

362357
sample = von_mises_centered(
363358
key=sampling_key, concentration=kappa, shape=size, dtype=dtype
364359
)
365360
sample = (sample + mu + np.pi) % (2.0 * np.pi) - np.pi
366361

367-
rng["jax_state"] = rng_key
368-
369-
return (rng, sample)
362+
return (rng_key, sample)
370363

371364
return sample_fn

pytensor/link/jax/linker.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from numpy.random import Generator, RandomState
44

55
from pytensor.compile.sharedvalue import SharedVariable, shared
6-
from pytensor.graph.basic import Constant
6+
from pytensor.graph.basic import Constant, Variable
77
from pytensor.link.basic import JITLinker
88

99

@@ -63,6 +63,11 @@ def jit_compile(self, fn):
6363
]
6464
return jax.jit(fn, static_argnums=static_argnums)
6565

66+
def typify(self, var: Variable):
67+
from pytensor.link.jax.dispatch import jax_typify
68+
69+
return jax_typify(var.type)
70+
6671
def create_thunk_inputs(self, storage_map):
6772
from pytensor.link.jax.dispatch import jax_typify
6873

tests/link/jax/test_random.py

+28
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from pytensor.graph.basic import Constant
1313
from pytensor.graph.fg import FunctionGraph
1414
from pytensor.tensor.random.basic import RandomVariable
15+
from pytensor.tensor.random.type import random_generator_type
1516
from pytensor.tensor.random.utils import RandomStream
1617
from tests.link.jax.test_basic import compare_jax_and_py, jax_mode, set_test_value
1718

@@ -22,6 +23,33 @@
2223
from pytensor.link.jax.dispatch.random import numpyro_available # noqa: E402
2324

2425

26+
def test_rng_io():
27+
rng = random_generator_type("rng")
28+
next_rng, x = aer.normal(rng=rng).owner.outputs
29+
fn = pytensor.function([rng], [next_rng, x], mode="JAX")
30+
31+
np_rng = np.random.default_rng(0)
32+
np_rst = np.random.RandomState(1)
33+
jx_rng = jax.random.PRNGKey(2)
34+
35+
# Inputs - RNG outputs
36+
assert isinstance(fn(np_rng)[0], jax.Array)
37+
assert isinstance(fn(np_rst)[0], jax.Array)
38+
assert isinstance(fn(jx_rng)[0], jax.Array)
39+
40+
# Inputs - Value outputs
41+
assert fn(np_rng)[1] == fn(np_rng)[1]
42+
assert fn(np_rst)[1] == fn(np_rst)[1]
43+
assert fn(jx_rng)[1] == fn(jx_rng)[1]
44+
assert fn(np_rng)[1] != fn(np_rst)[1]
45+
assert fn(np_rng)[1] != fn(jx_rng)[1]
46+
47+
# Chained Inputs - RNG / Value outputs
48+
assert fn(fn(np_rng)[0])[1] != fn(np_rng)[1]
49+
assert fn(fn(np_rst)[0])[1] != fn(np_rst)[1]
50+
assert fn(fn(jx_rng)[0])[1] != fn(jx_rng)[1]
51+
52+
2553
def test_random_RandomStream():
2654
"""Two successive calls of a compiled graph using `RandomStream` should
2755
return different values.

0 commit comments

Comments
 (0)