Skip to content

Commit 4c3ec36

Browse files
Add var_names argument to sample (#7206)
Co-authored-by: Ricardo Vieira <[email protected]>
1 parent cdfa8c0 commit 4c3ec36

File tree

5 files changed

+51
-4
lines changed

5 files changed

+51
-4
lines changed

pymc/backends/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@
6767

6868
import numpy as np
6969

70+
from pytensor.tensor.variable import TensorVariable
7071
from typing_extensions import TypeAlias
7172

7273
from pymc.backends.arviz import predictions_to_inference_data, to_inference_data
@@ -99,11 +100,12 @@ def _init_trace(
99100
stats_dtypes: list[dict[str, type]],
100101
trace: Optional[BaseTrace],
101102
model: Model,
103+
trace_vars: Optional[list[TensorVariable]] = None,
102104
) -> BaseTrace:
103105
"""Initializes a trace backend for a chain."""
104106
strace: BaseTrace
105107
if trace is None:
106-
strace = NDArray(model=model)
108+
strace = NDArray(model=model, vars=trace_vars)
107109
elif isinstance(trace, BaseTrace):
108110
if len(trace) > 0:
109111
raise ValueError("Continuation of traces is no longer supported.")
@@ -123,6 +125,7 @@ def init_traces(
123125
step: Union[BlockedStep, CompoundStep],
124126
initial_point: Mapping[str, np.ndarray],
125127
model: Model,
128+
trace_vars: Optional[list[TensorVariable]] = None,
126129
) -> tuple[Optional[RunType], Sequence[IBaseTrace]]:
127130
"""Initializes a trace recorder for each chain."""
128131
if HAS_MCB and isinstance(backend, Backend):
@@ -142,6 +145,7 @@ def init_traces(
142145
chain_number=chain_number,
143146
trace=backend,
144147
model=model,
148+
trace_vars=trace_vars,
145149
)
146150
for chain_number in range(chains)
147151
]

