Skip to content

Add more info to divergence warnings #3990

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jul 5, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 29 additions & 19 deletions pymc3/backends/report.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from collections import namedtuple
import logging
import enum
import typing
from typing import Any, Optional
import dataclasses

from ..util import is_transformed_name, get_untransformed_name

import arviz
Expand All @@ -38,9 +39,17 @@ class WarningType(enum.Enum):
BAD_ENERGY = 8


SamplerWarning = namedtuple(
'SamplerWarning',
"kind, message, level, step, exec_info, extra")
@dataclasses.dataclass
class SamplerWarning:
kind: WarningType
message: str
level: str
step: Optional[int] = None
exec_info: Optional[Any] = None
extra: Optional[Any] = None
divergence_point_source: Optional[dict] = None
divergence_point_dest: Optional[dict] = None
divergence_info: Optional[Any] = None


_LEVELS = {
Expand All @@ -53,7 +62,8 @@ class WarningType(enum.Enum):


class SamplerReport:
"""This object bundles warnings, convergence statistics and metadata of a sampling run."""
"""Bundle warnings, convergence stats and metadata of a sampling run."""

def __init__(self):
self._chain_warnings = {}
self._global_warnings = []
Expand All @@ -75,17 +85,17 @@ def ok(self):
for warn in self._warnings)

@property
def n_tune(self) -> typing.Optional[int]:
def n_tune(self) -> Optional[int]:
"""Number of tune iterations - not necessarily kept in trace!"""
return self._n_tune

@property
def n_draws(self) -> typing.Optional[int]:
def n_draws(self) -> Optional[int]:
"""Number of draw iterations."""
return self._n_draws

