Skip to content

Commit 4ce0b07

Browse files
committed
Typify RNG input variables in JAX linker
1 parent 5c87d74 commit 4ce0b07

File tree

4 files changed

+110
-58
lines changed

4 files changed

+110
-58
lines changed

pytensor/link/basic.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -609,6 +609,9 @@ def create_thunk_inputs(self, storage_map: Dict[Variable, List[Any]]) -> List[An
609609
def jit_compile(self, fn: Callable) -> Callable:
610610
"""JIT compile a converted ``FunctionGraph``."""
611611

612+
def typify(self, var: Variable):
613+
return var
614+
612615
def output_filter(self, var: Variable, out: Any) -> Any:
613616
"""Apply a filter to the data output by a JITed function call."""
614617
return out
@@ -735,7 +738,7 @@ def make_all(self, input_storage=None, output_storage=None, storage_map=None):
735738
return (
736739
fn,
737740
[
738-
Container(input, storage)
741+
Container(self.typify(input), storage)
739742
for input, storage in zip(fgraph.inputs, input_storage)
740743
],
741744
[

pytensor/link/jax/dispatch/random.py

+56-54
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import pytensor.tensor.random.basic as aer
1111
from pytensor.link.jax.dispatch.basic import jax_funcify, jax_typify
1212
from pytensor.link.jax.dispatch.shape import JAXShapeTuple
13+
from pytensor.tensor.random.type import RandomType
1314
from pytensor.tensor.shape import Shape, Shape_i
1415

1516

@@ -57,8 +58,7 @@ def jax_typify_RandomState(state, **kwargs):
5758
state = state.get_state(legacy=False)
5859
state["bit_generator"] = numpy_bit_gens[state["bit_generator"]]
5960
# XXX: Is this a reasonable approach?
60-
state["jax_state"] = state["state"]["key"][0:2]
61-
return state
61+
return state["state"]["key"][0:2]
6262

6363

6464
@jax_typify.register(Generator)
@@ -83,7 +83,36 @@ def jax_typify_Generator(rng, **kwargs):
8383
state_32 = _coerce_to_uint32_array(state["state"]["state"])
8484
state["state"]["inc"] = inc_32[0] << 32 | inc_32[1]
8585
state["state"]["state"] = state_32[0] << 32 | state_32[1]
86-
return state
86+
return state["jax_state"]
87+
88+
89+
class RandomPRNGKeyType(RandomType[jax.random.PRNGKey]):
90+
"""JAX-compatible PRNGKey type.
91+
92+
This type is not exposed to users directly.
93+
94+
It is introduced by the JIT linker in place of any RandomType input
95+
variables used in the original function. Nodes in the function graph will
96+
still show the original types as inputs and outputs.
97+
"""
98+
99+
def filter(self, data, strict: bool = False, allow_downcast=None):
100+
# PRNGs are just JAX Arrays, we assume this is a valid one!
101+
if isinstance(data, jax.Array):
102+
return data
103+
104+
if strict:
105+
raise TypeError()
106+
107+
return jax_typify(data)
108+
109+
110+
random_prng_key_type = RandomPRNGKeyType()
111+
112+
113+
@jax_typify.register(RandomType)
114+
def jax_typify_RandomType(type):
115+
return random_prng_key_type()
87116

88117

89118
@jax_funcify.register(aer.RandomVariable)
@@ -130,12 +159,10 @@ def jax_sample_fn_generic(op):
130159
name = op.name
131160
jax_op = getattr(jax.random, name)
132161

133-
def sample_fn(rng, size, dtype, *parameters):
134-
rng_key = rng["jax_state"]
162+
def sample_fn(rng_key, size, dtype, *parameters):
135163
rng_key, sampling_key = jax.random.split(rng_key, 2)
136164
sample = jax_op(sampling_key, *parameters, shape=size, dtype=dtype)
137-
rng["jax_state"] = rng_key
138-
return (rng, sample)
165+
return (rng_key, sample)
139166

140167
return sample_fn
141168

@@ -157,13 +184,11 @@ def jax_sample_fn_loc_scale(op):
157184
name = op.name
158185
jax_op = getattr(jax.random, name)
159186

160-
def sample_fn(rng, size, dtype, *parameters):
161-
rng_key = rng["jax_state"]
187+
def sample_fn(rng_key, size, dtype, *parameters):
162188
rng_key, sampling_key = jax.random.split(rng_key, 2)
163189
loc, scale = parameters
164190
sample = loc + jax_op(sampling_key, size, dtype) * scale
165-
rng["jax_state"] = rng_key
166-
return (rng, sample)
191+
return (rng_key, sample)
167192

168193
return sample_fn
169194

@@ -175,12 +200,10 @@ def jax_sample_fn_no_dtype(op):
175200
name = op.name
176201
jax_op = getattr(jax.random, name)
177202

178-
def sample_fn(rng, size, dtype, *parameters):
179-
rng_key = rng["jax_state"]
203+
def sample_fn(rng_key, size, dtype, *parameters):
180204
rng_key, sampling_key = jax.random.split(rng_key, 2)
181205
sample = jax_op(sampling_key, *parameters, shape=size)
182-
rng["jax_state"] = rng_key
183-
return (rng, sample)
206+
return (rng_key, sample)
184207

185208
return sample_fn
186209

@@ -201,15 +224,13 @@ def jax_sample_fn_uniform(op):
201224
name = "randint"
202225
jax_op = getattr(jax.random, name)
203226

204-
def sample_fn(rng, size, dtype, *parameters):
205-
rng_key = rng["jax_state"]
227+
def sample_fn(rng_key, size, dtype, *parameters):
206228
rng_key, sampling_key = jax.random.split(rng_key, 2)
207229
minval, maxval = parameters
208230
sample = jax_op(
209231
sampling_key, shape=size, dtype=dtype, minval=minval, maxval=maxval
210232
)
211-
rng["jax_state"] = rng_key
212-
return (rng, sample)
233+
return (rng_key, sample)
213234

214235
return sample_fn
215236

@@ -226,13 +247,11 @@ def jax_sample_fn_shape_rate(op):
226247
name = op.name
227248
jax_op = getattr(jax.random, name)
228249

229-
def sample_fn(rng, size, dtype, *parameters):
230-
rng_key = rng["jax_state"]
250+
def sample_fn(rng_key, size, dtype, *parameters):
231251
rng_key, sampling_key = jax.random.split(rng_key, 2)
232252
(shape, rate) = parameters
233253
sample = jax_op(sampling_key, shape, size, dtype) / rate
234-
rng["jax_state"] = rng_key
235-
return (rng, sample)
254+
return (rng_key, sample)
236255

237256
return sample_fn
238257

@@ -241,13 +260,11 @@ def sample_fn(rng, size, dtype, *parameters):
241260
def jax_sample_fn_exponential(op):
242261
"""JAX implementation of `ExponentialRV`."""
243262

244-
def sample_fn(rng, size, dtype, *parameters):
245-
rng_key = rng["jax_state"]
263+
def sample_fn(rng_key, size, dtype, *parameters):
246264
rng_key, sampling_key = jax.random.split(rng_key, 2)
247265
(scale,) = parameters
248266
sample = jax.random.exponential(sampling_key, size, dtype) * scale
249-
rng["jax_state"] = rng_key
250-
return (rng, sample)
267+
return (rng_key, sample)
251268

252269
return sample_fn
253270

@@ -256,17 +273,15 @@ def sample_fn(rng, size, dtype, *parameters):
256273
def jax_sample_fn_t(op):
257274
"""JAX implementation of `StudentTRV`."""
258275

259-
def sample_fn(rng, size, dtype, *parameters):
260-
rng_key = rng["jax_state"]
276+
def sample_fn(rng_key, size, dtype, *parameters):
261277
rng_key, sampling_key = jax.random.split(rng_key, 2)
262278
(
263279
df,
264280
loc,
265281
scale,
266282
) = parameters
267283
sample = loc + jax.random.t(sampling_key, df, size, dtype) * scale
268-
rng["jax_state"] = rng_key
269-
return (rng, sample)
284+
return (rng_key, sample)
270285

271286
return sample_fn
272287

@@ -275,13 +290,11 @@ def sample_fn(rng, size, dtype, *parameters):
275290
def jax_funcify_choice(op):
276291
"""JAX implementation of `ChoiceRV`."""
277292

278-
def sample_fn(rng, size, dtype, *parameters):
279-
rng_key = rng["jax_state"]
293+
def sample_fn(rng_key, size, dtype, *parameters):
280294
rng_key, sampling_key = jax.random.split(rng_key, 2)
281295
(a, p, replace) = parameters
282296
smpl_value = jax.random.choice(sampling_key, a, size, replace, p)
283-
rng["jax_state"] = rng_key
284-
return (rng, smpl_value)
297+
return (rng_key, smpl_value)
285298

286299
return sample_fn
287300

@@ -290,13 +303,11 @@ def sample_fn(rng, size, dtype, *parameters):
290303
def jax_sample_fn_permutation(op):
291304
"""JAX implementation of `PermutationRV`."""
292305

293-
def sample_fn(rng, size, dtype, *parameters):
294-
rng_key = rng["jax_state"]
306+
def sample_fn(rng_key, size, dtype, *parameters):
295307
rng_key, sampling_key = jax.random.split(rng_key, 2)
296308
(x,) = parameters
297309
sample = jax.random.permutation(sampling_key, x)
298-
rng["jax_state"] = rng_key
299-
return (rng, sample)
310+
return (rng_key, sample)
300311

301312
return sample_fn
302313

@@ -311,15 +322,12 @@ def jax_sample_fn_binomial(op):
311322

312323
from numpyro.distributions.util import binomial
313324

314-
def sample_fn(rng, size, dtype, n, p):
315-
rng_key = rng["jax_state"]
325+
def sample_fn(rng_key, size, dtype, n, p):
316326
rng_key, sampling_key = jax.random.split(rng_key, 2)
317327

318328
sample = binomial(key=sampling_key, n=n, p=p, shape=size)
319329

320-
rng["jax_state"] = rng_key
321-
322-
return (rng, sample)
330+
return (rng_key, sample)
323331

324332
return sample_fn
325333

@@ -334,15 +342,12 @@ def jax_sample_fn_multinomial(op):
334342

335343
from numpyro.distributions.util import multinomial
336344

337-
def sample_fn(rng, size, dtype, n, p):
338-
rng_key = rng["jax_state"]
345+
def sample_fn(rng_key, size, dtype, n, p):
339346
rng_key, sampling_key = jax.random.split(rng_key, 2)
340347

341348
sample = multinomial(key=sampling_key, n=n, p=p, shape=size)
342349

343-
rng["jax_state"] = rng_key
344-
345-
return (rng, sample)
350+
return (rng_key, sample)
346351

347352
return sample_fn
348353

@@ -357,17 +362,14 @@ def jax_sample_fn_vonmises(op):
357362

358363
from numpyro.distributions.util import von_mises_centered
359364

360-
def sample_fn(rng, size, dtype, mu, kappa):
361-
rng_key = rng["jax_state"]
365+
def sample_fn(rng_key, size, dtype, mu, kappa):
362366
rng_key, sampling_key = jax.random.split(rng_key, 2)
363367

364368
sample = von_mises_centered(
365369
key=sampling_key, concentration=kappa, shape=size, dtype=dtype
366370
)
367371
sample = (sample + mu + np.pi) % (2.0 * np.pi) - np.pi
368372

369-
rng["jax_state"] = rng_key
370-
371-
return (rng, sample)
373+
return (rng_key, sample)
372374

373375
return sample_fn

pytensor/link/jax/linker.py

+15-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from numpy.random import Generator, RandomState
44

55
from pytensor.compile.sharedvalue import SharedVariable, shared
6-
from pytensor.graph.basic import Constant
6+
from pytensor.graph.basic import Constant, Variable
77
from pytensor.link.basic import JITLinker
88

99

@@ -14,6 +14,15 @@ def fgraph_convert(self, fgraph, input_storage, storage_map, **kwargs):
1414
from pytensor.link.jax.dispatch import jax_funcify
1515
from pytensor.tensor.random.type import RandomType
1616

17+
if any(
18+
isinstance(inp.type, RandomType) and not isinstance(inp, SharedVariable)
19+
for inp in fgraph.inputs
20+
):
21+
warnings.warn(
22+
"RandomTypes are implicitly converted to random PRNGKey arrays in JAX. "
23+
"Input values should be provided in this format to avoid a conversion overhead."
24+
)
25+
1726
shared_rng_inputs = [
1827
inp
1928
for inp in fgraph.inputs
@@ -70,6 +79,11 @@ def jit_compile(self, fn):
7079
]
7180
return jax.jit(fn, static_argnums=static_argnums)
7281

82+
def typify(self, var: Variable):
83+
from pytensor.link.jax.dispatch import jax_typify
84+
85+
return jax_typify(var.type)
86+
7387
def create_thunk_inputs(self, storage_map):
7488
from pytensor.link.jax.dispatch import jax_typify
7589

tests/link/jax/test_random.py

+35-2
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,11 @@
55
import pytensor
66
import pytensor.tensor as at
77
import pytensor.tensor.random as aer
8-
from pytensor.compile.function import function
98
from pytensor.compile.sharedvalue import SharedVariable, shared
109
from pytensor.graph.basic import Constant
1110
from pytensor.graph.fg import FunctionGraph
1211
from pytensor.tensor.random.basic import RandomVariable
12+
from pytensor.tensor.random.type import random_generator_type, random_state_type
1313
from pytensor.tensor.random.utils import RandomStream
1414
from tests.link.jax.test_basic import compare_jax_and_py, jax_mode, set_test_value
1515

@@ -24,7 +24,40 @@ def random_function(*args, **kwargs):
2424
with pytest.warns(
2525
UserWarning, match=r"The RandomType SharedVariables \[.+\] will not be used"
2626
):
27-
return function(*args, **kwargs)
27+
return pytensor.function(*args, **kwargs)
28+
29+
30+
@pytest.mark.parametrize("random_type", ("generator", "state"))
31+
def test_rng_io(random_type):
32+
"""Test explicit (non-shared) input and output RNG types in JAX."""
33+
if random_type == "generator":
34+
rng = random_generator_type("rng")
35+
np_rng = np.random.default_rng(0)
36+
else:
37+
rng = random_state_type("rng")
38+
np_rng = np.random.RandomState(0)
39+
jx_rng = jax.random.PRNGKey(0)
40+
41+
next_rng, x = aer.normal(rng=rng).owner.outputs
42+
43+
with pytest.warns(
44+
UserWarning,
45+
match="RandomTypes are implicitly converted to random PRNGKey arrays",
46+
):
47+
fn = pytensor.function([rng], [next_rng, x], mode="JAX")
48+
49+
# Inputs - RNG outputs
50+
assert isinstance(fn(np_rng)[0], jax.Array)
51+
assert isinstance(fn(jx_rng)[0], jax.Array)
52+
53+
# Inputs - Value outputs
54+
assert fn(np_rng)[1] == fn(np_rng)[1]
55+
assert fn(jx_rng)[1] == fn(jx_rng)[1]
56+
assert fn(np_rng)[1] != fn(jx_rng)[1]
57+
58+
# Chained Inputs - RNG / Value outputs
59+
assert fn(fn(np_rng)[0])[1] != fn(np_rng)[1]
60+
assert fn(fn(jx_rng)[0])[1] != fn(jx_rng)[1]
2861

2962

3063
def test_random_RandomStream():

0 commit comments

Comments
 (0)