Skip to content

Collect sampler warnings only through stats #6192

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Oct 28, 2022
Merged

Collect sampler warnings only through stats #6192

merged 3 commits into from
Oct 28, 2022

Conversation

michaelosthege
Copy link
Member

@michaelosthege michaelosthege commented Oct 8, 2022

What is this PR about?
Deleting code that gave special treatments to sampler warnings.

This PR mostly changes three things:

  • The SamplerWarning class and related code, including the code for run_convergence_checks is moved from report.py to convergence.py. Here it is more functionally structure, and no longer coupled to the state of a SamplerReport.
  • A new stat called "warning" is introduced. This is an object stat to which samplers can emit warnings instead of piggy-backing them onto the draws.
  • Handling warnings is removed from the BaseTrace, in favor of managing them via the "warning" stat.
  • updating the minimum ArviZ requirement to the latest version that can handle object-typed variables (by skipping and warning about them) when saving (related Support sparse sample stats #6194).
  • A pm.sample(keep_warning_stat={False, True}) setting was added to enable access to sampler warnings without breaking .to_netcdf() by default.

The logic in sampling.py was modified to stop handling warnings separately.
Instead, the record_and_warn function now takes over the task of logging warnings that are coming through via the "warning" stat.

I also modified the corresponding tests to be more targeted.

Checklist

Major / Breaking Changes

  • Detailed information about divergent samples is no longer captured.
  • Support for handling sampler warnings via ._warnings attributes on samplers and BaseTrace was removed in favor of a "warning" stat.

Bugfixes / New features

  • HMC/NUTS samplers now emit a "warning" stat that ends up in the InferenceData.
  • With pm.sample(keep_warning_stat=True) one can now access detailed info about divergence warnings directly via idata.sample_stats.warning, at the cost of not being able to save these InferenceData objects.

Docs / Maintenance

  • New pm.stats.convergence submodule for functional implementation of convergence diagnostics and warnings.

@codecov
Copy link

codecov bot commented Oct 8, 2022

Codecov Report

Merging #6192 (4096bea) into main (570e6e8) will decrease coverage by 0.03%.
The diff coverage is 84.41%.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #6192      +/-   ##
==========================================
- Coverage   93.77%   93.74%   -0.04%     
==========================================
  Files         101      101              
  Lines       22232    22209      -23     
==========================================
- Hits        20849    20819      -30     
- Misses       1383     1390       +7     
Impacted Files Coverage Δ
pymc/backends/base.py 88.47% <ø> (+0.52%) ⬆️
pymc/step_methods/compound.py 84.61% <ø> (-2.06%) ⬇️
pymc/parallel_sampling.py 88.05% <60.00%> (+2.52%) ⬆️
pymc/step_methods/hmc/base_hmc.py 89.81% <75.00%> (-0.74%) ⬇️
pymc/stats/convergence.py 90.90% <85.71%> (-2.25%) ⬇️
pymc/sampling.py 83.58% <94.44%> (+1.04%) ⬆️
pymc/step_methods/hmc/hmc.py 92.72% <100.00%> (+0.13%) ⬆️
pymc/step_methods/hmc/nuts.py 97.26% <100.00%> (-0.15%) ⬇️
pymc/tests/step_methods/hmc/test_nuts.py 100.00% <100.00%> (ø)
pymc/step_methods/step_sizes.py 75.00% <0.00%> (-25.00%) ⬇️
... and 4 more

@michaelosthege michaelosthege added maintenance trace-backend Traces and ArviZ stuff major Include in major changes release notes section labels Oct 9, 2022
@michaelosthege michaelosthege marked this pull request as ready for review October 9, 2022 11:26
@aseyboldt
Copy link
Member

Hm, I really don't like loosing the extra divergence info. I'm not sure how many people actually use this, but I certainly do :-)
If we want to get rid of the warning system, I think that's fine, but we should then store the information in in the sampler stats instead (or at least have an option to do that).

@michaelosthege
Copy link
Member Author

Hm, I really don't like loosing the extra divergence info. I'm not sure how many people actually use this, but I certainly do :-) If we want to get rid of the warning system, I think that's fine, but we should then store the information in in the sampler stats instead (or at least have an option to do that).

