|
23 | 23 | _log = logging.getLogger(__file__)
|
24 | 24 |
|
25 | 25 |
|
| 26 | +__all__ = ("is_rigid", "chain_id", "Chain", "Run", "Backend") |
| 27 | + |
| 28 | + |
26 | 29 | def is_rigid(nshape: Optional[Shape]):
|
27 | 30 | """Determines wheather the shape is constant.
|
28 | 31 |
|
@@ -133,6 +136,20 @@ def sample_stats(self) -> Dict[str, Variable]:
|
133 | 136 | return {var.name: var for var in self.rmeta.sample_stats}
|
134 | 137 |
|
135 | 138 |
|
| 139 | +def get_tune_mask(chain: Chain, slc: slice = slice(None)) -> numpy.ndarray: |
| 140 | + """Load the tuning mask from either a ``"tune"``, or a ``"*__tune"`` stat. |
| 141 | +
|
| 142 | + Raises |
| 143 | + ------ |
| 144 | + KeyError |
| 145 | + When no matching stat is found. |
| 146 | + """ |
| 147 | + for sname in chain.sample_stats: |
| 148 | + if sname.endswith("__tune") or sname == "tune": |
| 149 | + return chain.get_stats(sname, slc).astype(bool) |
| 150 | + raise KeyError("No tune stat found.") |
| 151 | + |
| 152 | + |
136 | 153 | class Run:
|
137 | 154 | """A handle on one MCMC run."""
|
138 | 155 |
|
@@ -231,14 +248,15 @@ def to_inferencedata(self, *, equalize_chain_lengths: bool = True, **kwargs) ->
|
231 | 248 | slc = slice(0, min_clen)
|
232 | 249 |
|
233 | 250 | # Obtain a mask by which draws can be split into warmup/posterior
|
234 |
| - if "tune" in chain.sample_stats: |
235 |
| - tune = chain.get_stats("tune", slc).astype(bool) |
236 |
| - else: |
| 251 | + try: |
| 252 | + # Use the same slice to avoid shape issues in case the chain is still active |
| 253 | + tune = get_tune_mask(chain, slc) |
| 254 | + except KeyError: |
237 | 255 | if c == 0:
|
238 | 256 | _log.warning(
|
239 | 257 | "No 'tune' stat found. Assuming all iterations are posterior draws."
|
240 | 258 | )
|
241 |
| - tune = numpy.full((chain_lengths[chain.cid],), False) |
| 259 | + tune = numpy.full((slc.stop,), False) |
242 | 260 |
|
243 | 261 | # Split all variables draws into warmup/posterior
|
244 | 262 | for var in variables:
|
|
0 commit comments