Skip to content

Commit 415e075

Browse files
Pass slice to get draws/stats in to_inferencedata
1 parent 8faf25e commit 415e075

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

mcbackend/core.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -223,17 +223,17 @@ def to_inferencedata(self, *, equalize_chain_lengths: bool = True, **kwargs) ->
223223
posterior = collections.defaultdict(list)
224224
sample_stats = collections.defaultdict(list)
225225
for c, chain in enumerate(chains):
226-
# NOTE: Replace the truncation with a ranged fetch once issue #47 is resolved.
226+
# Create a slice to use when fetching the variables
227227
if min_clen is None:
228228
# Every retrieved array is shortened to the previously determined chain length.
229-
# This is needed for database backends which may get inserts inbetween.
230-
clen = chain_lengths[chain.cid]
229+
# Needed for backends which may get inserts inbetween our get_draws/get_stats calls.
230+
slc = slice(0, chain_lengths[chain.cid])
231231
else:
232-
clen = min_clen
232+
slc = slice(0, min_clen)
233233

234234
# Obtain a mask by which draws can be split into warmup/posterior
235235
if "tune" in chain.sample_stats:
236-
tune = chain.get_stats("tune")[:clen].astype(bool)
236+
tune = chain.get_stats("tune", slc).astype(bool)
237237
else:
238238
if c == 0:
239239
_log.warning(
@@ -243,12 +243,12 @@ def to_inferencedata(self, *, equalize_chain_lengths: bool = True, **kwargs) ->
243243

244244
# Split all variables draws into warmup/posterior
245245
for var in variables:
246-
draws = chain.get_draws(var.name)[:clen]
246+
draws = chain.get_draws(var.name, slc)
247247
warmup_posterior[var.name].append(draws[tune])
248248
posterior[var.name].append(draws[~tune])
249249
# Same for sample stats
250250
for svar in self.meta.sample_stats:
251-
stats = chain.get_stats(svar.name)[:clen]
251+
stats = chain.get_stats(svar.name, slc)
252252
warmup_sample_stats[svar.name].append(stats[tune])
253253
sample_stats[svar.name].append(stats[~tune])
254254

0 commit comments

Comments
 (0)