Unfortunately it's not that easy, because a stat can't be a dictionary.
In addition it is not clear whether stats can be sparse.

Assuming we want to store them for later it would be good to use the same data structure across HMC implementations (PyMC, nutpie, numpyro). Do you have any idea which structure that could be?

@michaelosthege
Copy link
Member Author

I think we can compromise:
Storing SamplerWarning-typed items in the stats dict and idata.sample_stats group actually works.
...until one tries to save it to_netcdf.

Coordinating with Oriol, I'll open a PR in ArviZ to filter out object-typed variables from the groups before saving. This will make to_netcdf also more robust in case there are other variables with invalid dtype.
(We can try converting them, but if this fails log a warning..)

@michaelosthege michaelosthege marked this pull request as draft October 10, 2022 21:42
@aseyboldt
Copy link
Member

I think we should be able to come up with a nice format for those in the sampler stats, without resorting to storing objects.
The most tricky part is dealing with the sparse structure, but we could handle that with a hierarchical index, or alternatively a second variable that maps the divergence to the chain. So something along the lines of this:

# Extra dims in the sample_stats:
"divergence" # One key for each divergence across all chains
"unconstrained_parameter"
# Extra vars:
"divergence_start": ("divergence", "unconstrained_parameter")
"divergence_end": ("divergence", "unconstrained_parameter")
"divergence_chain": ("divergence")  # For each divergence the label of the chain where that divergence happend

Instead of "divergence_chain" we could also use a hierarchical index for the divergence index. I think there is still some trouble when storing hierarchical indices as netcdf, but maybe after an unstack that would be fine?

@michaelosthege
Copy link
Member Author

@aseyboldt a serializable structure would be nice, yes, but I see other issues with higher priority..
With arviz-devs/arviz#2134 ArviZ will automatically drop object variables before saving which means that we can safely have the SamplerWarning end up InferenceData. It just won't be saved, but that wasn't the case before either.

To clean up the sampler stats structure, IMO the next step is #6207 and introducing a coord/naming scheme by which samplers can be identified.

This should also open the door to storing other sampler-wise information such as which variables they were sampling.

In the meantime, this PR will have to wait for the upcoming ArviZ release.

@michaelosthege
Copy link
Member Author

The last changes I pushed made thus PR independent of ArviZ updates.

I updated the description, so this is ready to review/merge :)

@michaelosthege
Copy link
Member Author

@pymc-devs/core-contributors can I get reviews here?
I'd like to move forward before git conflicts creep in..

@OriolAbril
Copy link
Member

I think we should be able to come up with a nice format for those in the sampler stats, without resorting to storing objects.
The most tricky part is dealing with the sparse structure, but we could handle that with a hierarchical index, or alternatively a second variable that maps the divergence to the chain. So something along the lines of this:

# Extra dims in the sample_stats:
"divergence" # One key for each divergence across all chains
"unconstrained_parameter"
# Extra vars:
"divergence_start": ("divergence", "unconstrained_parameter")
"divergence_end": ("divergence", "unconstrained_parameter")
"divergence_chain": ("divergence")  # For each divergence the label of the chain where that divergence happend

I think it is worth it to move that to an issue. I tried going over the data in the divergences warning but was unable to understand what info was exactly there and how to store it in "regular" arrays. Defining a structure that works in all cases might require a bit of extra work to account for multiple samplers or adding some exceptions for this to work if nuts is the only sampler, but it might be somewhat feasible. At least on xarray side, the dimensions are shared, but variables in a dataset can have any dimensions (there is not need for all variables to have chain and draw dims for example) and a dataset can have as many dimensions and coordinates as desired.

Copy link
Member

@lucianopaz lucianopaz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me, @michaelosthege . I agree with @OriolAbril that we need to open an issue to change the warning data structure to something that is serializable. It’s a real waste to lose all the warning information when storing the results to netcdf.

Default to `pm.sample(keep_warning_stat=False)`
to maintain compatibility with saving InferenceData.

Closes #6191
@michaelosthege
Copy link
Member Author

Thanks Luciano!

As you can see, I created #6252.

I'll merge this before another git conflict creeps in..

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancements maintenance major Include in major changes release notes section trace-backend Traces and ArviZ stuff
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants