@@ -133,18 +133,28 @@ def gelman_rubin(mtrace):
133
133
134
134
def calc_rhat (x ):
135
135
136
- m , n = x .shape
136
+ try :
137
+ # When the variable is multidimensional, this assignment will fail, triggering
138
+ # a ValueError that will handle the multidimensional case
139
+ m , n = x .shape
140
+
141
+ # Calculate between-chain variance
142
+ B = n * np .var (np .mean (x , axis = 1 ), ddof = 1 )
137
143
138
- # Calculate between -chain variance
139
- B = n * np .var (np .mean (x , axis = 1 ) , ddof = 1 )
144
+ # Calculate within -chain variance
145
+ W = np .mean (np .var (x , axis = 1 , ddof = 1 ) )
140
146
141
- # Calculate within-chain variance
142
- W = np . mean ( np . var ( x , axis = 1 , ddof = 1 ))
147
+ # Estimate of marginal posterior variance
148
+ Vhat = W * ( n - 1 ) / n + B / n
143
149
144
- # Estimate of marginal posterior variance
145
- Vhat = W * (n - 1 )/ n + B / n
150
+ return np .sqrt (Vhat / W )
151
+
152
+ except ValueError :
146
153
147
- return np .sqrt (Vhat / W )
154
+ # Tricky transpose here, shifting the last dimension to the first
155
+ rotated_indices = np .roll (np .arange (x .ndim ), 1 )
156
+ # Now iterate over the dimension of the variable
157
+ return np .squeeze ([calc_rhat (xi ) for xi in x .transpose (rotated_indices )])
148
158
149
159
Rhat = {}
150
160
for var in mtrace .varnames :
0 commit comments