3
3
4
4
The only PyMC dependency is on the ``BaseTrace`` abstract base class.
5
5
"""
6
+ import base64
7
+ import pickle
6
8
from typing import Dict , List , Optional , Sequence , Tuple
7
9
8
10
import hagelkorn
@@ -159,12 +161,16 @@ def setup(
159
161
self ._stat_groups .append ([])
160
162
for statname , dtype in names_dtypes .items ():
161
163
sname = f"sampler_{ s } __{ statname } "
162
- svar = Variable (
163
- name = sname ,
164
- dtype = numpy .dtype (dtype ).name ,
165
- # This 👇 is needed until PyMC provides shapes ahead of time.
166
- undefined_ndim = True ,
167
- )
164
+ if statname == "warning" :
165
+ # SamplerWarnings will be pickled and stored as string!
166
+ svar = Variable (sname , "str" )
167
+ else :
168
+ svar = Variable (
169
+ name = sname ,
170
+ dtype = numpy .dtype (dtype ).name ,
171
+ # This 👇 is needed until PyMC provides shapes ahead of time.
172
+ undefined_ndim = True ,
173
+ )
168
174
self ._stat_groups [s ].append ((sname , statname ))
169
175
sample_stats .append (svar )
170
176
@@ -197,8 +203,12 @@ def record(self, point, sampler_states=None):
197
203
for s , sts in enumerate (sampler_states ):
198
204
for statname , sval in sts .items ():
199
205
sname = f"sampler_{ s } __{ statname } "
200
- stats [sname ] = sval
201
- # Make not whether this is a tuning iteration.
206
+ # Automatically pickle SamplerWarnings
207
+ if statname == "warning" :
208
+ sval_bytes = pickle .dumps (sval )
209
+ sval = base64 .encodebytes (sval_bytes ).decode ("ascii" )
210
+ stats [sname ] = numpy .asarray (sval )
211
+ # Make note whether this is a tuning iteration.
202
212
if statname == "tune" :
203
213
stats ["tune" ] = sval
204
214
@@ -214,7 +224,16 @@ def get_values(self, varname, burn=0, thin=1) -> numpy.ndarray:
214
224
def _get_stats (self , varname , burn = 0 , thin = 1 ) -> numpy .ndarray :
215
225
if self ._chain is None :
216
226
raise Exception ("Trace setup was not completed. Call `.setup()` first." )
217
- return self ._chain .get_stats (varname )[burn ::thin ]
227
+ values = self ._chain .get_stats (varname )[burn ::thin ]
228
+ if "warning" in varname :
229
+ objs = []
230
+ for v in values :
231
+ enc = v .encode ("ascii" )
232
+ str_ = base64 .decodebytes (enc )
233
+ obj = pickle .loads (str_ )
234
+ objs .append (obj )
235
+ values = numpy .array (objs , dtype = object )
236
+ return values
218
237
219
238
def _get_sampler_stats (self , stat_name , sampler_idx , burn , thin ):
220
239
if self ._chain is None :
0 commit comments