-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Add var_names argument to sample #7206
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 7 commits
3a1cdad
911263d
2733886
2092ea1
55c06cb
16585a9
c9f3224
efbeef9
13a5d31
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -264,6 +264,7 @@ def _sample_external_nuts( | |||||||||||||||||||||||||
random_seed: Union[RandomState, None], | ||||||||||||||||||||||||||
initvals: Union[StartDict, Sequence[Optional[StartDict]], None], | ||||||||||||||||||||||||||
model: Model, | ||||||||||||||||||||||||||
var_names: Optional[Sequence[str]], | ||||||||||||||||||||||||||
progressbar: bool, | ||||||||||||||||||||||||||
idata_kwargs: Optional[dict], | ||||||||||||||||||||||||||
nuts_sampler_kwargs: Optional[dict], | ||||||||||||||||||||||||||
|
@@ -348,6 +349,7 @@ def _sample_external_nuts( | |||||||||||||||||||||||||
random_seed=random_seed, | ||||||||||||||||||||||||||
initvals=initvals, | ||||||||||||||||||||||||||
model=model, | ||||||||||||||||||||||||||
var_names=var_names, | ||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add a warning about var_names not beeing used by nutpie like we have for some other arguments above? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also, @aseyboldt how hard/reasonable is it to support this in nutpie? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Perhaps it could be filtered in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ideally we want to filter during sampling already since RAM is usually the issue, not disk-space? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree, but there are no obvious hooks into the nutpie compiled model. It would require some changes on the nutpie side, by the looks of it. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. My top comment was to add a warning like these: Lines 283 to 294 in abe7bc9
Not to try to monkey-patch nutpie from the outside There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Once/if nutpie has similar functionality we can forward it from pymc? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's go with the warning for now, and create an issue on nutpie for a solution. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shouldn't be too hard. Nutpie uses a numba function to compute all values that should appear in the trace (including the deterministics and transformed values). We should be able to just export a subset (code is around here: https://github.com/pymc-devs/nutpie/blob/main/python/nutpie/compile_pymc.py#L387) |
||||||||||||||||||||||||||
progressbar=progressbar, | ||||||||||||||||||||||||||
nuts_sampler=sampler, | ||||||||||||||||||||||||||
idata_kwargs=idata_kwargs, | ||||||||||||||||||||||||||
|
@@ -371,6 +373,7 @@ def sample( | |||||||||||||||||||||||||
random_seed: RandomState = None, | ||||||||||||||||||||||||||
progressbar: bool = True, | ||||||||||||||||||||||||||
step=None, | ||||||||||||||||||||||||||
var_names: Optional[Sequence[str]] = None, | ||||||||||||||||||||||||||
nuts_sampler: Literal["pymc", "nutpie", "numpyro", "blackjax"] = "pymc", | ||||||||||||||||||||||||||
initvals: Optional[Union[StartDict, Sequence[Optional[StartDict]]]] = None, | ||||||||||||||||||||||||||
init: str = "auto", | ||||||||||||||||||||||||||
|
@@ -399,6 +402,7 @@ def sample( | |||||||||||||||||||||||||
random_seed: RandomState = None, | ||||||||||||||||||||||||||
progressbar: bool = True, | ||||||||||||||||||||||||||
step=None, | ||||||||||||||||||||||||||
var_names: Optional[Sequence[str]] = None, | ||||||||||||||||||||||||||
nuts_sampler: Literal["pymc", "nutpie", "numpyro", "blackjax"] = "pymc", | ||||||||||||||||||||||||||
initvals: Optional[Union[StartDict, Sequence[Optional[StartDict]]]] = None, | ||||||||||||||||||||||||||
init: str = "auto", | ||||||||||||||||||||||||||
|
@@ -427,6 +431,7 @@ def sample( | |||||||||||||||||||||||||
random_seed: RandomState = None, | ||||||||||||||||||||||||||
progressbar: bool = True, | ||||||||||||||||||||||||||
step=None, | ||||||||||||||||||||||||||
var_names: Optional[Sequence[str]] = None, | ||||||||||||||||||||||||||
nuts_sampler: Literal["pymc", "nutpie", "numpyro", "blackjax"] = "pymc", | ||||||||||||||||||||||||||
initvals: Optional[Union[StartDict, Sequence[Optional[StartDict]]]] = None, | ||||||||||||||||||||||||||
init: str = "auto", | ||||||||||||||||||||||||||
|
@@ -478,6 +483,8 @@ def sample( | |||||||||||||||||||||||||
A step function or collection of functions. If there are variables without step methods, | ||||||||||||||||||||||||||
step methods for those variables will be assigned automatically. By default the NUTS step | ||||||||||||||||||||||||||
method will be used, if appropriate to the model. | ||||||||||||||||||||||||||
var_names : list of str, optional | ||||||||||||||||||||||||||
Names of variables to be stored in the trace. Defaults to all free variables and deterministics. | ||||||||||||||||||||||||||
nuts_sampler : str | ||||||||||||||||||||||||||
Which NUTS implementation to run. One of ["pymc", "nutpie", "blackjax", "numpyro"]. | ||||||||||||||||||||||||||
This requires the chosen sampler to be installed. | ||||||||||||||||||||||||||
|
@@ -680,6 +687,7 @@ def sample( | |||||||||||||||||||||||||
random_seed=random_seed, | ||||||||||||||||||||||||||
initvals=initvals, | ||||||||||||||||||||||||||
model=model, | ||||||||||||||||||||||||||
var_names=var_names, | ||||||||||||||||||||||||||
progressbar=progressbar, | ||||||||||||||||||||||||||
idata_kwargs=idata_kwargs, | ||||||||||||||||||||||||||
nuts_sampler_kwargs=nuts_sampler_kwargs, | ||||||||||||||||||||||||||
|
@@ -722,12 +730,19 @@ def sample( | |||||||||||||||||||||||||
model.check_start_vals(ip) | ||||||||||||||||||||||||||
_check_start_shape(model, ip) | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
if var_names is not None: | ||||||||||||||||||||||||||
trace_vars = [v for v in model.unobserved_RVs if v.name in var_names] | ||||||||||||||||||||||||||
fonnesbeck marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||||
assert len(trace_vars) == len(var_names), "Not all var_names were found in the model" | ||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||
trace_vars = None | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
# Create trace backends for each chain | ||||||||||||||||||||||||||
run, traces = init_traces( | ||||||||||||||||||||||||||
backend=trace, | ||||||||||||||||||||||||||
chains=chains, | ||||||||||||||||||||||||||
expected_length=draws + tune, | ||||||||||||||||||||||||||
step=step, | ||||||||||||||||||||||||||
trace_vars=trace_vars, | ||||||||||||||||||||||||||
initial_point=ip, | ||||||||||||||||||||||||||
model=model, | ||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||
|
@@ -739,6 +754,7 @@ def sample( | |||||||||||||||||||||||||
"traces": traces, | ||||||||||||||||||||||||||
"chains": chains, | ||||||||||||||||||||||||||
"tune": tune, | ||||||||||||||||||||||||||
"var_names": var_names, | ||||||||||||||||||||||||||
"progressbar": progressbar, | ||||||||||||||||||||||||||
"model": model, | ||||||||||||||||||||||||||
"cores": cores, | ||||||||||||||||||||||||||
|
Uh oh!
There was an error while loading. Please reload this page.