@@ -256,6 +256,60 @@ def test_insert_draw(self):
256
256
chain ._get_row_at (2 , var_names = ["v1" ])
257
257
pass
258
258
259
+ def test_to_inferencedata_equalize_chain_lengths (self , caplog ):
260
+ run , chains = fully_initialized (
261
+ self .backend ,
262
+ make_runmeta (
263
+ variables = [
264
+ Variable ("A" , "uint16" , []),
265
+ ],
266
+ sample_stats = [Variable ("tune" , "bool" )],
267
+ data = [],
268
+ ),
269
+ nchains = 2 ,
270
+ )
271
+ # Create chains of uneven lengths:
272
+ # - Chain 0 has 5 tune and 15 draws (length 20)
273
+ # - Chain 1 has 5 tune and 14 draws (length 19)
274
+ # This simulates the situation where chains aren't synchronized.
275
+ ntune = 5
276
+
277
+ c0 = chains [0 ]
278
+ for i in range (0 , 20 ):
279
+ c0 .append (dict (A = i ), stats = dict (tune = i < ntune ))
280
+
281
+ c1 = chains [1 ]
282
+ for i in range (0 , 19 ):
283
+ c1 .append (dict (A = i ), stats = dict (tune = i < ntune ))
284
+
285
+ assert len (c0 ) == 20
286
+ assert len (c1 ) == 19
287
+
288
+ # With equalize=True all chains should have the length of the shortest (here: 7)
289
+ # But the first 3 are tuning, so 4 posterior draws remain.
290
+ with caplog .at_level (logging .WARNING ):
291
+ idata_even = run .to_inferencedata (equalize_chain_lengths = True )
292
+ assert "Chains vary in length" in caplog .records [0 ].message
293
+ assert "Truncating to" in caplog .records [0 ].message
294
+ assert len (idata_even .posterior .draw ) == 14
295
+
296
+ # With equalize=False the "draw" dim has the length of the longest chain (here: 8-3 = 5)
297
+ caplog .clear ()
298
+ with caplog .at_level (logging .WARNING ):
299
+ idata_uneven = run .to_inferencedata (equalize_chain_lengths = False )
300
+ # These are the messed-up chain and draw dimensions!
301
+ assert idata_uneven .posterior .dims ["chain" ] == 1
302
+ assert idata_uneven .posterior .dims ["draw" ] == 2
303
+ # The "draws" are actually the chains, but in a weird scalar object-array?!
304
+ # Doing .tolist() seems to be the only way to get our hands on it.
305
+ d1 = idata_uneven .posterior .A .sel (chain = 0 , draw = 0 ).values .tolist ()
306
+ d2 = idata_uneven .posterior .A .sel (chain = 0 , draw = 1 ).values .tolist ()
307
+ numpy .testing .assert_array_equal (d1 , list (range (ntune , 20 )))
308
+ numpy .testing .assert_array_equal (d2 , list (range (ntune , 19 )))
309
+ assert "Chains vary in length" in caplog .records [0 ].message
310
+ assert "see ArviZ issue #2094" in caplog .records [0 ].message
311
+ pass
312
+
259
313
260
314
if __name__ == "__main__" :
261
315
tc = TestClickHouseBackend ()
0 commit comments