@property
def t_sampling(self) -> typing.Optional[float]:
def t_sampling(self) -> Optional[float]:
"""
Number of seconds that the sampling procedure took.

Expand All @@ -110,8 +120,7 @@ def _run_convergence_checks(self, idata: arviz.InferenceData, model):
if idata.posterior.sizes['chain'] == 1:
msg = ("Only one chain was sampled, this makes it impossible to "
"run some convergence checks")
warn = SamplerWarning(WarningType.BAD_PARAMS, msg, 'info',
None, None, None)
warn = SamplerWarning(WarningType.BAD_PARAMS, msg, 'info')
self._add_warnings([warn])
return

Expand All @@ -134,41 +143,42 @@ def _run_convergence_checks(self, idata: arviz.InferenceData, model):
msg = ("The rhat statistic is larger than 1.4 for some "
"parameters. The sampler did not converge.")
warn = SamplerWarning(
WarningType.CONVERGENCE, msg, 'error', None, None, rhat)
WarningType.CONVERGENCE, msg, 'error', extra=rhat)
warnings.append(warn)
elif rhat_max > 1.2:
msg = ("The rhat statistic is larger than 1.2 for some "
"parameters.")
warn = SamplerWarning(
WarningType.CONVERGENCE, msg, 'warn', None, None, rhat)
WarningType.CONVERGENCE, msg, 'warn', extra=rhat)
warnings.append(warn)
elif rhat_max > 1.05:
msg = ("The rhat statistic is larger than 1.05 for some "
"parameters. This indicates slight problems during "
"sampling.")
warn = SamplerWarning(
WarningType.CONVERGENCE, msg, 'info', None, None, rhat)
WarningType.CONVERGENCE, msg, 'info', extra=rhat)
warnings.append(warn)

eff_min = min(val.min() for val in ess.values())
n_samples = idata.posterior.sizes['chain'] * idata.posterior.sizes['draw']
sizes = idata.posterior.sizes
n_samples = sizes['chain'] * sizes['draw']
if eff_min < 200 and n_samples >= 500:
msg = ("The estimated number of effective samples is smaller than "
"200 for some parameters.")
warn = SamplerWarning(
WarningType.CONVERGENCE, msg, 'error', None, None, ess)
WarningType.CONVERGENCE, msg, 'error', extra=ess)
warnings.append(warn)
elif eff_min / n_samples < 0.1:
msg = ("The number of effective samples is smaller than "
"10% for some parameters.")
warn = SamplerWarning(
WarningType.CONVERGENCE, msg, 'warn', None, None, ess)
WarningType.CONVERGENCE, msg, 'warn', extra=ess)
warnings.append(warn)
elif eff_min / n_samples < 0.25:
msg = ("The number of effective samples is smaller than "
"25% for some parameters.")
warn = SamplerWarning(
WarningType.CONVERGENCE, msg, 'info', None, None, ess)
WarningType.CONVERGENCE, msg, 'info', extra=ess)
warnings.append(warn)

self._add_warnings(warnings)
Expand Down Expand Up @@ -201,7 +211,7 @@ def filter_warns(warnings):
filtered.append(warn)
elif (start <= warn.step < stop and
(warn.step - start) % step == 0):
warn = warn._replace(step=warn.step - start)
warn = dataclasses.replace(warn, step=warn.step - start)
filtered.append(warn)
return filtered

Expand Down
42 changes: 29 additions & 13 deletions pymc3/step_methods/hmc/base_hmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,16 @@

logger = logging.getLogger("pymc3")

HMCStepData = namedtuple("HMCStepData", "end, accept_stat, divergence_info, stats")
HMCStepData = namedtuple(
"HMCStepData",
"end, accept_stat, divergence_info, stats"
)

DivergenceInfo = namedtuple(
"DivergenceInfo",
"message, exec_info, state, state_div"
)

DivergenceInfo = namedtuple("DivergenceInfo", "message, exec_info, state")

class BaseHMC(arraystep.GradientSharedStep):
"""Superclass to implement Hamiltonian/hybrid monte carlo."""
Expand Down Expand Up @@ -148,15 +154,14 @@ def astep(self, q0):
self.potential.raise_ok(self._logp_dlogp_func._ordering.vmap)
message_energy = (
"Bad initial energy, check any log probabilities that "
"are inf or -inf, nan or very small:\n{}".format(error_logp.to_string())
"are inf or -inf, nan or very small:\n{}"
.format(error_logp.to_string())
)
warning = SamplerWarning(
WarningType.BAD_ENERGY,
message_energy,
"critical",
self.iter_count,
None,
None,
)
self._warnings.append(warning)
raise SamplingError("Bad initial energy")
Expand All @@ -177,19 +182,32 @@ def astep(self, q0):
self.potential.update(hmc_step.end.q, hmc_step.end.q_grad, self.tune)
if hmc_step.divergence_info:
info = hmc_step.divergence_info
point = None
point_dest = None
info_store = None
if self.tune:
kind = WarningType.TUNING_DIVERGENCE
point = None
else:
kind = WarningType.DIVERGENCE
self._num_divs_sample += 1
# We don't want to fill up all memory with divergence info
if self._num_divs_sample < 100:
if self._num_divs_sample < 100 and info.state is not None:
point = self._logp_dlogp_func.array_to_dict(info.state.q)
else:
point = None
if self._num_divs_sample < 100 and info.state_div is not None:
point_dest = self._logp_dlogp_func.array_to_dict(
info.state_div.q
)
if self._num_divs_sample < 100:
info_store = info
warning = SamplerWarning(
kind, info.message, "debug", self.iter_count, info.exec_info, point
kind,
info.message,
"debug",
self.iter_count,
info.exec_info,
divergence_point_source=point,
divergence_point_dest=point_dest,
divergence_info=info_store,
)

self._warnings.append(warning)
Expand Down Expand Up @@ -243,9 +261,7 @@ def warnings(self):
)

if message:
warning = SamplerWarning(
WarningType.DIVERGENCES, message, "error", None, None, None
)
warning = SamplerWarning(WarningType.DIVERGENCES, message, "error")
warnings.append(warning)

warnings.extend(self.step_adapt.warnings())
Expand Down
8 changes: 5 additions & 3 deletions pymc3/step_methods/hmc/hmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,23 +116,25 @@ def _hamiltonian_step(self, start, p0, step_size):

energy_change = -np.inf
state = start
last = state
div_info = None
try:
for _ in range(n_steps):
last = state
state = self.integrator.step(step_size, state)
except IntegrationError as e:
div_info = DivergenceInfo('Divergence encountered.', e, state)
div_info = DivergenceInfo('Integration failed.', e, last, None)
else:
if not np.isfinite(state.energy):
div_info = DivergenceInfo(
'Divergence encountered, bad energy.', None, state)
'Divergence encountered, bad energy.', None, last, state)
energy_change = start.energy - state.energy
if np.isnan(energy_change):
energy_change = -np.inf
if np.abs(energy_change) > self.Emax:
div_info = DivergenceInfo(
'Divergence encountered, large integration error.',
None, state)
None, last, state)

accept_stat = min(1, np.exp(energy_change))

Expand Down
5 changes: 3 additions & 2 deletions pymc3/step_methods/hmc/nuts.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def warnings(self):
"The chain reached the maximum tree depth. Increase "
"max_treedepth, increase target_accept or reparameterize."
)
warn = SamplerWarning(WarningType.TREEDEPTH, msg, "warn", None, None, None)
warn = SamplerWarning(WarningType.TREEDEPTH, msg, 'warn')
warnings.append(warn)
return warnings

Expand Down Expand Up @@ -331,6 +331,7 @@ def _single_step(self, left, epsilon):
except IntegrationError as err:
error_msg = str(err)
error = err
right = None
else:
# h - H0
energy_change = right.energy - self.start_energy
Expand Down Expand Up @@ -363,7 +364,7 @@ def _single_step(self, left, epsilon):
)
error = None
tree = Subtree(None, None, None, None, -np.inf, -np.inf, 1)
divergance_info = DivergenceInfo(error_msg, error, left)
divergance_info = DivergenceInfo(error_msg, error, left, right)
return tree, divergance_info, False

def _build_subtree(self, left, depth, epsilon):
Expand Down
2 changes: 1 addition & 1 deletion pymc3/step_methods/step_sizes.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def warnings(self):
% (mean_accept, target_accept))
info = {'target': target_accept, 'actual': mean_accept}
warning = SamplerWarning(
WarningType.BAD_ACCEPTANCE, msg, 'warn', None, None, info)
WarningType.BAD_ACCEPTANCE, msg, 'warn', extra=info)
return [warning]
else:
return []
1 change: 0 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,3 @@ patsy>=0.5.1
fastprogress>=0.2.0
h5py>=2.7.0
typing-extensions>=3.7.4
contextvars; python_version < '3.7'