Skip to content

Commit e6a7acd

Browse files
removed asserts, added tests on sizes
1 parent 6f0579a commit e6a7acd

File tree

2 files changed

+9
-11
lines changed

2 files changed

+9
-11
lines changed

pymc/distributions/multivariate.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2259,12 +2259,6 @@ class ICAR(Continuous):
22592259
constraint by finding the sum of the vector $\\phi$ and penalizing based on its
22602260
distance from zero.
22612261
2262-
======== ==========================================
2263-
Support :math:`x \\in \\mathbb{R}^k`
2264-
Mean :math:`0`
2265-
Variance :math:`T^{-1}` ?
2266-
======== ==========================================
2267-
22682262
Parameters
22692263
----------
22702264
W : ndarray of int
@@ -2365,14 +2359,10 @@ def dist(cls, W, sigma=1, zero_sum_strength=0.001, **kwargs):
23652359
# check on sigma
23662360

23672361
sigma = pt.as_tensor_variable(floatX(sigma))
2368-
sigma = Assert("sigma > 0")(sigma, pt.gt(sigma, 0))
23692362

23702363
# check on centering_strength
23712364

23722365
zero_sum_strength = pt.as_tensor_variable(floatX(zero_sum_strength))
2373-
zero_sum_strength = Assert("centering_strength > 0")(
2374-
zero_sum_strength, pt.gt(zero_sum_strength, 0)
2375-
)
23762366

23772367
return super().dist([W, node1, node2, N, sigma, zero_sum_strength], **kwargs)
23782368

tests/distributions/test_multivariate.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2093,7 +2093,15 @@ class TestICAR(BaseTestDistributionRandom):
20932093
"sigma": 2,
20942094
"zero_sum_strength": 0.001,
20952095
}
2096-
checks_to_run = ["check_pymc_params_match_rv_op"]
2096+
checks_to_run = ["check_pymc_params_match_rv_op", "check_rv_inferred_size"]
2097+
2098+
def check_rv_inferred_size(self):
2099+
sizes_to_check = [None, (), 1, (1,), 5, (4, 5), (2, 4, 2)]
2100+
sizes_expected = [(3,), (3,), (1, 3), (1, 3), (5, 3), (4, 5, 3), (2, 4, 2, 3)]
2101+
for size, expected in zip(sizes_to_check, sizes_expected):
2102+
pymc_rv = self.pymc_dist.dist(**self.pymc_dist_params, size=size)
2103+
expected_symbolic = tuple(pymc_rv.shape.eval())
2104+
assert expected_symbolic == expected
20972105

20982106
def test_icar_logp(self):
20992107
W = np.array([[0, 1, 0, 1], [1, 0, 1, 0], [0, 1, 0, 1], [1, 0, 1, 0]])

0 commit comments

Comments
 (0)