Skip to content

Commit 8560f1e

Browse files
authored
Honor discard_tuned_samples during KeyboardInterrupt (#3785)
* Honor discard_tuned_samples during KeyboardInterrupt * Do not compute convergence checks without samples
1 parent 747db63 commit 8560f1e

File tree

2 files changed

+14
-2
lines changed

2 files changed

+14
-2
lines changed

pymc3/backends/report.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,14 @@ def raise_ok(self, level='error'):
9999
if errors:
100100
raise ValueError('Serious convergence issues during sampling.')
101101

102-
def _run_convergence_checks(self, idata:arviz.InferenceData, model):
102+
def _run_convergence_checks(self, idata: arviz.InferenceData, model):
103+
if not hasattr(idata, 'posterior'):
104+
msg = "No posterior samples. Unable to run convergence checks"
105+
warn = SamplerWarning(WarningType.BAD_PARAMS, msg, 'info',
106+
None, None, None)
107+
self._add_warnings([warn])
108+
return
109+
103110
if idata.posterior.sizes['chain'] == 1:
104111
msg = ("Only one chain was sampled, this makes it impossible to "
105112
"run some convergence checks")

pymc3/sampling.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -502,6 +502,7 @@ def sample(
502502
"random_seed": random_seed,
503503
"cores": cores,
504504
"callback": callback,
505+
"discard_tuned_samples": discard_tuned_samples,
505506
}
506507

507508
sample_args.update(kwargs)
@@ -1347,6 +1348,7 @@ def _mp_sample(
13471348
trace=None,
13481349
model=None,
13491350
callback=None,
1351+
discard_tuned_samples=True,
13501352
**kwargs
13511353
):
13521354
"""Main iteration for multiprocess sampling.
@@ -1439,7 +1441,10 @@ def _mp_sample(
14391441
raise
14401442
return MultiTrace(traces)
14411443
except KeyboardInterrupt:
1442-
traces, length = _choose_chains(traces, tune)
1444+
if discard_tuned_samples:
1445+
traces, length = _choose_chains(traces, tune)
1446+
else:
1447+
traces, length = _choose_chains(traces, 0)
14431448
return MultiTrace(traces)[:length]
14441449
finally:
14451450
for trace in traces:

0 commit comments

Comments
 (0)