-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Collect sampler warnings only through stats #6192
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Codecov Report
Additional details and impacted files@@ Coverage Diff @@
## main #6192 +/- ##
==========================================
- Coverage 93.77% 93.74% -0.04%
==========================================
Files 101 101
Lines 22232 22209 -23
==========================================
- Hits 20849 20819 -30
- Misses 1383 1390 +7
|
884b85f
to
b013b52
Compare
Hm, I really don't like loosing the extra divergence info. I'm not sure how many people actually use this, but I certainly do :-) |
Unfortunately it's not that easy, because a stat can't be a dictionary. Assuming we want to store them for later it would be good to use the same data structure across HMC implementations (PyMC, nutpie, numpyro). Do you have any idea which structure that could be? |
b013b52
to
c065c46
Compare
I think we can compromise: Coordinating with Oriol, I'll open a PR in ArviZ to filter out |
I think we should be able to come up with a nice format for those in the sampler stats, without resorting to storing objects.
Instead of "divergence_chain" we could also use a hierarchical index for the divergence index. I think there is still some trouble when storing hierarchical indices as netcdf, but maybe after an unstack that would be fine? |
c065c46
to
2e12554
Compare
2e12554
to
8239daa
Compare
@aseyboldt a serializable structure would be nice, yes, but I see other issues with higher priority.. To clean up the sampler stats structure, IMO the next step is #6207 and introducing a coord/naming scheme by which samplers can be identified. This should also open the door to storing other sampler-wise information such as which variables they were sampling. In the meantime, this PR will have to wait for the upcoming ArviZ release. |
35cc6da
to
a32db66
Compare
The last changes I pushed made thus PR independent of ArviZ updates. I updated the description, so this is ready to review/merge :) |
@pymc-devs/core-contributors can I get reviews here? |
I think it is worth it to move that to an issue. I tried going over the data in the divergences warning but was unable to understand what info was exactly there and how to store it in "regular" arrays. Defining a structure that works in all cases might require a bit of extra work to account for multiple samplers or adding some exceptions for this to work if nuts is the only sampler, but it might be somewhat feasible. At least on xarray side, the dimensions are shared, but variables in a dataset can have any dimensions (there is not need for all variables to have chain and draw dims for example) and a dataset can have as many dimensions and coordinates as desired. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me, @michaelosthege . I agree with @OriolAbril that we need to open an issue to change the warning data structure to something that is serializable. It’s a real waste to lose all the warning information when storing the results to netcdf.
Default to `pm.sample(keep_warning_stat=False)` to maintain compatibility with saving InferenceData. Closes #6191
988b21a
to
8ef5990
Compare
Thanks Luciano! As you can see, I created #6252. I'll merge this before another git conflict creeps in.. |
What is this PR about?
Deleting code that gave special treatments to sampler warnings.
This PR mostly changes three things:
SamplerWarning
class and related code, including the code forrun_convergence_checks
is moved fromreport.py
toconvergence.py
. Here it is more functionally structure, and no longer coupled to the state of aSamplerReport
."warning"
is introduced. This is anobject
stat to which samplers can emit warnings instead of piggy-backing them onto the draws.BaseTrace
, in favor of managing them via the "warning" stat.updating the minimum ArviZ requirement to the latest version that can handleobject
-typed variables (by skipping and warning about them) when saving (related Support sparse sample stats #6194).pm.sample(keep_warning_stat={False, True})
setting was added to enable access to sampler warnings without breaking.to_netcdf()
by default.The logic in
sampling.py
was modified to stop handling warnings separately.Instead, the
record_and_warn
function now takes over the task of logging warnings that are coming through via the "warning" stat.I also modified the corresponding tests to be more targeted.
Checklist
Major / Breaking Changes
._warnings
attributes on samplers andBaseTrace
was removed in favor of a "warning" stat.Bugfixes / New features
InferenceData
.pm.sample(keep_warning_stat=True)
one can now access detailed info about divergence warnings directly viaidata.sample_stats.warning
, at the cost of not being able to save theseInferenceData
objects.Docs / Maintenance
pm.stats.convergence
submodule for functional implementation of convergence diagnostics and warnings.