Skip to content

Commit 607950e

Browse files
author
Christopher Fonnesbeck
committed
Merge branch 'rhat_fix'
2 parents a272d36 + 3dc852c commit 607950e

File tree

1 file changed

+18
-8
lines changed

1 file changed

+18
-8
lines changed

pymc/diagnostics.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -133,18 +133,28 @@ def gelman_rubin(mtrace):
133133

134134
def calc_rhat(x):
135135

136-
m, n = x.shape
136+
try:
137+
# When the variable is multidimensional, this assignment will fail, triggering
138+
# a ValueError that will handle the multidimensional case
139+
m, n = x.shape
140+
141+
# Calculate between-chain variance
142+
B = n * np.var(np.mean(x, axis=1), ddof=1)
137143

138-
# Calculate between-chain variance
139-
B = n * np.var(np.mean(x, axis=1), ddof=1)
144+
# Calculate within-chain variance
145+
W = np.mean(np.var(x, axis=1, ddof=1))
140146

141-
# Calculate within-chain variance
142-
W = np.mean(np.var(x, axis=1, ddof=1))
147+
# Estimate of marginal posterior variance
148+
Vhat = W*(n - 1)/n + B/n
143149

144-
# Estimate of marginal posterior variance
145-
Vhat = W*(n - 1)/n + B/n
150+
return np.sqrt(Vhat/W)
151+
152+
except ValueError:
146153

147-
return np.sqrt(Vhat/W)
154+
# Tricky transpose here, shifting the last dimension to the first
155+
rotated_indices = np.roll(np.arange(x.ndim), 1)
156+
# Now iterate over the dimension of the variable
157+
return np.squeeze([calc_rhat(xi) for xi in x.transpose(rotated_indices)])
148158

149159
Rhat = {}
150160
for var in mtrace.varnames:

0 commit comments

Comments
 (0)