Skip to content

Commit 978bfc2

Browse files
Add test and logging uneven chain lengths
This is a follow-up to a5c1aae.
1 parent a5c1aae commit 978bfc2

File tree

3 files changed

+74
-4
lines changed

3 files changed

+74
-4
lines changed

.pylintrc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ disable =
66
C0116, # some methods are just too simple to deserve a docstring
77
R0902, # too many instance attributes
88
R0903, # too few public methods
9+
R0912, # too many branches
910
R0913, # too many arguments
1011
R0914, # too many local variables
1112
R1711, # useless return is okay

mcbackend/core.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -182,21 +182,36 @@ def to_inferencedata(self, *, equalize_chain_lengths: bool = True, **kwargs) ->
182182

183183
chain_lengths = {c.cid: len(c) for c in chains}
184184
if len(set(chain_lengths.values())) != 1:
185-
_log.warning("Chains vary in length. Lenghts are: %s", chain_lengths)
186-
clen = None
185+
msg = f"Chains vary in length. Lenghts are: {chain_lengths}"
186+
if not equalize_chain_lengths:
187+
msg += (
188+
"\nArviZ does not properly support uneven chain lengths (see ArviZ issue #2094)."
189+
"\nWe'll try to give you an InferenceData, but best case the chain & draw dimensions"
190+
" will be messed-up as {'chain': 1, 'draws': n_chains}."
191+
"\nYou won't be able to save this InferenceData to a file"
192+
" and you should expect many ArviZ functions to choke on it."
193+
"\nSpecify `to_inferencedata(equalize_chain_lengths=True)` to get regular InferenceData."
194+
)
195+
else:
196+
msg += "\nTruncating to the length of the shortest chain."
197+
_log.warning(msg)
198+
min_clen = None
187199
if equalize_chain_lengths:
188200
# A minimum chain length is introduced so that all chains have equal length
189-
clen = min(chain_lengths.values())
201+
min_clen = min(chain_lengths.values())
190202
# Aggregate draws and stats, while splitting into warmup/posterior
191203
warmup_posterior = collections.defaultdict(list)
192204
warmup_sample_stats = collections.defaultdict(list)
193205
posterior = collections.defaultdict(list)
194206
sample_stats = collections.defaultdict(list)
195207
for c, chain in enumerate(chains):
196-
if clen is None:
208+
# NOTE: Replace the truncation with a ranged fetch once issue #47 is resolved.
209+
if min_clen is None:
197210
# Every retrieved array is shortened to the previously determined chain length.
198211
# This is needed for database backends which may get inserts inbetween.
199212
clen = chain_lengths[chain.cid]
213+
else:
214+
clen = min_clen
200215

201216
# Obtain a mask by which draws can be split into warmup/posterior
202217
if "tune" in chain.sample_stats:

mcbackend/test_backend_clickhouse.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,60 @@ def test_insert_draw(self):
256256
chain._get_row_at(2, var_names=["v1"])
257257
pass
258258

259+
def test_to_inferencedata_equalize_chain_lengths(self, caplog):
260+
run, chains = fully_initialized(
261+
self.backend,
262+
make_runmeta(
263+
variables=[
264+
Variable("A", "uint16", []),
265+
],
266+
sample_stats=[Variable("tune", "bool")],
267+
data=[],
268+
),
269+
nchains=2,
270+
)
271+
# Create chains of uneven lengths:
272+
# - Chain 0 has 5 tune and 15 draws (length 20)
273+
# - Chain 1 has 5 tune and 14 draws (length 19)
274+
# This simulates the situation where chains aren't synchronized.
275+
ntune = 5
276+
277+
c0 = chains[0]
278+
for i in range(0, 20):
279+
c0.append(dict(A=i), stats=dict(tune=i < ntune))
280+
281+
c1 = chains[1]
282+
for i in range(0, 19):
283+
c1.append(dict(A=i), stats=dict(tune=i < ntune))
284+
285+
assert len(c0) == 20
286+
assert len(c1) == 19
287+
288+
# With equalize=True all chains should have the length of the shortest (here: 7)
289+
# But the first 3 are tuning, so 4 posterior draws remain.
290+
with caplog.at_level(logging.WARNING):
291+
idata_even = run.to_inferencedata(equalize_chain_lengths=True)
292+
assert "Chains vary in length" in caplog.records[0].message
293+
assert "Truncating to" in caplog.records[0].message
294+
assert len(idata_even.posterior.draw) == 14
295+
296+
# With equalize=False the "draw" dim has the length of the longest chain (here: 8-3 = 5)
297+
caplog.clear()
298+
with caplog.at_level(logging.WARNING):
299+
idata_uneven = run.to_inferencedata(equalize_chain_lengths=False)
300+
# These are the messed-up chain and draw dimensions!
301+
assert idata_uneven.posterior.dims["chain"] == 1
302+
assert idata_uneven.posterior.dims["draw"] == 2
303+
# The "draws" are actually the chains, but in a weird scalar object-array?!
304+
# Doing .tolist() seems to be the only way to get our hands on it.
305+
d1 = idata_uneven.posterior.A.sel(chain=0, draw=0).values.tolist()
306+
d2 = idata_uneven.posterior.A.sel(chain=0, draw=1).values.tolist()
307+
numpy.testing.assert_array_equal(d1, list(range(ntune, 20)))
308+
numpy.testing.assert_array_equal(d2, list(range(ntune, 19)))
309+
assert "Chains vary in length" in caplog.records[0].message
310+
assert "see ArviZ issue #2094" in caplog.records[0].message
311+
pass
312+
259313

260314
if __name__ == "__main__":
261315
tc = TestClickHouseBackend()

0 commit comments

Comments
 (0)