Skip to content

Commit dfd4d38

Browse files
Store PyMC SamplerWarning stats as str by pickling
Closes #73
1 parent 1927579 commit dfd4d38

File tree

2 files changed

+28
-10
lines changed

2 files changed

+28
-10
lines changed

mcbackend/adapters/pymc.py

+28-9
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
44
The only PyMC dependency is on the ``BaseTrace`` abstract base class.
55
"""
6+
import base64
7+
import pickle
68
from typing import Dict, List, Optional, Sequence, Tuple
79

810
import hagelkorn
@@ -159,12 +161,16 @@ def setup(
159161
self._stat_groups.append([])
160162
for statname, dtype in names_dtypes.items():
161163
sname = f"sampler_{s}__{statname}"
162-
svar = Variable(
163-
name=sname,
164-
dtype=numpy.dtype(dtype).name,
165-
# This 👇 is needed until PyMC provides shapes ahead of time.
166-
undefined_ndim=True,
167-
)
164+
if statname == "warning":
165+
# SamplerWarnings will be pickled and stored as string!
166+
svar = Variable(sname, "str")
167+
else:
168+
svar = Variable(
169+
name=sname,
170+
dtype=numpy.dtype(dtype).name,
171+
# This 👇 is needed until PyMC provides shapes ahead of time.
172+
undefined_ndim=True,
173+
)
168174
self._stat_groups[s].append((sname, statname))
169175
sample_stats.append(svar)
170176

@@ -197,8 +203,12 @@ def record(self, point, sampler_states=None):
197203
for s, sts in enumerate(sampler_states):
198204
for statname, sval in sts.items():
199205
sname = f"sampler_{s}__{statname}"
200-
stats[sname] = sval
201-
# Make not whether this is a tuning iteration.
206+
# Automatically pickle SamplerWarnings
207+
if statname == "warning":
208+
sval_bytes = pickle.dumps(sval)
209+
sval = base64.encodebytes(sval_bytes).decode("ascii")
210+
stats[sname] = numpy.asarray(sval)
211+
# Make note whether this is a tuning iteration.
202212
if statname == "tune":
203213
stats["tune"] = sval
204214

@@ -214,7 +224,16 @@ def get_values(self, varname, burn=0, thin=1) -> numpy.ndarray:
214224
def _get_stats(self, varname, burn=0, thin=1) -> numpy.ndarray:
215225
if self._chain is None:
216226
raise Exception("Trace setup was not completed. Call `.setup()` first.")
217-
return self._chain.get_stats(varname)[burn::thin]
227+
values = self._chain.get_stats(varname)[burn::thin]
228+
if "warning" in varname:
229+
objs = []
230+
for v in values:
231+
enc = v.encode("ascii")
232+
str_ = base64.decodebytes(enc)
233+
obj = pickle.loads(str_)
234+
objs.append(obj)
235+
values = numpy.array(objs, dtype=object)
236+
return values
218237

219238
def _get_sampler_stats(self, stat_name, sampler_idx, burn, thin):
220239
if self._chain is None:

mcbackend/test_adapter_pymc.py

-1
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,6 @@ def teardown_method(self, method):
6969
self._client_main.disconnect()
7070
return
7171

72-
@pytest.mark.xfail(reason="Warning stats are objects. See #73.")
7372
@pytest.mark.parametrize("cores", [1, 3])
7473
def test_cores(self, simple_model, cores):
7574
backend = ClickHouseBackend(self._client)

0 commit comments

Comments
 (0)