Skip to content

Commit dcee437

Browse files
committed
Add more info to divergence warnings
1 parent 7842072 commit dcee437

File tree

4 files changed

+57
-33
lines changed

4 files changed

+57
-33
lines changed

pymc3/backends/report.py

Lines changed: 29 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,11 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from collections import namedtuple
1615
import logging
1716
import enum
18-
import typing
17+
from typing import Any, Optional
18+
import dataclasses
19+
1920
from ..util import is_transformed_name, get_untransformed_name
2021

2122
import arviz
@@ -38,9 +39,17 @@ class WarningType(enum.Enum):
3839
BAD_ENERGY = 8
3940

4041

41-
SamplerWarning = namedtuple(
42-
'SamplerWarning',
43-
"kind, message, level, step, exec_info, extra")
42+
@dataclasses.dataclass
43+
class SamplerWarning:
44+
kind: WarningType
45+
message: str
46+
level: str
47+
step: Optional[int] = None
48+
exec_info: Optional[Any] = None
49+
extra: Optional[Any] = None
50+
divergence_point_source: Optional[dict] = None
51+
divergence_point_dest: Optional[dict] = None
52+
divergence_info: Optional[Any] = None
4453

4554

4655
_LEVELS = {
@@ -53,7 +62,8 @@ class WarningType(enum.Enum):
5362

5463

5564
class SamplerReport:
56-
"""This object bundles warnings, convergence statistics and metadata of a sampling run."""
65+
"""Bundle warnings, convergence stats and metadata of a sampling run."""
66+
5767
def __init__(self):
5868
self._chain_warnings = {}
5969
self._global_warnings = []
@@ -75,17 +85,17 @@ def ok(self):
7585
for warn in self._warnings)
7686

7787
@property
78-
def n_tune(self) -> typing.Optional[int]:
88+
def n_tune(self) -> Optional[int]:
7989
"""Number of tune iterations - not necessarily kept in trace!"""
8090
return self._n_tune
8191

8292
@property
83-
def n_draws(self) -> typing.Optional[int]:
93+
def n_draws(self) -> Optional[int]:
8494
"""Number of draw iterations."""
8595
return self._n_draws
8696

