Skip to content

Commit 3a1cdad

Browse files
committed
Draft of var_names arg for sample
1 parent aa679f3 commit 3a1cdad

File tree

2 files changed

+17
-1
lines changed

2 files changed

+17
-1
lines changed

pymc/backends/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@
6767

6868
import numpy as np
6969

70+
from pytensor.tensor.variable import TensorVariable
7071
from typing_extensions import TypeAlias
7172

7273
from pymc.backends.arviz import predictions_to_inference_data, to_inference_data
@@ -99,11 +100,12 @@ def _init_trace(
99100
stats_dtypes: list[dict[str, type]],
100101
trace: Optional[BaseTrace],
101102
model: Model,
103+
trace_vars: Optional[list[TensorVariable]] = None,
102104
) -> BaseTrace:
103105
"""Initializes a trace backend for a chain."""
104106
strace: BaseTrace
105107
if trace is None:
106-
strace = NDArray(model=model)
108+
strace = NDArray(model=model, vars=trace_vars)
107109
elif isinstance(trace, BaseTrace):
108110
if len(trace) > 0:
109111
raise ValueError("Continuation of traces is no longer supported.")
@@ -123,6 +125,7 @@ def init_traces(
123125
step: Union[BlockedStep, CompoundStep],
124126
initial_point: Mapping[str, np.ndarray],
125127
model: Model,
128+
trace_vars: Optional[list[TensorVariable]] = None,
126129
) -> tuple[Optional[RunType], Sequence[IBaseTrace]]:
127130
"""Initializes a trace recorder for each chain."""
128131
if HAS_MCB and isinstance(backend, Backend):
@@ -142,6 +145,7 @@ def init_traces(
142145
chain_number=chain_number,
143146
trace=backend,
144147
model=model,
148+
trace_vars=trace_vars,
145149
)
146150
for chain_number in range(chains)
147151
]

pymc/sampling/mcmc.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -371,6 +371,7 @@ def sample(
371371
random_seed: RandomState = None,
372372
progressbar: bool = True,
373373
step=None,
374+
var_names: Optional[Sequence[str]] = None,
374375
nuts_sampler: Literal["pymc", "nutpie", "numpyro", "blackjax"] = "pymc",
375376
initvals: Optional[Union[StartDict, Sequence[Optional[StartDict]]]] = None,
376377
init: str = "auto",
@@ -399,6 +400,7 @@ def sample(
399400
random_seed: RandomState = None,
400401
progressbar: bool = True,
401402
step=None,
403+
var_names: Optional[Sequence[str]] = None,
402404
nuts_sampler: Literal["pymc", "nutpie", "numpyro", "blackjax"] = "pymc",
403405
initvals: Optional[Union[StartDict, Sequence[Optional[StartDict]]]] = None,
404406
init: str = "auto",
@@ -427,6 +429,7 @@ def sample(
427429
random_seed: RandomState = None,
428430
progressbar: bool = True,
429431
step=None,
432+
var_names: Optional[Sequence[str]] = None,
430433
nuts_sampler: Literal["pymc", "nutpie", "numpyro", "blackjax"] = "pymc",
431434
initvals: Optional[Union[StartDict, Sequence[Optional[StartDict]]]] = None,
432435
init: str = "auto",
@@ -478,6 +481,8 @@ def sample(
478481
A step function or collection of functions. If there are variables without step methods,
479482
step methods for those variables will be assigned automatically. By default the NUTS step
480483
method will be used, if appropriate to the model.
484+
var_names : list of str
485+
Names of variables to be monitored. If None, all named variables are selected automatically.
481486
nuts_sampler : str
482487
Which NUTS implementation to run. One of ["pymc", "nutpie", "blackjax", "numpyro"].
483488
This requires the chosen sampler to be installed.
@@ -722,12 +727,18 @@ def sample(
722727
model.check_start_vals(ip)
723728
_check_start_shape(model, ip)
724729

730+
if var_names is not None:
731+
trace_vars = [v for v in model.unobserved_RVs if v.name in var_names]
732+
else:
733+
trace_vars = None
734+
725735
# Create trace backends for each chain
726736
run, traces = init_traces(
727737
backend=trace,
728738
chains=chains,
729739
expected_length=draws + tune,
730740
step=step,
741+
trace_vars=trace_vars,
731742
initial_point=ip,
732743
model=model,
733744
)
@@ -739,6 +750,7 @@ def sample(
739750
"traces": traces,
740751
"chains": chains,
741752
"tune": tune,
753+
"var_names": var_names,
742754
"progressbar": progressbar,
743755
"model": model,
744756
"cores": cores,

0 commit comments

Comments
 (0)