File tree 2 files changed +14
-2
lines changed 2 files changed +14
-2
lines changed Original file line number Diff line number Diff line change @@ -99,7 +99,14 @@ def raise_ok(self, level='error'):
99
99
if errors :
100
100
raise ValueError ('Serious convergence issues during sampling.' )
101
101
102
- def _run_convergence_checks (self , idata :arviz .InferenceData , model ):
102
+ def _run_convergence_checks (self , idata : arviz .InferenceData , model ):
103
+ if not hasattr (idata , 'posterior' ):
104
+ msg = "No posterior samples. Unable to run convergence checks"
105
+ warn = SamplerWarning (WarningType .BAD_PARAMS , msg , 'info' ,
106
+ None , None , None )
107
+ self ._add_warnings ([warn ])
108
+ return
109
+
103
110
if idata .posterior .sizes ['chain' ] == 1 :
104
111
msg = ("Only one chain was sampled, this makes it impossible to "
105
112
"run some convergence checks" )
Original file line number Diff line number Diff line change @@ -502,6 +502,7 @@ def sample(
502
502
"random_seed" : random_seed ,
503
503
"cores" : cores ,
504
504
"callback" : callback ,
505
+ "discard_tuned_samples" : discard_tuned_samples ,
505
506
}
506
507
507
508
sample_args .update (kwargs )
@@ -1347,6 +1348,7 @@ def _mp_sample(
1347
1348
trace = None ,
1348
1349
model = None ,
1349
1350
callback = None ,
1351
+ discard_tuned_samples = True ,
1350
1352
** kwargs
1351
1353
):
1352
1354
"""Main iteration for multiprocess sampling.
@@ -1439,7 +1441,10 @@ def _mp_sample(
1439
1441
raise
1440
1442
return MultiTrace (traces )
1441
1443
except KeyboardInterrupt :
1442
- traces , length = _choose_chains (traces , tune )
1444
+ if discard_tuned_samples :
1445
+ traces , length = _choose_chains (traces , tune )
1446
+ else :
1447
+ traces , length = _choose_chains (traces , 0 )
1443
1448
return MultiTrace (traces )[:length ]
1444
1449
finally :
1445
1450
for trace in traces :
You can’t perform that action at this time.
0 commit comments