Skip to content

Simplify dispatch of JAX random variables by handling rng split automatically #1204

Closed
@ricardoV94

Description

@ricardoV94

Description

JAX inner-most dispatch for RandomVariables: jax_sample_fn, look like

@jax_sample_fn.register(ptr.CauchyRV)
@jax_sample_fn.register(ptr.GumbelRV)
@jax_sample_fn.register(ptr.LaplaceRV)
@jax_sample_fn.register(ptr.LogisticRV)
@jax_sample_fn.register(ptr.NormalRV)
def jax_sample_fn_loc_scale(op, node):
"""JAX implementation of random variables in the loc-scale families.
JAX only implements the standard version of random variables in the
loc-scale family. We thus need to translate and rescale the results
manually.
"""
name = op.name
jax_op = getattr(jax.random, name)
def sample_fn(rng, size, dtype, *parameters):
rng_key = rng["jax_state"]
rng_key, sampling_key = jax.random.split(rng_key, 2)
loc, scale = parameters
if size is None:
size = jax.numpy.broadcast_arrays(loc, scale)[0].shape
sample = loc + jax_op(sampling_key, size, dtype) * scale
rng["jax_state"] = rng_key
return (rng, sample)
return sample_fn

The whole rng logic could be handled on the outermost dispatch jax_funcify_RandomVariable instead:

if None in static_size:
assert_size_argument_jax_compatible(node)
def sample_fn(rng, size, *parameters):
return jax_sample_fn(op, node=node)(rng, size, out_dtype, *parameters)
else:
def sample_fn(rng, size, *parameters):
return jax_sample_fn(op, node=node)(
rng, static_size, out_dtype, *parameters
)
return sample_fn

If an implementation needs a split other than 2, they can split the provided rng again anyway.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions