Skip to content

Add test and logging uneven chain lengths #48

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Aug 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ disable =
C0116, # some methods are just too simple to deserve a docstring
R0902, # too many instance attributes
R0903, # too few public methods
R0912, # too many branches
R0913, # too many arguments
R0914, # too many local variables
R1711, # useless return is okay
2 changes: 1 addition & 1 deletion mcbackend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,4 @@
pass


__version__ = "0.1.2"
__version__ = "0.1.3"
23 changes: 19 additions & 4 deletions mcbackend/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,21 +182,36 @@ def to_inferencedata(self, *, equalize_chain_lengths: bool = True, **kwargs) ->

chain_lengths = {c.cid: len(c) for c in chains}
if len(set(chain_lengths.values())) != 1:
_log.warning("Chains vary in length. Lenghts are: %s", chain_lengths)
clen = None
msg = f"Chains vary in length. Lenghts are: {chain_lengths}"
if not equalize_chain_lengths:
msg += (
"\nArviZ does not properly support uneven chain lengths (see ArviZ issue #2094)."
"\nWe'll try to give you an InferenceData, but best case the chain & draw dimensions"
" will be messed-up as {'chain': 1, 'draws': n_chains}."
"\nYou won't be able to save this InferenceData to a file"
" and you should expect many ArviZ functions to choke on it."
"\nSpecify `to_inferencedata(equalize_chain_lengths=True)` to get regular InferenceData."
)
else:
msg += "\nTruncating to the length of the shortest chain."
_log.warning(msg)
min_clen = None
if equalize_chain_lengths:
# A minimum chain length is introduced so that all chains have equal length
clen = min(chain_lengths.values())
min_clen = min(chain_lengths.values())
# Aggregate draws and stats, while splitting into warmup/posterior
warmup_posterior = collections.defaultdict(list)
warmup_sample_stats = collections.defaultdict(list)
posterior = collections.defaultdict(list)
sample_stats = collections.defaultdict(list)
for c, chain in enumerate(chains):
if clen is None:
# NOTE: Replace the truncation with a ranged fetch once issue #47 is resolved.
if min_clen is None:
# Every retrieved array is shortened to the previously determined chain length.
# This is needed for database backends which may get inserts inbetween.
clen = chain_lengths[chain.cid]
else:
clen = min_clen

# Obtain a mask by which draws can be split into warmup/posterior
if "tune" in chain.sample_stats:
Expand Down
54 changes: 54 additions & 0 deletions mcbackend/test_backend_clickhouse.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,60 @@ def test_insert_draw(self):
chain._get_row_at(2, var_names=["v1"])
pass

def test_to_inferencedata_equalize_chain_lengths(self, caplog):
run, chains = fully_initialized(
self.backend,
make_runmeta(
variables=[
Variable("A", "uint16", []),
],
sample_stats=[Variable("tune", "bool")],
data=[],
),
nchains=2,
)
# Create chains of uneven lengths:
# - Chain 0 has 5 tune and 15 draws (length 20)
# - Chain 1 has 5 tune and 14 draws (length 19)
# This simulates the situation where chains aren't synchronized.
ntune = 5

c0 = chains[0]
for i in range(0, 20):
c0.append(dict(A=i), stats=dict(tune=i < ntune))

c1 = chains[1]
for i in range(0, 19):
c1.append(dict(A=i), stats=dict(tune=i < ntune))

assert len(c0) == 20
assert len(c1) == 19

# With equalize=True all chains should have the length of the shortest (here: 7)
# But the first 3 are tuning, so 4 posterior draws remain.
with caplog.at_level(logging.WARNING):
idata_even = run.to_inferencedata(equalize_chain_lengths=True)
assert "Chains vary in length" in caplog.records[0].message
assert "Truncating to" in caplog.records[0].message
assert len(idata_even.posterior.draw) == 14

# With equalize=False the "draw" dim has the length of the longest chain (here: 8-3 = 5)
caplog.clear()
with caplog.at_level(logging.WARNING):
idata_uneven = run.to_inferencedata(equalize_chain_lengths=False)
# These are the messed-up chain and draw dimensions!
assert idata_uneven.posterior.dims["chain"] == 1
assert idata_uneven.posterior.dims["draw"] == 2
# The "draws" are actually the chains, but in a weird scalar object-array?!
# Doing .tolist() seems to be the only way to get our hands on it.
d1 = idata_uneven.posterior.A.sel(chain=0, draw=0).values.tolist()
d2 = idata_uneven.posterior.A.sel(chain=0, draw=1).values.tolist()
numpy.testing.assert_array_equal(d1, list(range(ntune, 20)))
numpy.testing.assert_array_equal(d2, list(range(ntune, 19)))
assert "Chains vary in length" in caplog.records[0].message
assert "see ArviZ issue #2094" in caplog.records[0].message
pass


if __name__ == "__main__":
tc = TestClickHouseBackend()
Expand Down