File tree 2 files changed +18
-1
lines changed
2 files changed +18
-1
lines changed Original file line number Diff line number Diff line change @@ -221,9 +221,13 @@ def make_runmeta_and_point_fn(
221
221
(- 1 if s is None else s )
222
222
for s in (shape or [])
223
223
]
224
+ dt = np .dtype (dtype ).name
225
+ # Object types will be pickled by the ChainRecordAdapter!
226
+ if dt == "object" :
227
+ dt = "str"
224
228
svar = mcb .Variable (
225
229
name = sname ,
226
- dtype = np . dtype ( dtype ). name ,
230
+ dtype = dt ,
227
231
shape = sshape ,
228
232
undefined_ndim = shape is None ,
229
233
)
Original file line number Diff line number Diff line change @@ -120,6 +120,19 @@ def test_make_runmeta_and_point_fn(simple_model):
120
120
assert not vars ["vector_interval__" ].is_deterministic
121
121
assert vars ["matrix" ].is_deterministic
122
122
assert len (rmeta .sample_stats ) == len (step .stats_dtypes [0 ])
123
+
124
+ with simple_model :
125
+ step = pm .NUTS ()
126
+ rmeta , point_fn = make_runmeta_and_point_fn (
127
+ initial_point = simple_model .initial_point (),
128
+ step = step ,
129
+ model = simple_model ,
130
+ )
131
+ assert isinstance (rmeta , mcb .RunMeta )
132
+ svars = {s .name : s for s in rmeta .sample_stats }
133
+ # Unbeknownst to McBackend, object stats are pickled to str
134
+ assert "sampler_0__warning" in svars
135
+ assert svars ["sampler_0__warning" ].dtype == "str"
123
136
pass
124
137
125
138
You can’t perform that action at this time.
0 commit comments