Skip to content

Commit 1511ddb

Browse files
Get tune mask from "tune" or "*__tune" stats
This restores compatibility with MCMC runs from PyMC >= 5.7.0. Closes #102
1 parent a67d391 commit 1511ddb

File tree

2 files changed

+23
-5
lines changed

2 files changed

+23
-5
lines changed

mcbackend/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
except ModuleNotFoundError:
1313
pass
1414

15-
__version__ = "0.5.1"
15+
__version__ = "0.5.2"
1616
__all__ = [
1717
"NumPyBackend",
1818
"Backend",

mcbackend/core.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@
2323
_log = logging.getLogger(__file__)
2424

2525

26+
__all__ = ("is_rigid", "chain_id", "Chain", "Run", "Backend")
27+
28+
2629
def is_rigid(nshape: Optional[Shape]):
2730
"""Determines wheather the shape is constant.
2831
@@ -133,6 +136,20 @@ def sample_stats(self) -> Dict[str, Variable]:
133136
return {var.name: var for var in self.rmeta.sample_stats}
134137

135138

139+
def get_tune_mask(chain: Chain, slc: slice = slice(None)) -> numpy.ndarray:
140+
"""Load the tuning mask from either a ``"tune"``, or a ``"*__tune"`` stat.
141+
142+
Raises
143+
------
144+
KeyError
145+
When no matching stat is found.
146+
"""
147+
for sname in chain.sample_stats:
148+
if sname.endswith("__tune") or sname == "tune":
149+
return chain.get_stats(sname, slc).astype(bool)
150+
raise KeyError("No tune stat found.")
151+
152+
136153
class Run:
137154
"""A handle on one MCMC run."""
138155

@@ -231,14 +248,15 @@ def to_inferencedata(self, *, equalize_chain_lengths: bool = True, **kwargs) ->
231248
slc = slice(0, min_clen)
232249

233250
# Obtain a mask by which draws can be split into warmup/posterior
234-
if "tune" in chain.sample_stats:
235-
tune = chain.get_stats("tune", slc).astype(bool)
236-
else:
251+
try:
252+
# Use the same slice to avoid shape issues in case the chain is still active
253+
tune = get_tune_mask(chain, slc)
254+
except KeyError:
237255
if c == 0:
238256
_log.warning(
239257
"No 'tune' stat found. Assuming all iterations are posterior draws."
240258
)
241-
tune = numpy.full((chain_lengths[chain.cid],), False)
259+
tune = numpy.full((slc.stop,), False)
242260

243261
# Split all variables draws into warmup/posterior
244262
for var in variables:

0 commit comments

Comments
 (0)