Skip to content

Commit 1c290ef

Browse files
authored
Fix MvNormal.random (#4207)
* Fixed MvNormal.random method * Added some comments * Handled corner case when self.shape is not provided * Fixed comments * Considered the corner case of 'point' as well. * Fixed MvNormal.random method * Handled sample and batch dimensions in tau parametrization using numpy Added batch dimensions to all parametrization * Modified logic while inserting batch dimensions to parametrization * Used shapes_utils.broadcast_dist_samples_to function for broadcasting * Make pylint pass * Make test passes 🤞 hopefully * Modified logic and added tests * Given a mention in release notes
1 parent 4cccb46 commit 1c290ef

File tree

4 files changed

+144
-53
lines changed

4 files changed

+144
-53
lines changed

RELEASE-NOTES.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ This new version of `Theano-PyMC` comes with an experimental JAX backend which,
4545
- Enabled the `Multinomial` distribution to handle batch sizes that have more than 2 dimensions. [#4169](https://github.com/pymc-devs/pymc3/pull/4169)
4646
- Test model logp before starting any MCMC chains (see [#4116](https://github.com/pymc-devs/pymc3/issues/4116))
4747
- Fix bug in `model.check_test_point` that caused the `test_point` argument to be ignored. (see [PR #4211](https://github.com/pymc-devs/pymc3/pull/4211#issuecomment-727142721))
48+
- Refactored MvNormal.random method with better handling of sample, batch and event shapes. [#4207](https://github.com/pymc-devs/pymc3/pull/4207)
4849

4950
### Documentation
5051
- Added a new notebook demonstrating how to incorporate sampling from a conjugate Dirichlet-multinomial posterior density in conjunction with other step methods (see [#4199](https://github.com/pymc-devs/pymc3/pull/4199)).

pymc3/distributions/multivariate.py

Lines changed: 31 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
from .continuous import ChiSquared, Normal
3737
from .special import gammaln, multigammaln
3838
from .dist_math import bound, logpow, factln
39-
from .shape_utils import to_tuple
39+
from .shape_utils import to_tuple, broadcast_dist_samples_to
4040
from ..math import kron_dot, kron_diag, kron_solve_lower, kronecker
4141

4242

@@ -250,58 +250,36 @@ def random(self, point=None, size=None):
250250
-------
251251
array
252252
"""
253-
if size is None:
254-
size = tuple()
255-
else:
256-
if not isinstance(size, tuple):
257-
try:
258-
size = tuple(size)
259-
except TypeError:
260-
size = (size,)
253+
size = to_tuple(size)
261254

262-
if self._cov_type == "cov":
263-
mu, cov = draw_values([self.mu, self.cov], point=point, size=size)
264-
if mu.shape[-1] != cov.shape[-1]:
265-
raise ValueError("Shapes for mu and cov don't match")
255+
param_attribute = getattr(self, "chol_cov" if self._cov_type == "chol" else self._cov_type)
256+
mu, param = draw_values([self.mu, param_attribute], point=point, size=size)
266257

267-
try:
268-
dist = stats.multivariate_normal(mean=mu, cov=cov, allow_singular=True)
269-
except ValueError:
270-
size += (mu.shape[-1],)
271-
return np.nan * np.zeros(size)
272-
return dist.rvs(size)
273-
elif self._cov_type == "chol":
274-
mu, chol = draw_values([self.mu, self.chol_cov], point=point, size=size)
275-
if size and mu.ndim == len(size) and mu.shape == size:
276-
mu = mu[..., np.newaxis]
277-
if mu.shape[-1] != chol.shape[-1] and mu.shape[-1] != 1:
278-
raise ValueError("Shapes for mu and chol don't match")
279-
broadcast_shape = np.broadcast(np.empty(mu.shape[:-1]), np.empty(chol.shape[:-2])).shape
280-
281-
mu = np.broadcast_to(mu, broadcast_shape + (chol.shape[-1],))
282-
chol = np.broadcast_to(chol, broadcast_shape + chol.shape[-2:])
283-
# If mu and chol were fixed by the point, only the standard normal
284-
# should change
285-
if mu.shape[: len(size)] != size:
286-
std_norm_shape = size + mu.shape
287-
else:
288-
std_norm_shape = mu.shape
289-
standard_normal = np.random.standard_normal(std_norm_shape)
290-
return mu + np.einsum("...ij,...j->...i", chol, standard_normal)
291-
else:
292-
mu, tau = draw_values([self.mu, self.tau], point=point, size=size)
293-
if mu.shape[-1] != tau[0].shape[-1]:
294-
raise ValueError("Shapes for mu and tau don't match")
258+
dist_shape = to_tuple(self.shape)
259+
output_shape = size + dist_shape
295260

296-
size += (mu.shape[-1],)
297-
try:
298-
chol = linalg.cholesky(tau, lower=True)
299-
except linalg.LinAlgError:
300-
return np.nan * np.zeros(size)
261+
# Simple, there can be only be 1 batch dimension, only available from `mu`.
262+
# Insert it into `param` before events, if there is a sample shape in front.
263+
if param.ndim > 2 and dist_shape[:-1]:
264+
param = param.reshape(size + (1,) + param.shape[-2:])
265+
266+
mu = broadcast_dist_samples_to(to_shape=output_shape, samples=[mu], size=size)[0]
267+
param = np.broadcast_to(param, shape=output_shape + dist_shape[-1:])
268+
269+
assert mu.shape == output_shape
270+
assert param.shape == output_shape + dist_shape[-1:]
271+
272+
if self._cov_type == "cov":
273+
chol = np.linalg.cholesky(param)
274+
elif self._cov_type == "chol":
275+
chol = param
276+
else: # tau -> chol -> swapaxes (chol, -1, -2) -> inv ...
277+
lower_chol = np.linalg.cholesky(param)
278+
upper_chol = np.swapaxes(lower_chol, -1, -2)
279+
chol = np.linalg.inv(upper_chol)
301280

302-
standard_normal = np.random.standard_normal(size)
303-
transformed = linalg.solve_triangular(chol, standard_normal.T, lower=True)
304-
return mu + transformed.T
281+
standard_normal = np.random.standard_normal(output_shape)
282+
return mu + np.einsum("...ij,...j->...i", chol, standard_normal)
305283

306284
def logp(self, value):
307285
"""
@@ -399,13 +377,13 @@ def random(self, point=None, size=None):
399377
nu, mu = draw_values([self.nu, self.mu], point=point, size=size)
400378
if self._cov_type == "cov":
401379
(cov,) = draw_values([self.cov], point=point, size=size)
402-
dist = MvNormal.dist(mu=np.zeros_like(mu), cov=cov)
380+
dist = MvNormal.dist(mu=np.zeros_like(mu), cov=cov, shape=self.shape)
403381
elif self._cov_type == "tau":
404382
(tau,) = draw_values([self.tau], point=point, size=size)
405-
dist = MvNormal.dist(mu=np.zeros_like(mu), tau=tau)
383+
dist = MvNormal.dist(mu=np.zeros_like(mu), tau=tau, shape=self.shape)
406384
else:
407385
(chol,) = draw_values([self.chol_cov], point=point, size=size)
408-
dist = MvNormal.dist(mu=np.zeros_like(mu), chol=chol)
386+
dist = MvNormal.dist(mu=np.zeros_like(mu), chol=chol, shape=self.shape)
409387

410388
samples = dist.random(point, size)
411389

@@ -1915,6 +1893,7 @@ def random(self, point=None, size=None):
19151893
"""
19161894
# Expand params into terms MvNormal can understand to force consistency
19171895
self._setup_random()
1896+
self.mv_params["shape"] = self.shape
19181897
dist = MvNormal.dist(**self.mv_params)
19191898
return dist.random(point, size)
19201899

pymc3/tests/test_distributions_random.py

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import itertools
1516
import pytest
1617
import numpy as np
1718
import numpy.testing as npt
@@ -27,6 +28,7 @@
2728
draw_values,
2829
_DrawValuesContext,
2930
_DrawValuesContextBlocker,
31+
to_tuple,
3032
)
3133
from .helpers import SeededTest
3234
from .test_distributions import (
@@ -1544,3 +1546,112 @@ def test_Triangular(
15441546
prior_samples=prior_samples,
15451547
)
15461548
assert prior["target"].shape == (prior_samples,) + shape
1549+
1550+
1551+
def generate_shapes(include_params=False, xfail=False):
1552+
# fmt: off
1553+
mudim_as_event = [
1554+
[None, 1, 3, 10, (10, 3), 100],
1555+
[(3,)],
1556+
[(1,), (3,)],
1557+
["cov", "chol", "tau"]
1558+
]
1559+
# fmt: on
1560+
mudim_as_dist = [
1561+
[None, 1, 3, 10, (10, 3), 100],
1562+
[(10, 3)],
1563+
[(1,), (3,), (1, 1), (1, 3), (10, 1), (10, 3)],
1564+
["cov", "chol", "tau"],
1565+
]
1566+
if not include_params:
1567+
del mudim_as_event[-1]
1568+
del mudim_as_dist[-1]
1569+
data = itertools.chain(itertools.product(*mudim_as_event), itertools.product(*mudim_as_dist))
1570+
if xfail:
1571+
data = list(data)
1572+
for index in range(len(data)):
1573+
if data[index][0] in (None, 1):
1574+
data[index] = pytest.param(
1575+
*data[index], marks=pytest.mark.xfail(reason="wait for PR #4214")
1576+
)
1577+
return data
1578+
1579+
1580+
class TestMvNormal(SeededTest):
1581+
@pytest.mark.parametrize(
1582+
["sample_shape", "dist_shape", "mu_shape", "param"],
1583+
generate_shapes(include_params=True, xfail=False),
1584+
ids=str,
1585+
)
1586+
def test_with_np_arrays(self, sample_shape, dist_shape, mu_shape, param):
1587+
dist = pm.MvNormal.dist(mu=np.ones(mu_shape), **{param: np.eye(3)}, shape=dist_shape)
1588+
output_shape = to_tuple(sample_shape) + dist_shape
1589+
assert dist.random(size=sample_shape).shape == output_shape
1590+
1591+
@pytest.mark.parametrize(
1592+
["sample_shape", "dist_shape", "mu_shape"],
1593+
generate_shapes(include_params=False, xfail=True),
1594+
ids=str,
1595+
)
1596+
def test_with_chol_rv(self, sample_shape, dist_shape, mu_shape):
1597+
with pm.Model() as model:
1598+
mu = pm.Normal("mu", 0.0, 1.0, shape=mu_shape)
1599+
sd_dist = pm.Exponential.dist(1.0, shape=3)
1600+
chol, corr, stds = pm.LKJCholeskyCov(
1601+
"chol_cov", n=3, eta=2, sd_dist=sd_dist, compute_corr=True
1602+
)
1603+
mv = pm.MvNormal("mv", mu, chol=chol, shape=dist_shape)
1604+
prior = pm.sample_prior_predictive(samples=sample_shape)
1605+
1606+
assert prior["mv"].shape == to_tuple(sample_shape) + dist_shape
1607+
1608+
@pytest.mark.parametrize(
1609+
["sample_shape", "dist_shape", "mu_shape"],
1610+
generate_shapes(include_params=False, xfail=True),
1611+
ids=str,
1612+
)
1613+
def test_with_cov_rv(self, sample_shape, dist_shape, mu_shape):
1614+
with pm.Model() as model:
1615+
mu = pm.Normal("mu", 0.0, 1.0, shape=mu_shape)
1616+
sd_dist = pm.Exponential.dist(1.0, shape=3)
1617+
chol, corr, stds = pm.LKJCholeskyCov(
1618+
"chol_cov", n=3, eta=2, sd_dist=sd_dist, compute_corr=True
1619+
)
1620+
mv = pm.MvNormal("mv", mu, cov=pm.math.dot(chol, chol.T), shape=dist_shape)
1621+
prior = pm.sample_prior_predictive(samples=sample_shape)
1622+
1623+
assert prior["mv"].shape == to_tuple(sample_shape) + dist_shape
1624+
1625+
def test_issue_3758(self):
1626+
np.random.seed(42)
1627+
ndim = 50
1628+
with pm.Model() as model:
1629+
a = pm.Normal("a", sigma=100, shape=ndim)
1630+
b = pm.Normal("b", mu=a, sigma=1, shape=ndim)
1631+
c = pm.MvNormal("c", mu=a, chol=np.linalg.cholesky(np.eye(ndim)), shape=ndim)
1632+
d = pm.MvNormal("d", mu=a, cov=np.eye(ndim), shape=ndim)
1633+
samples = pm.sample_prior_predictive(1000)
1634+
1635+
for var in "abcd":
1636+
assert not np.isnan(np.std(samples[var]))
1637+
1638+
def test_issue_3829(self):
1639+
with pm.Model() as model:
1640+
x = pm.MvNormal("x", mu=np.zeros(5), cov=np.eye(5), shape=(2, 5))
1641+
trace_pp = pm.sample_prior_predictive(50)
1642+
1643+
assert np.shape(trace_pp["x"][0]) == (2, 5)
1644+
1645+
def test_issue_3706(self):
1646+
N = 10
1647+
Sigma = np.eye(2)
1648+
1649+
with pm.Model() as model:
1650+
1651+
X = pm.MvNormal("X", mu=np.zeros(2), cov=Sigma, shape=(N, 2))
1652+
betas = pm.Normal("betas", 0, 1, shape=2)
1653+
y = pm.Deterministic("y", pm.math.dot(X, betas))
1654+
1655+
prior_pred = pm.sample_prior_predictive(1)
1656+
1657+
assert prior_pred["X"].shape == (1, N, 2)

pymc3/tests/test_mixture.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -368,7 +368,7 @@ def build_toy_dataset(N, K):
368368
)
369369
)
370370
chol.append(pm.expand_packed_triangular(D, packed_chol[i], lower=True))
371-
comp_dist.append(pm.MvNormal.dist(mu=mu[i], chol=chol[i]))
371+
comp_dist.append(pm.MvNormal.dist(mu=mu[i], chol=chol[i], shape=D))
372372

373373
pm.Mixture("x_obs", pi, comp_dist, observed=X)
374374
with model:

0 commit comments

Comments
 (0)