Skip to content

Commit 26859b1

Browse files
logp test, test rename, np.all - > np.any
1 parent eba76f6 commit 26859b1

File tree

2 files changed

+6
-2
lines changed

2 files changed

+6
-2
lines changed

pymc/distributions/multivariate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2078,7 +2078,7 @@ def rng_fn(cls, rng: np.random.RandomState, mu, W, alpha, tau, size):
20782078
Journal of the Royal Statistical Society Series B, Royal Statistical Society,
20792079
vol. 63(2), pages 325-338. DOI: 10.1111/1467-9868.00288
20802080
"""
2081-
if np.all(alpha >= 1) or np.all(alpha <= -1):
2081+
if np.any(alpha >= 1) or np.any(alpha <= -1):
20822082
raise ValueError("the domain of alpha is: -1 < alpha < 1")
20832083

20842084
if not scipy.sparse.issparse(W):

tests/distributions/test_multivariate.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -836,7 +836,7 @@ def test_car_matrix_check(sparse):
836836

837837

838838
@pytest.mark.parametrize("alpha", [1, -1])
839-
def test_car_alpha(alpha):
839+
def test_car_alpha_bounds(alpha):
840840
"""
841841
Tests the check that -1 < alpha < 1
842842
"""
@@ -845,12 +845,16 @@ def test_car_alpha(alpha):
845845

846846
tau = 1
847847
mu = np.array([0, 0, 0])
848+
values = np.array([-0.5, 0, 0.5])
848849

849850
car_dist = pm.CAR.dist(W=W, alpha=alpha, mu=mu, tau=tau)
850851

851852
with pytest.raises(ValueError, match="the domain of alpha is: -1 < alpha < 1"):
852853
pm.draw(car_dist)
853854

855+
with pytest.raises(ValueError, match="-1 < alpha < 1, tau > 0"):
856+
pm.logp(car_dist, values).eval()
857+
854858

855859
class TestLKJCholeskCov:
856860
def test_dist(self):

0 commit comments

Comments
 (0)