Skip to content

Commit 6b46604

Browse files
Implementing Bill's comments
1 parent e6a7acd commit 6b46604

File tree

1 file changed

+41
-46
lines changed

1 file changed

+41
-46
lines changed

pymc/distributions/multivariate.py

Lines changed: 41 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -2226,11 +2226,11 @@ class ICARRV(RandomVariable):
22262226
dtype = "floatX"
22272227
_print_name = ("ICAR", "\\operatorname{ICAR}")
22282228

2229-
def __call__(self, W, node1, node2, N, sigma, zero_sum_strength, size=None, **kwargs):
2230-
return super().__call__(W, node1, node2, N, sigma, zero_sum_strength, size=size, **kwargs)
2229+
def __call__(self, W, node1, node2, N, sigma, zero_sum_stdev, size=None, **kwargs):
2230+
return super().__call__(W, node1, node2, N, sigma, zero_sum_stdev, size=size, **kwargs)
22312231

22322232
@classmethod
2233-
def rng_fn(cls, rng, size, W, node1, node2, N, sigma, zero_sum_strength):
2233+
def rng_fn(cls, rng, size, W, node1, node2, N, sigma, zero_sum_stdev):
22342234
raise NotImplementedError("Cannot sample from ICAR prior")
22352235

22362236

@@ -2240,40 +2240,41 @@ def rng_fn(cls, rng, size, W, node1, node2, N, sigma, zero_sum_strength):
22402240
class ICAR(Continuous):
22412241
r"""
22422242
The intrinsic conditional autoregressive prior. It is primarily used to model
2243-
covariance between neighboring areas on large datasets. It is a special case
2243+
covariance between neighboring areas. It is a special case
22442244
of the :class:`~pymc.CAR` distribution where alpha is set to 1.
22452245
22462246
The log probability density function is
22472247
22482248
.. math::
2249-
f(\\phi| W,\\sigma) =
2250-
-\frac{1}{2\\sigma^{2}} \\sum_{i\\sim j} (\\phi_{i} - \\phi_{j})^2 -
2251-
\frac{1}{2}*\frac{\\sum_{i}{\\phi_{i}}}{0.001N}^{2} - \\ln{\\sqrt{2\\pi}} -
2252-
\\ln{0.001N}
2249+
f(\phi| W,\sigma) =
2250+
-\frac{1}{2\sigma^{2}} \sum_{i\sim j} (\phi_{i} - \phi_{j})^2 -
2251+
\frac{1}{2}*\frac{\sum_{i}{\phi_{i}}}{0.001N}^{2} - \ln{\sqrt{2\\pi}} -
2252+
\ln{0.001N}
22532253
22542254
The first term represents the spatial covariance component. Each $\\phi_{i}$ is penalized
22552255
based on the square distance from each of its neighbors. The notation $i\\sim j$
22562256
indicates a sum over all the neighbors of $\\phi_{i}$. The last three terms are the
22572257
Normal log density function where the mean is zero and the standard deviation is
2258-
$N * 0.001$ (where N is the length of the vector $\\phi$). This component imposed the zero-sum
2259-
constraint by finding the sum of the vector $\\phi$ and penalizing based on its
2260-
distance from zero.
2258+
$N * 0.001$ (where N is the length of the vector $\\phi$). This component imposes
2259+
a zero-sum constraint by finding the sum of the vector $\\phi$ and penalizing based
2260+
on its distance from zero.
22612261
22622262
Parameters
22632263
----------
22642264
W : ndarray of int
22652265
Symmetric adjacency matrix of 1s and 0s indicating adjacency between elements.
2266-
Must pass either W or both node1 and node2.
22672266
22682267
sigma : scalar, default 1
22692268
Standard deviation of the vector of phi's. Putting a prior on sigma
22702269
will result in a centered parameterization. In most cases, it is
22712270
preferable to use a non-centered parameterization by using the default
22722271
value and multiplying the resulting phi's by sigma. See the example below.
22732272
2274-
zero_sum_strength : scalar, default 0.001
2275-
Controls how strongly to enforce the zero-sum constraint. It sets the
2276-
standard deviation of a normal density function with mean zero.
2273+
zero_sum_stdev : scalar, default 0.001
2274+
Controls how strongly to enforce the zero-sum constraint. The sum of
2275+
phi is normally distributed with a mean of zero and small standard deviation.
2276+
This parameter sets the standard deviation of a normal density function with
2277+
mean zero.
22772278
22782279
22792280
Examples
@@ -2289,25 +2290,23 @@ class ICAR(Continuous):
22892290
# 4x4 adjacency matrix
22902291
# arranged in a square lattice
22912292
2292-
W = np.array([[0,1,0,1],
2293-
[1,0,1,0],
2294-
[0,1,0,1],
2295-
[1,0,1,0]])
2293+
W = np.array([
2294+
[0,1,0,1],
2295+
[1,0,1,0],
2296+
[0,1,0,1],
2297+
[1,0,1,0]
2298+
])
22962299
22972300
# centered parameterization
2298-
22992301
with pm.Model():
2300-
sigma = pm.Exponential('sigma',1)
2301-
phi = pm.ICAR('phi',W=W,sigma=sigma)
2302-
2302+
sigma = pm.Exponential('sigma', 1)
2303+
phi = pm.ICAR('phi', W=W, sigma=sigma)
23032304
mu = phi
23042305
23052306
# non-centered parameterization
2306-
23072307
with pm.Model():
2308-
sigma = pm.Exponential('sigma',1)
2309-
phi = pm.ICAR('phi',W=W)
2310-
2308+
sigma = pm.Exponential('sigma', 1)
2309+
phi = pm.ICAR('phi', W=W)
23112310
mu = sigma * phi
23122311
23132312
References
@@ -2326,12 +2325,7 @@ class ICAR(Continuous):
23262325
rv_op = icar
23272326

23282327
@classmethod
2329-
def dist(cls, W, sigma=1, zero_sum_strength=0.001, **kwargs):
2330-
# check that adjacency matrix is two dimensional,
2331-
# square,
2332-
# symmetrical
2333-
# and composed of 1s or 0s.
2334-
2328+
def dist(cls, W, sigma=1, zero_sum_stdev=0.001, **kwargs):
23352329
if not W.ndim == 2:
23362330
raise ValueError("W must be matrix with ndim=2")
23372331

@@ -2345,6 +2339,12 @@ def dist(cls, W, sigma=1, zero_sum_strength=0.001, **kwargs):
23452339
raise ValueError("W must be composed of only 1s and 0s")
23462340

23472341
# convert adjacency matrix to edgelist representation
2342+
# An edgelist is a pair of lists.
2343+
# If node i and node j are connected then one list
2344+
# will contain i and the other will contain j at the same
2345+
# index value.
2346+
# We only use the lower triangle here because adjacency
2347+
# is a undirected connection.
23482348

23492349
node1, node2 = np.where(np.tril(W) == 1)
23502350

@@ -2356,30 +2356,25 @@ def dist(cls, W, sigma=1, zero_sum_strength=0.001, **kwargs):
23562356
N = pt.shape(W)[0]
23572357
N = pt.as_tensor_variable(N)
23582358

2359-
# check on sigma
2360-
23612359
sigma = pt.as_tensor_variable(floatX(sigma))
2360+
zero_sum_stdev = pt.as_tensor_variable(floatX(zero_sum_stdev))
23622361

2363-
# check on centering_strength
2364-
2365-
zero_sum_strength = pt.as_tensor_variable(floatX(zero_sum_strength))
2366-
2367-
return super().dist([W, node1, node2, N, sigma, zero_sum_strength], **kwargs)
2362+
return super().dist([W, node1, node2, N, sigma, zero_sum_stdev], **kwargs)
23682363

2369-
def moment(rv, size, W, node1, node2, N, sigma, zero_sum_strength):
2364+
def moment(rv, size, W, node1, node2, N, sigma, zero_sum_stdev):
23702365
return pt.zeros(N)
23712366

2372-
def logp(value, W, node1, node2, N, sigma, zero_sum_strength):
2367+
def logp(value, W, node1, node2, N, sigma, zero_sum_stdev):
23732368
pairwise_difference = (-1 / (2 * sigma**2)) * pt.sum(
23742369
pt.square(value[node1] - value[node2])
23752370
)
2376-
soft_center = (
2377-
-0.5 * pt.pow(pt.sum(value) / (zero_sum_strength * N), 2)
2371+
zero_sum = (
2372+
-0.5 * pt.pow(pt.sum(value) / (zero_sum_stdev * N), 2)
23782373
- pt.log(pt.sqrt(2.0 * np.pi))
2379-
- pt.log(zero_sum_strength * N)
2374+
- pt.log(zero_sum_stdev * N)
23802375
)
23812376

2382-
return check_parameters(pairwise_difference + soft_center, sigma > 0, msg="sigma > 0")
2377+
return check_parameters(pairwise_difference + zero_sum, sigma > 0, msg="sigma > 0")
23832378

23842379

23852380
class StickBreakingWeightsRV(RandomVariable):

0 commit comments

Comments
 (0)