Skip to content

Collect sampler warnings only through stats #6192

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Oct 28, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pymc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def __set_compiler_flags():
from pymc.stats import *
from pymc.step_methods import *
from pymc.tuning import *
from pymc.util import drop_warning_stat
from pymc.variational import *
from pymc.vartypes import *

Expand Down
7 changes: 0 additions & 7 deletions pymc/backends/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,6 @@ def __init__(self, name, model=None, vars=None, test_point=None):
self.chain = None
self._is_base_setup = False
self.sampler_vars = None
self._warnings = []

def _add_warnings(self, warnings):
self._warnings.extend(warnings)

# Sampling methods

Expand Down Expand Up @@ -288,9 +284,6 @@ def __init__(self, straces):
self._straces[strace.chain] = strace

self._report = SamplerReport()
for strace in straces:
if hasattr(strace, "_warnings"):
self._report._add_warnings(strace._warnings, strace.chain)

def __repr__(self):
template = "<{}: {} chains, {} iterations, {} variables>"
Expand Down
52 changes: 21 additions & 31 deletions pymc/parallel_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,9 @@


class ParallelSamplingError(Exception):
def __init__(self, message, chain, warnings=None):
def __init__(self, message, chain):
super().__init__(message)
if warnings is None:
warnings = []
self._chain = chain
self._warnings = warnings


# Taken from https://hg.python.org/cpython/rev/c4f92b597074
Expand Down Expand Up @@ -74,8 +71,8 @@ def rebuild_exc(exc, tb):


# Messages
# ('writing_done', is_last, sample_idx, tuning, stats, warns)
# ('error', warnings, *exception_info)
# ('writing_done', is_last, sample_idx, tuning, stats)
# ('error', *exception_info)

# ('abort', reason)
# ('write_next',)
Expand Down Expand Up @@ -133,7 +130,7 @@ def run(self):
e = ExceptionWithTraceback(e, e.__traceback__)
# Send is not blocking so we have to force a wait for the abort
# message
self._msg_pipe.send(("error", None, e))
self._msg_pipe.send(("error", e))
self._wait_for_abortion()
finally:
self._msg_pipe.close()
Expand Down Expand Up @@ -181,9 +178,8 @@ def _start_loop(self):
try:
point, stats = self._compute_point()
except SamplingError as e:
warns = self._collect_warnings()
e = ExceptionWithTraceback(e, e.__traceback__)
self._msg_pipe.send(("error", warns, e))
self._msg_pipe.send(("error", e))
else:
return

Expand All @@ -193,11 +189,7 @@ def _start_loop(self):
elif msg[0] == "write_next":
self._write_point(point)
is_last = draw + 1 == self._draws + self._tune
if is_last:
warns = self._collect_warnings()
else:
warns = None
self._msg_pipe.send(("writing_done", is_last, draw, tuning, stats, warns))
self._msg_pipe.send(("writing_done", is_last, draw, tuning, stats))
draw += 1
else:
raise ValueError("Unknown message " + msg[0])
Expand All @@ -210,12 +202,6 @@ def _compute_point(self):
stats = None
return point, stats

def _collect_warnings(self):
if hasattr(self._step_method, "warnings"):
return self._step_method.warnings()
else:
return []


def _run_process(*args):
_Process(*args).run()
Expand Down Expand Up @@ -308,11 +294,13 @@ def _send(self, msg, *args):
except Exception:
pass
if message is not None and message[0] == "error":
warns, old_error = message[1:]
if warns is not None:
error = ParallelSamplingError(str(old_error), self.chain, warns)
old_error = message[1]
if old_error is not None:
error = ParallelSamplingError(
f"Chain {self.chain} failed with: {old_error}", self.chain
)
else:
error = RuntimeError("Chain %s failed." % self.chain)
error = RuntimeError(f"Chain {self.chain} failed.")
raise error from old_error
raise

Expand Down Expand Up @@ -345,11 +333,13 @@ def recv_draw(processes, timeout=3600):
msg = ready[0].recv()

if msg[0] == "error":
warns, old_error = msg[1:]
if warns is not None:
error = ParallelSamplingError(str(old_error), proc.chain, warns)
old_error = msg[1]
if old_error is not None:
error = ParallelSamplingError(
f"Chain {proc.chain} failed with: {old_error}", proc.chain
)
else:
error = RuntimeError("Chain %s failed." % proc.chain)
error = RuntimeError(f"Chain {proc.chain} failed.")
raise error from old_error
elif msg[0] == "writing_done":
proc._readable = True
Expand Down Expand Up @@ -383,7 +373,7 @@ def terminate_all(processes, patience=2):
process.join()


