-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Add var_names argument to sample #7206
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Just getting started here. Testing to see if this is the right approach. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @fonnesbeck
Left a small comment. Also I think the jax samplers kind of allow this functionality, so would only need to forward to the sample_external_nuts
code path?
Apologies, I didn't see it was in draft :) |
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #7206 +/- ##
==========================================
- Coverage 92.29% 90.29% -2.01%
==========================================
Files 100 100
Lines 16875 16896 +21
==========================================
- Hits 15575 15256 -319
- Misses 1300 1640 +340
|
bed2f0a
to
3a1cdad
Compare
No worries, feedback welcome at all stages! (earlier the better, in fact) |
Should probably enforce that all stochastic variables be included. |
Maybe it's fine not to? |
The numpyro sampler does not appear to do the right thing with the passed var names. |
What does it do? Checking the source code, it looks like it should if you pass just the strings? |
Better docstring Co-authored-by: Ricardo Vieira <[email protected]>
Seems like at some point the names don't get converted vars. |
(ok, fixed) |
Not sure how code coverage drops when I've added two tests. |
Cov is flaky, not always up to date or comparing with the right commit |
@@ -348,6 +349,7 @@ def _sample_external_nuts( | |||
random_seed=random_seed, | |||
initvals=initvals, | |||
model=model, | |||
var_names=var_names, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add a warning about var_names not beeing used by nutpie like we have for some other arguments above?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, @aseyboldt how hard/reasonable is it to support this in nutpie?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps it could be filtered in nutpie.sample._trace_to_arviz
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ideally we want to filter during sampling already since RAM is usually the issue, not disk-space?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree, but there are no obvious hooks into the nutpie compiled model. It would require some changes on the nutpie side, by the looks of it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My top comment was to add a warning like these:
Lines 283 to 294 in abe7bc9
if initvals is not None: | |
warnings.warn( | |
"`initvals` are currently not passed to nutpie sampler. " | |
"Use `init_mean` kwarg following nutpie specification instead.", | |
UserWarning, | |
) | |
if idata_kwargs is not None: | |
warnings.warn( | |
"`idata_kwargs` are currently ignored by the nutpie sampler", | |
UserWarning, | |
) |
Not to try to monkey-patch nutpie from the outside
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Once/if nutpie has similar functionality we can forward it from pymc?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's go with the warning for now, and create an issue on nutpie for a solution.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't be too hard. Nutpie uses a numba function to compute all values that should appear in the trace (including the deterministics and transformed values). We should be able to just export a subset (code is around here: https://github.com/pymc-devs/nutpie/blob/main/python/nutpie/compile_pymc.py#L387)
Thanks @fonnesbeck ! |
Description
Allow for filtering of variables included in sampled trace via an optional
var_names
argument, similar to what is done for plotting.Related Issue
pm.sample
via avar_names
kwarg #7068Checklist
Type of change
📚 Documentation preview 📚: https://pymc--7206.org.readthedocs.build/en/7206/