Skip to content

Commit 299ea68

Browse files
fix check for square matrix + check_pymc_match_rv_op
1 parent bf8be00 commit 299ea68

File tree

2 files changed

+33
-27
lines changed

2 files changed

+33
-27
lines changed

pymc/distributions/multivariate.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2334,12 +2334,16 @@ class ICAR(Continuous):
23342334
@classmethod
23352335
def dist(cls, W, sigma=1, zero_sum_strength=0.001, **kwargs):
23362336
# check that adjacency matrix is two dimensional,
2337+
# square,
23372338
# symmetrical
23382339
# and composed of 1s or 0s.
23392340

23402341
if not W.ndim == 2:
23412342
raise ValueError("W must be matrix with ndim=2")
23422343

2344+
if not W.shape[0] == W.shape[1]:
2345+
raise ValueError("W must be a square matrix")
2346+
23432347
if not np.allclose(W.T, W):
23442348
raise ValueError("W must be a symmetric matrix")
23452349

tests/distributions/test_multivariate.py

Lines changed: 29 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -2085,19 +2085,39 @@ def check_draws_match_expected(self):
20852085
class TestICAR(BaseTestDistributionRandom):
20862086
pymc_dist = pm.ICAR
20872087
pymc_dist_params = {"W": np.array([[0, 1, 1], [1, 0, 1], [1, 1, 0]]), "sigma": 2}
2088-
expected_rv_op_params = {"W": np.array([[0, 1, 1], [1, 0, 1], [1, 1, 0]]), "sigma": 2}
2089-
sizes_to_check = [None, (1), (2), (2, 2)]
2090-
sizes_expected = [(3,), (3), (6), (12)]
2091-
checks_to_run = [
2092-
"check_pymc_params_match_rv_op",
2093-
"check_rv_size",
2094-
]
2088+
expected_rv_op_params = {
2089+
"W": np.array([[0, 1, 1], [1, 0, 1], [1, 1, 0]]),
2090+
"node1": np.array([1, 2, 2]),
2091+
"node2": np.array([0, 0, 1]),
2092+
"N": 3,
2093+
"sigma": 2,
2094+
"zero_sum_strength": 0.001,
2095+
}
2096+
checks_to_run = ["check_pymc_params_match_rv_op"]
2097+
2098+
def test_icar_logp(self):
2099+
W = np.array([[0, 1, 0, 1], [1, 0, 1, 0], [0, 1, 0, 1], [1, 0, 1, 0]])
2100+
2101+
with pm.Model() as m:
2102+
RV = pm.ICAR("phi", W=W)
2103+
2104+
assert pt.isclose(
2105+
pm.logp(RV, np.array([0.01, -0.03, 0.02, 0.00])).eval(), np.array(4.60022238)
2106+
).eval(), "logp inaccuracy"
2107+
2108+
def test_icar_rng_fn(self):
2109+
W = np.array([[0, 1, 0, 1], [1, 0, 1, 0], [0, 1, 0, 1], [1, 0, 1, 0]])
2110+
2111+
RV = pm.ICAR.dist(W=W)
2112+
2113+
with pytest.raises(NotImplementedError, match="Cannot sample from ICAR prior"):
2114+
pm.draw(RV)
20952115

20962116
@pytest.mark.parametrize(
20972117
"W,msg",
20982118
[
20992119
(np.array([0, 1, 0, 0]), "W must be matrix with ndim=2"),
2100-
(np.array([[0, 1, 0, 0], [1, 0, 0, 1], [1, 0, 0, 1]]), "W must be a symmetric matrix"),
2120+
(np.array([[0, 1, 0, 0], [1, 0, 0, 1], [1, 0, 0, 1]]), "W must be a square matrix"),
21012121
(
21022122
np.array([[0, 1, 0, 0], [1, 0, 0, 1], [1, 0, 0, 1], [0, 1, 1, 0]]),
21032123
"W must be a symmetric matrix",
@@ -2108,29 +2128,11 @@ class TestICAR(BaseTestDistributionRandom):
21082128
),
21092129
],
21102130
)
2111-
def test_icar_matrix_checks(W, msg):
2131+
def test_icar_matrix_checks(self, W, msg):
21122132
with pytest.raises(ValueError, match=msg):
21132133
with pm.Model():
21142134
pm.ICAR("phi", W=W)
21152135

2116-
def test_icar_logp():
2117-
W = np.array([[0, 1, 0, 1], [1, 0, 1, 0], [0, 1, 0, 1], [1, 0, 1, 0]])
2118-
2119-
with pm.Model() as m:
2120-
RV = pm.ICAR("phi", W=W)
2121-
2122-
assert pt.isclose(
2123-
pm.logp(RV, np.array([0.01, -0.03, 0.02, 0.00])).eval(), np.array(4.60022238)
2124-
).eval(), "logp inaccuracy"
2125-
2126-
def test_icar_rng_fn():
2127-
W = np.array([[0, 1, 0, 1], [1, 0, 1, 0], [0, 1, 0, 1], [1, 0, 1, 0]])
2128-
2129-
RV = pm.ICAR.dist(W=W)
2130-
2131-
with pytest.raises(NotImplementedError, match="Cannot sample from ICAR prior"):
2132-
pm.draw(RV)
2133-
21342136

21352137
@pytest.mark.parametrize("sparse", [True, False])
21362138
def test_car_rng_fn(sparse):

0 commit comments

Comments
 (0)