@@ -152,11 +152,13 @@ def constant_data(self) -> Dict[str, numpy.ndarray]:
152
152
def observed_data (self ) -> Dict [str , numpy .ndarray ]:
153
153
return {dv .name : ndarray_to_numpy (dv .value ) for dv in self .meta .data if dv .is_observed }
154
154
155
- def to_inferencedata (self , ** kwargs ) -> InferenceData :
155
+ def to_inferencedata (self , * , equalize_chain_lengths : bool = True , * *kwargs ) -> InferenceData :
156
156
"""Creates an ArviZ ``InferenceData`` object from this run.
157
157
158
158
Parameters
159
159
----------
160
+ equalize_chain_lengths : bool
161
+ Whether to truncate all chains to the shortest chain length (default: ``True``).
160
162
**kwargs
161
163
Will be forwarded to ``arviz.from_dict()``.
162
164
@@ -181,16 +183,20 @@ def to_inferencedata(self, **kwargs) -> InferenceData:
181
183
chain_lengths = {c .cid : len (c ) for c in chains }
182
184
if len (set (chain_lengths .values ())) != 1 :
183
185
_log .warning ("Chains vary in length. Lenghts are: %s" , chain_lengths )
184
-
186
+ clen = None
187
+ if equalize_chain_lengths :
188
+ # A minimum chain length is introduced so that all chains have equal length
189
+ clen = min (chain_lengths .values ())
185
190
# Aggregate draws and stats, while splitting into warmup/posterior
186
191
warmup_posterior = collections .defaultdict (list )
187
192
warmup_sample_stats = collections .defaultdict (list )
188
193
posterior = collections .defaultdict (list )
189
194
sample_stats = collections .defaultdict (list )
190
195
for c , chain in enumerate (chains ):
191
- # Every retrieved array is shortened to the previously determined chain length.
192
- # This is needed for database backends which may get inserts inbetween.
193
- clen = chain_lengths [chain .cid ]
196
+ if clen is None :
197
+ # Every retrieved array is shortened to the previously determined chain length.
198
+ # This is needed for database backends which may get inserts inbetween.
199
+ clen = chain_lengths [chain .cid ]
194
200
195
201
# Obtain a mask by which draws can be split into warmup/posterior
196
202
if "tune" in chain .sample_stats :
0 commit comments