Skip to content

Commit 51210c3

Browse files
committed
Extend supported RandomVariables in JAX backend via NumPyro
Dependency is optional
1 parent dcd24a3 commit 51210c3

File tree

5 files changed

+206
-8
lines changed

5 files changed

+206
-8
lines changed

.github/workflows/test.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ jobs:
117117
run: |
118118
mamba install --yes -q "python~=${PYTHON_VERSION}=*_cpython" mkl numpy scipy pip mkl-service graphviz cython pytest coverage pytest-cov pytest-benchmark sympy
119119
if [[ $INSTALL_NUMBA == "1" ]]; then mamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" "numba>=0.55" numba-scipy; fi
120-
mamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" jax jaxlib
120+
mamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" jax jaxlib numpyro
121121
pip install -e ./
122122
mamba list && pip freeze
123123
python -c 'import pytensor; print(pytensor.config.__str__(print_doc=False))'

pytensor/link/jax/dispatch/extra_ops.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import jax.numpy as jnp
44

55
from pytensor.link.jax.dispatch.basic import jax_funcify
6+
from pytensor.tensor.basic import infer_static_shape
67
from pytensor.tensor.extra_ops import (
78
Bartlett,
89
BroadcastTo,
@@ -102,8 +103,12 @@ def ravelmultiindex(*inp, mode=mode, order=order):
102103

103104

104105
@jax_funcify.register(BroadcastTo)
105-
def jax_funcify_BroadcastTo(op, **kwargs):
106+
def jax_funcify_BroadcastTo(op, node, **kwargs):
107+
shape = node.inputs[1:]
108+
static_shape = infer_static_shape(shape)[1]
109+
106110
def broadcast_to(x, *shape):
111+
shape = tuple(st if st is not None else s for s, st in zip(shape, static_shape))
107112
return jnp.broadcast_to(x, shape)
108113

109114
return broadcast_to

pytensor/link/jax/dispatch/random.py

+82-5
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from functools import singledispatch
22

33
import jax
4+
import numpy as np
45
from numpy.random import Generator, RandomState
56
from numpy.random.bit_generator import ( # type: ignore[attr-defined]
67
_coerce_to_uint32_array,
@@ -12,6 +13,13 @@
1213
from pytensor.tensor.shape import Shape, Shape_i
1314

1415

16+
try:
17+
import numpyro # noqa: F401
18+
19+
numpyro_available = True
20+
except ImportError:
21+
numpyro_available = False
22+
1523
numpy_bit_gens = {"MT19937": 0, "PCG64": 1, "Philox": 2, "SFC64": 3}
1624

1725

@@ -83,11 +91,8 @@ def jax_funcify_RandomVariable(op, node, **kwargs):
8391
out_dtype = rv.type.dtype
8492
out_size = rv.type.shape
8593

86-
if isinstance(op, aer.MvNormalRV):
87-
# PyTensor sets the `size` to the concatenation of the support shape
88-
# and the batch shape, while JAX explicitly requires the batch
89-
# shape only for the multivariate normal.
90-
out_size = node.outputs[1].type.shape[:-1]
94+
if op.ndim_supp > 0:
95+
out_size = node.outputs[1].type.shape[: -op.ndim_supp]
9196

9297
# If one dimension has unknown size, either the size is determined
9398
# by a `Shape` operator in which case JAX will compile, or it is
@@ -292,3 +297,75 @@ def sample_fn(rng, size, dtype, *parameters):
292297
return (rng, sample)
293298

294299
return sample_fn
300+
301+
302+
@jax_sample_fn.register(aer.BinomialRV)
303+
def jax_sample_fn_binomial(op):
304+
if not numpyro_available:
305+
raise NotImplementedError(
306+
f"No JAX implementation for the given distribution: {op.name}. "
307+
"Implementation is available if NumPyro is installed."
308+
)
309+
310+
from numpyro.distributions.util import binomial
311+
312+
def sample_fn(rng, size, dtype, n, p):
313+
rng_key = rng["jax_state"]
314+
rng_key, sampling_key = jax.random.split(rng_key, 2)
315+
316+
sample = binomial(key=sampling_key, n=n, p=p, shape=size)
317+
318+
rng["jax_state"] = rng_key
319+
320+
return (rng, sample)
321+
322+
return sample_fn
323+
324+
325+
@jax_sample_fn.register(aer.MultinomialRV)
326+
def jax_sample_fn_multinomial(op):
327+
if not numpyro_available:
328+
raise NotImplementedError(
329+
f"No JAX implementation for the given distribution: {op.name}. "
330+
"Implementation is available if NumPyro is installed."
331+
)
332+
333+
from numpyro.distributions.util import multinomial
334+
335+
def sample_fn(rng, size, dtype, n, p):
336+
rng_key = rng["jax_state"]
337+
rng_key, sampling_key = jax.random.split(rng_key, 2)
338+
339+
sample = multinomial(key=sampling_key, n=n, p=p, shape=size)
340+
341+
rng["jax_state"] = rng_key
342+
343+
return (rng, sample)
344+
345+
return sample_fn
346+
347+
348+
@jax_sample_fn.register(aer.VonMisesRV)
349+
def jax_sample_fn_vonmises(op):
350+
if not numpyro_available:
351+
raise NotImplementedError(
352+
f"No JAX implementation for the given distribution: {op.name}. "
353+
"Implementation is available if NumPyro is installed."
354+
)
355+
356+
from numpyro.distributions.util import von_mises_centered
357+
358+
def sample_fn(rng, size, dtype, mu, kappa):
359+
rng_key = rng["jax_state"]
360+
rng_key, sampling_key = jax.random.split(rng_key, 2)
361+
362+
sample = von_mises_centered(
363+
key=sampling_key, concentration=kappa, shape=size, dtype=dtype
364+
)
365+
sample = (sample + mu + np.pi) % (2.0 * np.pi) - np.pi
366+
367+
rng["jax_state"] = rng_key
368+
369+
return (rng, sample)
370+
371+
return sample_fn

pytensor/tensor/random/rewriting/jax.py

+18-1
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,11 @@
22
from pytensor.graph.rewriting.basic import in2out, node_rewriter
33
from pytensor.graph.rewriting.db import SequenceDB
44
from pytensor.tensor import abs as abs_t
5-
from pytensor.tensor import exp, floor, log, log1p, reciprocal, sqrt
5+
from pytensor.tensor import broadcast_arrays, exp, floor, log, log1p, reciprocal, sqrt
66
from pytensor.tensor.basic import MakeVector, cast, ones_like, switch, zeros_like
77
from pytensor.tensor.elemwise import DimShuffle
88
from pytensor.tensor.random.basic import (
9+
BetaBinomialRV,
910
ChiSquareRV,
1011
GenGammaRV,
1112
GeometricRV,
@@ -14,6 +15,8 @@
1415
LogNormalRV,
1516
NegBinomialRV,
1617
WaldRV,
18+
beta,
19+
binomial,
1720
gamma,
1821
normal,
1922
poisson,
@@ -133,6 +136,15 @@ def wald_from_normal_uniform(fgraph, node):
133136
return [next_rng, cast(w, dtype=node.default_output().dtype)]
134137

135138

139+
@node_rewriter([BetaBinomialRV])
140+
def beta_binomial_from_beta_binomial(fgraph, node):
141+
rng, *other_inputs, n, a, b = node.inputs
142+
n, a, b = broadcast_arrays(n, a, b)
143+
next_rng, b = beta.make_node(rng, *other_inputs, a, b).outputs
144+
next_rng, b = binomial.make_node(next_rng, *other_inputs, n, b).outputs
145+
return [next_rng, b]
146+
147+
136148
random_vars_opt = SequenceDB()
137149
random_vars_opt.register(
138150
"lognormal_from_normal",
@@ -174,6 +186,11 @@ def wald_from_normal_uniform(fgraph, node):
174186
in2out(wald_from_normal_uniform),
175187
"jax",
176188
)
189+
random_vars_opt.register(
190+
"beta_binomial_from_beta_binomial",
191+
in2out(beta_binomial_from_beta_binomial),
192+
"jax",
193+
)
177194
optdb.register("jax_random_vars_rewrites", random_vars_opt, "jax", position=110)
178195

179196
optdb.register(

tests/link/jax/test_random.py

+99
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@
1919
jax = pytest.importorskip("jax")
2020

2121

22+
from pytensor.link.jax.dispatch.random import numpyro_available # noqa: E402
23+
24+
2225
def test_random_RandomStream():
2326
"""Two successive calls of a compiled graph using `RandomStream` should
2427
return different values.
@@ -377,6 +380,25 @@ def test_random_updates(rng_ctor):
377380
# https://stackoverflow.com/a/48603469
378381
lambda mean, scale: (mean / scale, 0, scale),
379382
),
383+
pytest.param(
384+
aer.vonmises,
385+
[
386+
set_test_value(
387+
at.dvector(),
388+
np.array([-0.5, 1.3], dtype=np.float64),
389+
),
390+
set_test_value(
391+
at.dvector(),
392+
np.array([5.5, 13.0], dtype=np.float64),
393+
),
394+
],
395+
(2,),
396+
"vonmises",
397+
lambda mu, kappa: (kappa, mu),
398+
marks=pytest.mark.skipif(
399+
not numpyro_available, reason="VonMises dispatch requires numpyro"
400+
),
401+
),
380402
],
381403
)
382404
def test_random_RandomVariable(rv_op, dist_params, base_size, cdf_name, params_conv):
@@ -519,6 +541,83 @@ def test_negative_binomial():
519541
)
520542

521543

544+
@pytest.mark.skipif(not numpyro_available, reason="Binomial dispatch requires numpyro")
545+
def test_binomial():
546+
rng = shared(np.random.RandomState(123))
547+
n = np.array([10, 40])
548+
p = np.array([0.3, 0.7])
549+
g = at.random.binomial(n, p, size=(10_000, 2), rng=rng)
550+
g_fn = function([], g, mode=jax_mode)
551+
samples = g_fn()
552+
np.testing.assert_allclose(samples.mean(axis=0), n * p, rtol=0.1)
553+
np.testing.assert_allclose(samples.std(axis=0), np.sqrt(n * p * (1 - p)), rtol=0.1)
554+
555+
556+
@pytest.mark.skipif(
557+
not numpyro_available, reason="BetaBinomial dispatch requires numpyro"
558+
)
559+
def test_beta_binomial():
560+
rng = shared(np.random.RandomState(123))
561+
n = np.array([10, 40])
562+
a = np.array([1.5, 13])
563+
b = np.array([0.5, 9])
564+
g = at.random.betabinom(n, a, b, size=(10_000, 2), rng=rng)
565+
g_fn = function([], g, mode=jax_mode)
566+
samples = g_fn()
567+
np.testing.assert_allclose(samples.mean(axis=0), n * a / (a + b), rtol=0.1)
568+
np.testing.assert_allclose(
569+
samples.std(axis=0),
570+
np.sqrt((n * a * b * (a + b + n)) / ((a + b) ** 2 * (a + b + 1))),
571+
rtol=0.1,
572+
)
573+
574+
575+
@pytest.mark.skipif(
576+
not numpyro_available, reason="Multinomial dispatch requires numpyro"
577+
)
578+
def test_multinomial():
579+
rng = shared(np.random.RandomState(123))
580+
n = np.array([10, 40])
581+
p = np.array([[0.3, 0.7, 0.0], [0.1, 0.4, 0.5]])
582+
g = at.random.multinomial(n, p, size=(10_000, 2), rng=rng)
583+
g_fn = function([], g, mode=jax_mode)
584+
samples = g_fn()
585+
np.testing.assert_allclose(samples.mean(axis=0), n[..., None] * p, rtol=0.1)
586+
np.testing.assert_allclose(
587+
samples.std(axis=0), np.sqrt(n[..., None] * p * (1 - p)), rtol=0.1
588+
)
589+
590+
591+
@pytest.mark.skipif(not numpyro_available, reason="VonMises dispatch requires numpyro")
592+
def test_vonmises_mu_outside_circle():
593+
# Scipy implementation does not behave as PyTensor/NumPy for mu outside the unit circle
594+
# We test that the random draws from the JAX dispatch work as expected in these cases
595+
rng = shared(np.random.RandomState(123))
596+
mu = np.array([-30, 40])
597+
kappa = np.array([100, 10])
598+
g = at.random.vonmises(mu, kappa, size=(10_000, 2), rng=rng)
599+
g_fn = function([], g, mode=jax_mode)
600+
samples = g_fn()
601+
np.testing.assert_allclose(
602+
samples.mean(axis=0), (mu + np.pi) % (2.0 * np.pi) - np.pi, rtol=0.1
603+
)
604+
605+
# Circvar only does the correct thing in more recent versions of Scipy
606+
# https://github.com/scipy/scipy/pull/5747
607+
# np.testing.assert_allclose(
608+
# stats.circvar(samples, axis=0),
609+
# 1 - special.iv(1, kappa) / special.iv(0, kappa),
610+
# rtol=0.1,
611+
# )
612+
613+
# For now simple compare with std from numpy draws
614+
rng = np.random.default_rng(123)
615+
ref_samples = rng.vonmises(mu, kappa, size=(10_000, 2))
616+
np.testing.assert_allclose(
617+
np.std(samples, axis=0), np.std(ref_samples, axis=0), rtol=0.1
618+
)
619+
620+
522621
def test_random_unimplemented():
523622
"""Compiling a graph with a non-supported `RandomVariable` should
524623
raise an error.

0 commit comments

Comments
 (0)