9
9
10
10
from .meta import ChainMeta , RunMeta , Variable
11
11
from .npproto .utils import ndarray_to_numpy
12
+ from .utils import as_array_from_ragged
12
13
13
14
InferenceData = TypeVar ("InferenceData" )
14
15
try :
@@ -252,7 +253,15 @@ def to_inferencedata(self, *, equalize_chain_lengths: bool = True, **kwargs) ->
252
253
warmup_sample_stats [svar .name ].append (stats [tune ])
253
254
sample_stats [svar .name ].append (stats [~ tune ])
254
255
255
- kwargs .setdefault ("save_warmup" , True )
256
+ if not equalize_chain_lengths :
257
+ # Convert ragged arrays to object-dtyped ndarray because NumPy >=1.24.0 no longer does that automatically
258
+ warmup_posterior = {k : as_array_from_ragged (v ) for k , v in warmup_posterior .items ()}
259
+ warmup_sample_stats = {
260
+ k : as_array_from_ragged (v ) for k , v in warmup_sample_stats .items ()
261
+ }
262
+ posterior = {k : as_array_from_ragged (v ) for k , v in posterior .items ()}
263
+ sample_stats = {k : as_array_from_ragged (v ) for k , v in sample_stats .items ()}
264
+
256
265
idata = from_dict (
257
266
warmup_posterior = warmup_posterior ,
258
267
warmup_sample_stats = warmup_sample_stats ,
@@ -263,6 +272,7 @@ def to_inferencedata(self, *, equalize_chain_lengths: bool = True, **kwargs) ->
263
272
attrs = self .meta .attributes ,
264
273
constant_data = self .constant_data ,
265
274
observed_data = self .observed_data ,
275
+ save_warmup = True ,
266
276
** kwargs ,
267
277
)
268
278
return idata
0 commit comments