Skip to content

Commit 2823dfc

Browse files
committed
Faster python implementation of MvNormal
Also remove bad default values
1 parent 1ed3611 commit 2823dfc

File tree

3 files changed

+49
-52
lines changed

3 files changed

+49
-52
lines changed

pytensor/tensor/random/basic.py

Lines changed: 17 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33

44
import numpy as np
55
import scipy.stats as stats
6+
from numpy import broadcast_shapes as np_broadcast_shapes
7+
from numpy import einsum as np_einsum
8+
from numpy.linalg import cholesky as np_cholesky
69

710
import pytensor
811
from pytensor.tensor import get_vector_length, specify_shape
@@ -831,27 +834,6 @@ def __call__(self, mu, kappa, size=None, **kwargs):
831834
vonmises = VonMisesRV()
832835

833836

834-
def safe_multivariate_normal(mean, cov, size=None, rng=None):
835-
"""A shape consistent multivariate normal sampler.
836-
837-
What we mean by "shape consistent": SciPy will return scalars when the
838-
arguments are vectors with dimension of size 1. We require that the output
839-
be at least 1D, so that it's consistent with the underlying random
840-
variable.
841-
842-
"""
843-
res = np.atleast_1d(
844-
stats.multivariate_normal(mean=mean, cov=cov, allow_singular=True).rvs(
845-
size=size, random_state=rng
846-
)
847-
)
848-
849-
if size is not None:
850-
res = res.reshape([*size, -1])
851-
852-
return res
853-
854-
855837
class MvNormalRV(RandomVariable):
856838
r"""A multivariate normal random variable.
857839
@@ -904,25 +886,20 @@ def __call__(self, mean=None, cov=None, size=None, **kwargs):
904886

905887
@classmethod
906888
def rng_fn(cls, rng, mean, cov, size):
907-
if mean.ndim > 1 or cov.ndim > 2:
908-
# Neither SciPy nor NumPy implement parameter broadcasting for
909-
# multivariate normals (or any other multivariate distributions),
910-
# so we need to implement that here
911-
912-
if size is None:
913-
mean, cov = broadcast_params([mean, cov], [1, 2])
914-
else:
915-
mean = np.broadcast_to(mean, size + mean.shape[-1:])
916-
cov = np.broadcast_to(cov, size + cov.shape[-2:])
917-
918-
res = np.empty(mean.shape)
919-
for idx in np.ndindex(mean.shape[:-1]):
920-
m = mean[idx]
921-
c = cov[idx]
922-
res[idx] = safe_multivariate_normal(m, c, rng=rng)
923-
return res
924-
else:
925-
return safe_multivariate_normal(mean, cov, size=size, rng=rng)
889+
if size is None:
890+
size = np_broadcast_shapes(mean.shape[:-1], cov.shape[:-2])
891+
892+
chol = np_cholesky(cov)
893+
out = rng.normal(size=(*size, mean.shape[-1]))
894+
np_einsum(
895+
"...ij,...j->...i", # numpy doesn't have a batch matrix-vector product
896+
chol,
897+
out,
898+
out=out,
899+
optimize=False, # Nothing to optimize with two operands, skip costly setup
900+
)
901+
out += mean
902+
return out
926903

927904

928905
multivariate_normal = MvNormalRV()

tests/tensor/random/rewriting/test_basic.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -778,8 +778,10 @@ def rand_bool_mask(shape, rng=None):
778778
multivariate_normal,
779779
(
780780
np.array([200, 250], dtype=config.floatX),
781-
# Second covariance is invalid, to test it is not chosen
782-
np.dstack([np.eye(2), np.eye(2) * 0, np.eye(2)]).T.astype(config.floatX)
781+
# Second covariance is very large, to test it is not chosen
782+
np.dstack([np.eye(2), np.eye(2) * 1000, np.eye(2)]).T.astype(
783+
config.floatX
784+
)
783785
* 1e-6,
784786
),
785787
(3,),

tests/tensor/random/test_basic.py

Lines changed: 28 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -521,13 +521,19 @@ def test_fn(shape, scale, **kwargs):
521521

522522

523523
def mvnormal_test_fn(mean=None, cov=None, size=None, random_state=None):
524-
if mean is None:
525-
mean = np.array([0.0], dtype=config.floatX)
526-
if cov is None:
527-
cov = np.array([[1.0]], dtype=config.floatX)
528-
if size is not None:
529-
size = tuple(size)
530-
return multivariate_normal.rng_fn(random_state, mean, cov, size)
524+
rng = random_state if random_state is not None else np.random.default_rng()
525+
526+
if size is None:
527+
size = np.broadcast_shapes(mean.shape[:-1], cov.shape[:-2])
528+
529+
mean = np.broadcast_to(mean, (*size, *mean.shape[-1:]))
530+
cov = np.broadcast_to(cov, (*size, *cov.shape[-2:]))
531+
532+
@np.vectorize(signature="(n),(n,n)->(n)")
533+
def vec_mvnormal(mean, cov):
534+
return rng.multivariate_normal(mean, cov, method="cholesky")
535+
536+
return vec_mvnormal(mean, cov)
531537

532538

533539
@pytest.mark.parametrize(
@@ -609,18 +615,30 @@ def mvnormal_test_fn(mean=None, cov=None, size=None, random_state=None):
609615
),
610616
],
611617
)
618+
@pytest.mark.skipif(
619+
config.floatX == "float32",
620+
reason="Draws are only strictly equal to numpy in float64",
621+
)
612622
def test_mvnormal_samples(mu, cov, size):
613623
compare_sample_values(
614624
multivariate_normal, mu, cov, size=size, test_fn=mvnormal_test_fn
615625
)
616626

617627

618-
def test_mvnormal_default_args():
619-
compare_sample_values(multivariate_normal, test_fn=mvnormal_test_fn)
628+
def test_mvnormal_no_default_args():
629+
with pytest.raises(
630+
TypeError, match="missing 2 required positional arguments: 'mean' and 'cov'"
631+
):
632+
multivariate_normal()
633+
620634

635+
def test_mvnormal_impl_catches_incompatible_size():
621636
with pytest.raises(ValueError, match="operands could not be broadcast together "):
622637
multivariate_normal.rng_fn(
623-
None, np.zeros((3, 2)), np.ones((3, 2, 2)), size=(4,)
638+
np.random.default_rng(),
639+
np.zeros((3, 2)),
640+
np.broadcast_to(np.eye(2), (3, 2, 2)),
641+
size=(4,),
624642
)
625643

626644

0 commit comments

Comments
 (0)