Draw = namedtuple("Draw", ["chain", "is_last", "draw_idx", "tuning", "stats", "point", "warnings"])
Draw = namedtuple("Draw", ["chain", "is_last", "draw_idx", "tuning", "stats", "point"])


class ParallelSampler:
Expand Down Expand Up @@ -466,7 +456,7 @@ def __iter__(self):

while self._active:
draw = ProcessAdapter.recv_draw(self._active)
proc, is_last, draw, tuning, stats, warns = draw
proc, is_last, draw, tuning, stats = draw
self._total_draws += 1
if not tuning and stats and stats[0].get("diverging"):
self._divergences += 1
Expand All @@ -491,7 +481,7 @@ def __iter__(self):
if not is_last:
proc.write_next()

yield Draw(proc.chain, is_last, draw, tuning, stats, point, warns)
yield Draw(proc.chain, is_last, draw, tuning, stats, point)

def __enter__(self):
self._in_context = True
Expand Down
51 changes: 35 additions & 16 deletions pymc/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,12 +70,13 @@
)
from pymc.model import Model, modelcontext
from pymc.parallel_sampling import Draw, _cpu_count
from pymc.stats.convergence import run_convergence_checks
from pymc.stats.convergence import SamplerWarning, log_warning, run_convergence_checks
from pymc.step_methods import NUTS, CompoundStep, DEMetropolis
from pymc.step_methods.arraystep import BlockedStep, PopulationArrayStepShared
from pymc.step_methods.hmc import quadpotential
from pymc.util import (
dataset_to_point_list,
drop_warning_stat,
get_default_varnames,
get_untransformed_name,
is_transformed_name,
Expand Down Expand Up @@ -323,6 +324,7 @@ def sample(
jitter_max_retries: int = 10,
*,
return_inferencedata: bool = True,
keep_warning_stat: bool = False,
idata_kwargs: dict = None,
mp_ctx=None,
**kwargs,
Expand Down Expand Up @@ -393,6 +395,13 @@ def sample(
`MultiTrace` (False). Defaults to `True`.
idata_kwargs : dict, optional
Keyword arguments for :func:`pymc.to_inference_data`
keep_warning_stat : bool
If ``True`` the "warning" stat emitted by, for example, HMC samplers will be kept
in the returned ``idata.sample_stat`` group.
This leads to the ``idata`` not supporting ``.to_netcdf()`` or ``.to_zarr()`` and
should only be set to ``True`` if you intend to use the "warning" objects right away.
Defaults to ``False`` such that ``pm.drop_warning_stat`` is applied automatically,
making the ``InferenceData`` compatible with saving.
mp_ctx : multiprocessing.context.BaseContent
A multiprocessing context for parallel sampling.
See multiprocessing documentation for details.
Expand Down Expand Up @@ -699,6 +708,10 @@ def sample(
mtrace.report._add_warnings(convergence_warnings)

if return_inferencedata:
# By default we drop the "warning" stat which contains `SamplerWarning`
# objects that can not be stored with `.to_netcdf()`.
if not keep_warning_stat:
return drop_warning_stat(idata)
return idata
return mtrace

Expand Down Expand Up @@ -1048,32 +1061,26 @@ def _iter_sample(
if step.generates_stats:
point, stats = step.step(point)
strace.record(point, stats)
log_warning_stats(stats)
diverging = i > tune and stats and stats[0].get("diverging")
else:
point = step.step(point)
strace.record(point)
if callback is not None:
warns = getattr(step, "warnings", None)
callback(
trace=strace,
draw=Draw(chain, i == draws, i, i < tune, stats, point, warns),
draw=Draw(chain, i == draws, i, i < tune, stats, point),
)

yield strace, diverging
except KeyboardInterrupt:
strace.close()
if hasattr(step, "warnings"):
warns = step.warnings()
strace._add_warnings(warns)
raise
except BaseException:
strace.close()
raise
else:
strace.close()
if hasattr(step, "warnings"):
warns = step.warnings()
strace._add_warnings(warns)


class PopulationStepper:
Expand Down Expand Up @@ -1356,6 +1363,7 @@ def _iter_population(
if steppers[c].generates_stats:
points[c], stats = updates[c]
strace.record(points[c], stats)
log_warning_stats(stats)
else:
points[c] = updates[c]
strace.record(points[c])
Expand Down Expand Up @@ -1513,21 +1521,16 @@ def _mp_sample(
with sampler:
for draw in sampler:
strace = traces[draw.chain]
if draw.stats is not None:
strace.record(draw.point, draw.stats)
else:
strace.record(draw.point)
strace.record(draw.point, draw.stats)
log_warning_stats(draw.stats)
if draw.is_last:
strace.close()
if draw.warnings is not None:
strace._add_warnings(draw.warnings)

if callback is not None:
callback(trace=trace, draw=draw)

except ps.ParallelSamplingError as error:
strace = traces[error._chain]
strace._add_warnings(error._warnings)
for strace in traces:
strace.close()

Expand All @@ -1546,6 +1549,22 @@ def _mp_sample(
strace.close()


def log_warning_stats(stats: Sequence[Dict[str, Any]]):
"""Logs 'warning' stats if present."""
if stats is None:
return

for sts in stats:
warn = sts.get("warning", None)
if warn is None:
continue
if isinstance(warn, SamplerWarning):
log_warning(warn)
else:
_log.warning(warn)
return


def _choose_chains(traces: Sequence[BaseTrace], tune: int) -> Tuple[List[BaseTrace], int]:
"""
Filter and slice traces such that (n_traces * len(shortest_trace)) is maximized.
Expand Down
53 changes: 51 additions & 2 deletions pymc/stats/convergence.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def run_convergence_checks(idata: arviz.InferenceData, model) -> List[SamplerWar
warn = SamplerWarning(WarningType.BAD_PARAMS, msg, "info")
return [warn]

warnings = []
warnings: List[SamplerWarning] = []
valid_name = [rv.name for rv in model.free_RVs + model.deterministics]
varnames = []
for rv in model.free_RVs:
Expand Down Expand Up @@ -104,11 +104,60 @@ def run_convergence_checks(idata: arviz.InferenceData, model) -> List[SamplerWar
warn = SamplerWarning(WarningType.CONVERGENCE, msg, "error", extra=ess)
warnings.append(warn)

warnings += warn_divergences(idata)
warnings += warn_treedepth(idata)

return warnings


def warn_divergences(idata: arviz.InferenceData) -> List[SamplerWarning]:
"""Checks sampler stats and creates a list of warnings about divergences."""
sampler_stats = idata.get("sample_stats", None)
if sampler_stats is None:
return []

diverging = sampler_stats.get("diverging", None)
if diverging is None:
return []

# Warn about divergences
n_div = int(diverging.sum())
if n_div == 0:
return []
warning = SamplerWarning(
WarningType.DIVERGENCES,
f"There were {n_div} divergences after tuning. Increase `target_accept` or reparameterize.",
"error",
)
return [warning]


def warn_treedepth(idata: arviz.InferenceData) -> List[SamplerWarning]:
"""Checks sampler stats and creates a list of warnings about tree depth."""
sampler_stats = idata.get("sample_stats", None)
if sampler_stats is None:
return []

treedepth = sampler_stats.get("tree_depth", None)
if treedepth is None:
return []

warnings = []
for c in treedepth.chain:
if sum(treedepth.sel(chain=c)) / treedepth.sizes["draw"] > 0.05:
warnings.append(
SamplerWarning(
WarningType.TREEDEPTH,
f"Chain {c} reached the maximum tree depth."
" Increase `max_treedepth`, increase `target_accept` or reparameterize.",
"warn",
)
)
return warnings


def log_warning(warn: SamplerWarning):
level = _LEVELS[warn.level]
level = _LEVELS.get(warn.level, logging.WARNING)
logger.log(level, warn.message)


Expand Down
7 changes: 0 additions & 7 deletions pymc/step_methods/compound.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,13 +59,6 @@ def step(self, point):
point = method.step(point)
return point

def warnings(self):
warns = []
for method in self.methods:
if hasattr(method, "warnings"):
warns.extend(method.warnings())
return warns

def stop_tuning(self):
for method in self.methods:
method.stop_tuning()
Expand Down
Loading