Closed
Description
Description
JAX inner-most dispatch for RandomVariables: jax_sample_fn
, look like
pytensor/pytensor/link/jax/dispatch/random.py
Lines 146 to 172 in 964cccb
The whole rng logic could be handled on the outermost dispatch jax_funcify_RandomVariable
instead:
pytensor/pytensor/link/jax/dispatch/random.py
Lines 104 to 117 in 964cccb
If an implementation needs a split other than 2, they can split the provided rng again anyway.