8797
@property
88-
def t_sampling(self) -> typing.Optional[float]:
98+
def t_sampling(self) -> Optional[float]:
8999
"""
90100
Number of seconds that the sampling procedure took.
91101
@@ -110,8 +120,7 @@ def _run_convergence_checks(self, idata: arviz.InferenceData, model):
110120
if idata.posterior.sizes['chain'] == 1:
111121
msg = ("Only one chain was sampled, this makes it impossible to "
112122
"run some convergence checks")
113-
warn = SamplerWarning(WarningType.BAD_PARAMS, msg, 'info',
114-
None, None, None)
123+
warn = SamplerWarning(WarningType.BAD_PARAMS, msg, 'info')
115124
self._add_warnings([warn])
116125
return
117126

@@ -134,41 +143,42 @@ def _run_convergence_checks(self, idata: arviz.InferenceData, model):
134143
msg = ("The rhat statistic is larger than 1.4 for some "
135144
"parameters. The sampler did not converge.")
136145
warn = SamplerWarning(
137-
WarningType.CONVERGENCE, msg, 'error', None, None, rhat)
146+
WarningType.CONVERGENCE, msg, 'error', extra=rhat)
138147
warnings.append(warn)
139148
elif rhat_max > 1.2:
140149
msg = ("The rhat statistic is larger than 1.2 for some "
141150
"parameters.")
142151
warn = SamplerWarning(
143-
WarningType.CONVERGENCE, msg, 'warn', None, None, rhat)
152+
WarningType.CONVERGENCE, msg, 'warn', extra=rhat)
144153
warnings.append(warn)
145154
elif rhat_max > 1.05:
146155
msg = ("The rhat statistic is larger than 1.05 for some "
147156
"parameters. This indicates slight problems during "
148157
"sampling.")
149158
warn = SamplerWarning(
150-
WarningType.CONVERGENCE, msg, 'info', None, None, rhat)
159+
WarningType.CONVERGENCE, msg, 'info', extra=rhat)
151160
warnings.append(warn)
152161

153162
eff_min = min(val.min() for val in ess.values())
154-
n_samples = idata.posterior.sizes['chain'] * idata.posterior.sizes['draw']
163+
sizes = idata.posterior.sizes
164+
n_samples = sizes['chain'] * sizes['draw']
155165
if eff_min < 200 and n_samples >= 500:
156166
msg = ("The estimated number of effective samples is smaller than "
157167
"200 for some parameters.")
158168
warn = SamplerWarning(
159-
WarningType.CONVERGENCE, msg, 'error', None, None, ess)
169+
WarningType.CONVERGENCE, msg, 'error', extra=ess)
160170
warnings.append(warn)
161171
elif eff_min / n_samples < 0.1:
162172
msg = ("The number of effective samples is smaller than "
163173
"10% for some parameters.")
164174
warn = SamplerWarning(
165-
WarningType.CONVERGENCE, msg, 'warn', None, None, ess)
175+
WarningType.CONVERGENCE, msg, 'warn', extra=ess)
166176
warnings.append(warn)
167177
elif eff_min / n_samples < 0.25:
168178
msg = ("The number of effective samples is smaller than "
169179
"25% for some parameters.")
170180
warn = SamplerWarning(
171-
WarningType.CONVERGENCE, msg, 'info', None, None, ess)
181+
WarningType.CONVERGENCE, msg, 'info', extra=ess)
172182
warnings.append(warn)
173183

174184
self._add_warnings(warnings)
@@ -201,7 +211,7 @@ def filter_warns(warnings):
201211
filtered.append(warn)
202212
elif (start <= warn.step < stop and
203213
(warn.step - start) % step == 0):
204-
warn = warn._replace(step=warn.step - start)
214+
warn = dataclasses.replace(warn, step=warn.step - start)
205215
filtered.append(warn)
206216
return filtered
207217

pymc3/step_methods/hmc/base_hmc.py

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,16 @@
2929

3030
logger = logging.getLogger("pymc3")
3131

32-
HMCStepData = namedtuple("HMCStepData", "end, accept_stat, divergence_info, stats")
32+
HMCStepData = namedtuple(
33+
"HMCStepData",
34+
"end, accept_stat, divergence_info, stats"
35+
)
3336

37+
DivergenceInfo = namedtuple(
38+
"DivergenceInfo",
39+
"message, exec_info, state, state_div"
40+
)
3441

35-
DivergenceInfo = namedtuple("DivergenceInfo", "message, exec_info, state")
3642

3743
class BaseHMC(arraystep.GradientSharedStep):
3844
"""Superclass to implement Hamiltonian/hybrid monte carlo."""
@@ -155,8 +161,6 @@ def astep(self, q0):
155161
message_energy,
156162
"critical",
157163
self.iter_count,
158-
None,
159-
None,
160164
)
161165
self._warnings.append(warning)
162166
raise SamplingError("Bad initial energy")
@@ -177,19 +181,30 @@ def astep(self, q0):
177181
self.potential.update(hmc_step.end.q, hmc_step.end.q_grad, self.tune)
178182
if hmc_step.divergence_info:
179183
info = hmc_step.divergence_info
184+
point = None
185+
point_dest = None
186+
info_store = None
180187
if self.tune:
181188
kind = WarningType.TUNING_DIVERGENCE
182-
point = None
183189
else:
184190
kind = WarningType.DIVERGENCE
185191
self._num_divs_sample += 1
186192
# We don't want to fill up all memory with divergence info
187193
if self._num_divs_sample < 100:
188194
point = self._logp_dlogp_func.array_to_dict(info.state.q)
189-
else:
190-
point = None
195+
point_dest = self._logp_dlogp_func.array_to_dict(
196+
info.state_div.q
197+
)
198+
info_store = info
191199
warning = SamplerWarning(
192-
kind, info.message, "debug", self.iter_count, info.exec_info, point
200+
kind,
201+
info.message,
202+
"debug",
203+
self.iter_count,
204+
info.exec_info,
205+
divergence_point_source=point,
206+
divergence_point_dest=point_dest,
207+
divergence_info=info_store,
193208
)
194209

195210
self._warnings.append(warning)
@@ -243,9 +258,7 @@ def warnings(self):
243258
)
244259

245260
if message:
246-
warning = SamplerWarning(
247-
WarningType.DIVERGENCES, message, "error", None, None, None
248-
)
261+
warning = SamplerWarning(WarningType.DIVERGENCES, message, "error")
249262
warnings.append(warning)
250263

251264
warnings.extend(self.step_adapt.warnings())

pymc3/step_methods/hmc/nuts.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ def warnings(self):
210210
"The chain reached the maximum tree depth. Increase "
211211
"max_treedepth, increase target_accept or reparameterize."
212212
)
213-
warn = SamplerWarning(WarningType.TREEDEPTH, msg, "warn", None, None, None)
213+
warn = SamplerWarning(WarningType.TREEDEPTH, msg, 'warn')
214214
warnings.append(warn)
215215
return warnings
216216

@@ -331,6 +331,7 @@ def _single_step(self, left, epsilon):
331331
except IntegrationError as err:
332332
error_msg = str(err)
333333
error = err
334+
right = None
334335
else:
335336
# h - H0
336337
energy_change = right.energy - self.start_energy
@@ -363,7 +364,7 @@ def _single_step(self, left, epsilon):
363364
)
364365
error = None
365366
tree = Subtree(None, None, None, None, -np.inf, -np.inf, 1)
366-
divergance_info = DivergenceInfo(error_msg, error, left)
367+
divergance_info = DivergenceInfo(error_msg, error, left, right)
367368
return tree, divergance_info, False
368369

369370
def _build_subtree(self, left, depth, epsilon):

pymc3/step_methods/step_sizes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def warnings(self):
7777
% (mean_accept, target_accept))
7878
info = {'target': target_accept, 'actual': mean_accept}
7979
warning = SamplerWarning(
80-
WarningType.BAD_ACCEPTANCE, msg, 'warn', None, None, info)
80+
WarningType.BAD_ACCEPTANCE, msg, 'warn', extra=info)
8181
return [warning]
8282
else:
8383
return []

0 commit comments

Comments
 (0)