Skip to content

Commit 398eae6

Browse files
authored
Fix MvStudentT.random (#4359)
* Fix MvStudentT.random * Given a mention in release notes
1 parent 0e9b9a4 commit 398eae6

File tree

3 files changed

+10
-7
lines changed

3 files changed

+10
-7
lines changed

RELEASE-NOTES.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ This is the first release to support Python3.9 and to drop Python3.6.
1414
- In `sample_posterior_predictive` the `vars` kwarg was removed in favor of `var_names` (see [#4343](https://github.com/pymc-devs/pymc3/pull/4343)).
1515
- The notebook gallery has been moved to https://github.com/pymc-devs/pymc-examples (see [#4348](https://github.com/pymc-devs/pymc3/pull/4348)).
1616
- `math.logsumexp` now matches `scipy.special.logsumexp` when arrays contain infinite values (see [#4360](https://github.com/pymc-devs/pymc3/pull/4360)).
17+
- Fixed mathematical formulation in `MvStudentT` random method. (see [#4359](https://github.com/pymc-devs/pymc3/pull/4359))
1718

1819
## PyMC3 3.10.0 (7 December 2020)
1920

pymc3/distributions/multivariate.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -324,10 +324,10 @@ class MvStudentT(_QuadFormBase):
324324
1+\frac{1}{\nu}
325325
({\mathbf x}-{\mu})^T
326326
{\Sigma}^{-1}({\mathbf x}-{\mu})
327-
\right]^{(\nu+p)/2}}
327+
\right]^{-(\nu+p)/2}}
328328
329329
======== =============================================
330-
Support :math:`x \in \mathbb{R}^k`
330+
Support :math:`x \in \mathbb{R}^p`
331331
Mean :math:`\mu` if :math:`\nu > 1` else undefined
332332
Variance :math:`\frac{\nu}{\mu-2}\Sigma`
333333
if :math:`\nu>2` else undefined
@@ -393,8 +393,10 @@ def random(self, point=None, size=None):
393393

394394
samples = dist.random(point, size)
395395

396-
chi2 = np.random.chisquare
397-
return (np.sqrt(nu) * samples.T / chi2(nu, size)).T + mu
396+
chi2_samples = np.random.chisquare(nu, size)
397+
# Add distribution shape to chi2 samples
398+
chi2_samples = chi2_samples.reshape(chi2_samples.shape + (1,) * len(self.shape))
399+
return (samples / np.sqrt(chi2_samples / nu)) + mu
398400

399401
def logp(self, value):
400402
"""

pymc3/tests/test_distributions_random.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -947,9 +947,9 @@ def ref_rand_evd(size, mu, evds, sigma):
947947

948948
def test_mv_t(self):
949949
def ref_rand(size, nu, Sigma, mu):
950-
normal = st.multivariate_normal.rvs(cov=Sigma, size=size).T
951-
chi2 = st.chi2.rvs(df=nu, size=size)
952-
return mu + np.sqrt(nu) * (normal / chi2).T
950+
normal = st.multivariate_normal.rvs(cov=Sigma, size=size)
951+
chi2 = st.chi2.rvs(df=nu, size=size)[..., None]
952+
return mu + (normal / np.sqrt(chi2 / nu))
953953

954954
for n in [2, 3]:
955955
pymc3_random(

0 commit comments

Comments
 (0)