Skip to content

Commit f369137

Browse files
authored
Added check that nu must be a scalar in MvStudentTRV (#5241)
1 parent a4f9657 commit f369137

File tree

2 files changed

+20
-0
lines changed

2 files changed

+20
-0
lines changed

pymc/distributions/multivariate.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,13 @@ class MvStudentTRV(RandomVariable):
262262
dtype = "floatX"
263263
_print_name = ("MvStudentT", "\\operatorname{MvStudentT}")
264264

265+
def make_node(self, rng, size, dtype, nu, mu, cov):
266+
nu = at.as_tensor_variable(nu)
267+
if not nu.ndim == 0:
268+
raise ValueError("nu must be a scalar (ndim=0).")
269+
270+
return super().make_node(rng, size, dtype, nu, mu, cov)
271+
265272
def __call__(self, nu, mu=None, cov=None, size=None, **kwargs):
266273

267274
dtype = aesara.config.floatX if self.dtype == "floatX" else self.dtype

pymc/tests/test_distributions_random.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414
import functools
1515
import itertools
16+
import re
1617

1718
from typing import Callable, List, Optional
1819

@@ -1115,8 +1116,20 @@ def mvstudentt_rng_fn(self, size, nu, mu, cov, rng):
11151116
"check_pymc_params_match_rv_op",
11161117
"check_pymc_draws_match_reference",
11171118
"check_rv_size",
1119+
"test_errors",
11181120
]
11191121

1122+
def test_errors(self):
1123+
msg = "nu must be a scalar (ndim=0)."
1124+
with pm.Model():
1125+
with pytest.raises(ValueError, match=re.escape(msg)):
1126+
mvstudentt = pm.MvStudentT(
1127+
"mvstudentt",
1128+
nu=np.array([1, 2]),
1129+
mu=np.ones(2),
1130+
cov=np.full((2, 2), np.ones(2)),
1131+
)
1132+
11201133

11211134
class TestMvStudentTChol(BaseTestDistribution):
11221135
pymc_dist = pm.MvStudentT

0 commit comments

Comments
 (0)