Skip to content

Commit cd1d354

Browse files
Mark object stats as str-typed
1 parent 65eb592 commit cd1d354

File tree

2 files changed

+18
-1
lines changed

2 files changed

+18
-1
lines changed

pymc/backends/mcbackend.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -221,9 +221,13 @@ def make_runmeta_and_point_fn(
221221
(-1 if s is None else s)
222222
for s in (shape or [])
223223
]
224+
dt = np.dtype(dtype).name
225+
# Object types will be pickled by the ChainRecordAdapter!
226+
if dt == "object":
227+
dt = "str"
224228
svar = mcb.Variable(
225229
name=sname,
226-
dtype=np.dtype(dtype).name,
230+
dtype=dt,
227231
shape=sshape,
228232
undefined_ndim=shape is None,
229233
)

tests/backends/test_mcbackend.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,19 @@ def test_make_runmeta_and_point_fn(simple_model):
120120
assert not vars["vector_interval__"].is_deterministic
121121
assert vars["matrix"].is_deterministic
122122
assert len(rmeta.sample_stats) == len(step.stats_dtypes[0])
123+
124+
with simple_model:
125+
step = pm.NUTS()
126+
rmeta, point_fn = make_runmeta_and_point_fn(
127+
initial_point=simple_model.initial_point(),
128+
step=step,
129+
model=simple_model,
130+
)
131+
assert isinstance(rmeta, mcb.RunMeta)
132+
svars = {s.name: s for s in rmeta.sample_stats}
133+
# Unbeknownst to McBackend, object stats are pickled to str
134+
assert "sampler_0__warning" in svars
135+
assert svars["sampler_0__warning"].dtype == "str"
123136
pass
124137

125138

0 commit comments

Comments
 (0)