Skip to content

Commit a5c1aae

Browse files
authored
Add boolean flag to equalize chain lengths (#46)
Closes #44
1 parent 2d3b21f commit a5c1aae

File tree

1 file changed

+11
-5
lines changed

1 file changed

+11
-5
lines changed

mcbackend/core.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -152,11 +152,13 @@ def constant_data(self) -> Dict[str, numpy.ndarray]:
152152
def observed_data(self) -> Dict[str, numpy.ndarray]:
153153
return {dv.name: ndarray_to_numpy(dv.value) for dv in self.meta.data if dv.is_observed}
154154

155-
def to_inferencedata(self, **kwargs) -> InferenceData:
155+
def to_inferencedata(self, *, equalize_chain_lengths: bool = True, **kwargs) -> InferenceData:
156156
"""Creates an ArviZ ``InferenceData`` object from this run.
157157
158158
Parameters
159159
----------
160+
equalize_chain_lengths : bool
161+
Whether to truncate all chains to the shortest chain length (default: ``True``).
160162
**kwargs
161163
Will be forwarded to ``arviz.from_dict()``.
162164
@@ -181,16 +183,20 @@ def to_inferencedata(self, **kwargs) -> InferenceData:
181183
chain_lengths = {c.cid: len(c) for c in chains}
182184
if len(set(chain_lengths.values())) != 1:
183185
_log.warning("Chains vary in length. Lenghts are: %s", chain_lengths)
184-
186+
clen = None
187+
if equalize_chain_lengths:
188+
# A minimum chain length is introduced so that all chains have equal length
189+
clen = min(chain_lengths.values())
185190
# Aggregate draws and stats, while splitting into warmup/posterior
186191
warmup_posterior = collections.defaultdict(list)
187192
warmup_sample_stats = collections.defaultdict(list)
188193
posterior = collections.defaultdict(list)
189194
sample_stats = collections.defaultdict(list)
190195
for c, chain in enumerate(chains):
191-
# Every retrieved array is shortened to the previously determined chain length.
192-
# This is needed for database backends which may get inserts inbetween.
193-
clen = chain_lengths[chain.cid]
196+
if clen is None:
197+
# Every retrieved array is shortened to the previously determined chain length.
198+
# This is needed for database backends which may get inserts inbetween.
199+
clen = chain_lengths[chain.cid]
194200

195201
# Obtain a mask by which draws can be split into warmup/posterior
196202
if "tune" in chain.sample_stats:

0 commit comments

Comments
 (0)