pymc/sampling/jax.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -532,15 +532,19 @@ def sample_jax_nuts(
532532

533533
model = modelcontext(model)
534534

535-
if var_names is None:
536-
var_names = model.unobserved_value_vars
535+
if var_names is not None:
536+
filtered_var_names = [v for v in model.unobserved_value_vars if v.name in var_names]
537+
else:
538+
filtered_var_names = model.unobserved_value_vars
537539

538540
if nuts_kwargs is None:
539541
nuts_kwargs = {}
540542
else:
541543
nuts_kwargs = nuts_kwargs.copy()
542544

543-
vars_to_sample = list(get_default_varnames(var_names, include_transformed=keep_untransformed))
545+
vars_to_sample = list(
546+
get_default_varnames(filtered_var_names, include_transformed=keep_untransformed)
547+
)
544548

545549
(random_seed,) = _get_seeds_per_chain(random_seed, 1)
546550

pymc/sampling/mcmc.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,7 @@ def _sample_external_nuts(
264264
random_seed: Union[RandomState, None],
265265
initvals: Union[StartDict, Sequence[Optional[StartDict]], None],
266266
model: Model,
267+
var_names: Optional[Sequence[str]],
267268
progressbar: bool,
268269
idata_kwargs: Optional[dict],
269270
nuts_sampler_kwargs: Optional[dict],
@@ -292,6 +293,11 @@ def _sample_external_nuts(
292293
"`idata_kwargs` are currently ignored by the nutpie sampler",
293294
UserWarning,
294295
)
296+
if var_names is not None:
297+
warnings.warn(
298+
"`var_names` are currently ignored by the nutpie sampler",
299+
UserWarning,
300+
)
295301
compiled_model = nutpie.compile_pymc_model(model)
296302
t_start = time.time()
297303
idata = nutpie.sample(
@@ -348,6 +354,7 @@ def _sample_external_nuts(
348354
random_seed=random_seed,
349355
initvals=initvals,
350356
model=model,
357+
var_names=var_names,
351358
progressbar=progressbar,
352359
nuts_sampler=sampler,
353360
idata_kwargs=idata_kwargs,
@@ -371,6 +378,7 @@ def sample(
371378
random_seed: RandomState = None,
372379
progressbar: bool = True,
373380
step=None,
381+
var_names: Optional[Sequence[str]] = None,
374382
nuts_sampler: Literal["pymc", "nutpie", "numpyro", "blackjax"] = "pymc",
375383
initvals: Optional[Union[StartDict, Sequence[Optional[StartDict]]]] = None,
376384
init: str = "auto",
@@ -399,6 +407,7 @@ def sample(
399407
random_seed: RandomState = None,
400408
progressbar: bool = True,
401409
step=None,
410+
var_names: Optional[Sequence[str]] = None,
402411
nuts_sampler: Literal["pymc", "nutpie", "numpyro", "blackjax"] = "pymc",
403412
initvals: Optional[Union[StartDict, Sequence[Optional[StartDict]]]] = None,
404413
init: str = "auto",
@@ -427,6 +436,7 @@ def sample(
427436
random_seed: RandomState = None,
428437
progressbar: bool = True,
429438
step=None,
439+
var_names: Optional[Sequence[str]] = None,
430440
nuts_sampler: Literal["pymc", "nutpie", "numpyro", "blackjax"] = "pymc",
431441
initvals: Optional[Union[StartDict, Sequence[Optional[StartDict]]]] = None,
432442
init: str = "auto",
@@ -478,6 +488,8 @@ def sample(
478488
A step function or collection of functions. If there are variables without step methods,
479489
step methods for those variables will be assigned automatically. By default the NUTS step
480490
method will be used, if appropriate to the model.
491+
var_names : list of str, optional
492+
Names of variables to be stored in the trace. Defaults to all free variables and deterministics.
481493
nuts_sampler : str
482494
Which NUTS implementation to run. One of ["pymc", "nutpie", "blackjax", "numpyro"].
483495
This requires the chosen sampler to be installed.
@@ -680,6 +692,7 @@ def sample(
680692
random_seed=random_seed,
681693
initvals=initvals,
682694
model=model,
695+
var_names=var_names,
683696
progressbar=progressbar,
684697
idata_kwargs=idata_kwargs,
685698
nuts_sampler_kwargs=nuts_sampler_kwargs,
@@ -722,12 +735,19 @@ def sample(
722735
model.check_start_vals(ip)
723736
_check_start_shape(model, ip)
724737

738+
if var_names is not None:
739+
trace_vars = [v for v in model.unobserved_RVs if v.name in var_names]
740+
assert len(trace_vars) == len(var_names), "Not all var_names were found in the model"
741+
else:
742+
trace_vars = None
743+
725744
# Create trace backends for each chain
726745
run, traces = init_traces(
727746
backend=trace,
728747
chains=chains,
729748
expected_length=draws + tune,
730749
step=step,
750+
trace_vars=trace_vars,
731751
initial_point=ip,
732752
model=model,
733753
)
@@ -739,6 +759,7 @@ def sample(
739759
"traces": traces,
740760
"chains": chains,
741761
"tune": tune,
762+
"var_names": var_names,
742763
"progressbar": progressbar,
743764
"model": model,
744765
"cores": cores,

tests/sampling/test_jax.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -491,6 +491,15 @@ def test_sample_partially_observed():
491491
assert idata.posterior["x"].shape == (1, 10, 3)
492492

493493

494+
def test_sample_var_names():
495+
with pm.Model() as model:
496+
a = pm.Normal("a")
497+
b = pm.Deterministic("b", a**2)
498+
idata = pm.sample(10, tune=10, nuts_sampler="numpyro", var_names=["a"])
499+
assert "a" in idata.posterior
500+
assert "b" not in idata.posterior
501+
502+
494503
@pytest.mark.parametrize("nuts_sampler", ("numpyro", "blackjax"))
495504
def test_convergence_warnings(caplog, nuts_sampler):
496505
with pm.Model() as m:

tests/sampling/test_mcmc.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -694,6 +694,15 @@ def test_no_init_nuts_compound(caplog):
694694
assert "Initializing NUTS" not in caplog.text
695695

696696

697+
def test_sample_var_names():
698+
with pm.Model() as model:
699+
a = pm.Normal("a")
700+
b = pm.Deterministic("b", a**2)
701+
idata = pm.sample(10, tune=10, var_names=["a"])
702+
assert "a" in idata.posterior
703+
assert "b" not in idata.posterior
704+
705+
697706
class TestAssignStepMethods:
698707
def test_bernoulli(self):
699708
"""Test bernoulli distribution is assigned binary gibbs metropolis method"""

0 commit comments

Comments
 (0)