Skip to content

Commit 2d29175

Browse files
committed
Add utility to compute deterministics
1 parent 7ad71f8 commit 2d29175

File tree

5 files changed

+196
-0
lines changed

5 files changed

+196
-0
lines changed

.github/workflows/tests.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ jobs:
6969
- |
7070
tests/distributions/test_censored.py
7171
tests/distributions/test_simulator.py
72+
tests/sampling/test_deterministic.py
7273
tests/sampling/test_forward.py
7374
tests/sampling/test_population.py
7475
tests/stats/test_convergence.py

docs/source/api/samplers.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@ This submodule contains functions for MCMC and forward sampling.
1313
sample_posterior_predictive
1414
draw
1515

16+
.. currentmodule:: pymc.sampling.deterministic
17+
compute_deterministics
18+
1619

1720
.. currentmodule:: pymc.sampling.mcmc
1821

pymc/sampling/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,5 +12,6 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from pymc.sampling.deterministic import compute_deterministics
1516
from pymc.sampling.forward import *
1617
from pymc.sampling.mcmc import *

pymc/sampling/deterministic.py

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
# Copyright 2024 The PyMC Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from collections.abc import Sequence
15+
16+
import xarray
17+
18+
from xarray import Dataset
19+
20+
from pymc.backends.arviz import apply_function_over_dataset, coords_and_dims_for_inferencedata
21+
from pymc.model.core import Model, modelcontext
22+
23+
24+
def compute_deterministics(
25+
dataset: Dataset,
26+
*,
27+
var_names: Sequence[str] | None = None,
28+
model: Model | None = None,
29+
sample_dims: Sequence[str] = ("chain", "draw"),
30+
merge_dataset: bool = False,
31+
progressbar: bool = True,
32+
compile_kwargs: dict | None = None,
33+
) -> Dataset:
34+
"""Compute model deterministics given a dataset with values for model variables.
35+
36+
Parameters
37+
----------
38+
dataset : Dataset
39+
Dataset with values for model variables. Commonly InferenceData["posterior"].
40+
var_names : sequence of str, optional
41+
List of names of deterministic variable to compute.
42+
If None, compute all deterministics in the model.
43+
model : Model, optional
44+
Model to use. If None, use context model.
45+
sample_dims : sequence of str, default ("chain", "draw")
46+
Sample (batch) dimensions of the dataset over which to compute the deterministics.
47+
merge_dataset : bool, default False
48+
Whether to extend the original dataset or return a new one.
49+
progressbar : bool, default True
50+
Whether to display a progress bar in the command line.
51+
progressbar_theme : Theme, optional
52+
Custom theme for the progress bar.
53+
compile_kwargs: dict, optional
54+
Additional arguments passed to `model.compile_fn`.
55+
56+
Returns
57+
-------
58+
Dataset
59+
Dataset with values for the deterministics.
60+
61+
62+
Examples
63+
--------
64+
.. code:: python
65+
66+
import pymc as pm
67+
68+
with pm.Model(coords={"group": (0, 2, 4)}) as m:
69+
mu_raw = pm.Normal("mu_raw", 0, 1, dims="group")
70+
mu = pm.Deterministic("mu", mu_raw.cumsum(), dims="group")
71+
72+
trace = pm.sample(var_names=["mu_raw"], chains=2, tune=5 draws=5)
73+
74+
assert "mu" not in trace.posterior
75+
76+
with m:
77+
trace.posterior = pm.compute_deterministics(trace.posterior, merge_dataset=True)
78+
79+
assert "mu" in trace.posterior
80+
81+
82+
"""
83+
model = modelcontext(model)
84+
85+
if var_names is None:
86+
deterministics = model.deterministics
87+
else:
88+
deterministics = [model[var_name] for var_name in var_names]
89+
if not set(deterministics).issubset(set(model.deterministics)):
90+
raise ValueError("Not all var_names corresponded to model deterministics")
91+
92+
fn = model.compile_fn(
93+
inputs=model.free_RVs,
94+
outs=deterministics,
95+
on_unused_input="ignore",
96+
**(compile_kwargs or {}),
97+
)
98+
99+
coords, dims = coords_and_dims_for_inferencedata(model)
100+
101+
new_dataset = apply_function_over_dataset(
102+
fn,
103+
dataset[[rv.name for rv in model.free_RVs]],
104+
output_var_names=[det.name for det in model.deterministics],
105+
dims=dims,
106+
coords=coords,
107+
sample_dims=sample_dims,
108+
progressbar=progressbar,
109+
)
110+
111+
if merge_dataset:
112+
new_dataset = xarray.merge([dataset, new_dataset], compat="override")
113+
114+
return new_dataset

tests/sampling/test_deterministic.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
# Copyright 2024 The PyMC Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import numpy as np
15+
import pytest
16+
17+
from numpy.testing import assert_allclose
18+
19+
from pymc.distributions import Normal
20+
from pymc.model.core import Deterministic, Model
21+
from pymc.sampling.deterministic import compute_deterministics
22+
from pymc.sampling.forward import sample_prior_predictive
23+
24+
# Turn all warnings into errors for this module
25+
pytestmark = pytest.mark.filterwarnings("error")
26+
27+
28+
def test_compute_deterministics():
29+
with Model(coords={"group": (0, 2, 4)}) as m:
30+
mu_raw = Normal("mu_raw", 0, 1, dims="group")
31+
mu = Deterministic("mu", mu_raw.cumsum(), dims="group")
32+
33+
sigma_raw = Normal("sigma_raw", 0, 1)
34+
sigma = Deterministic("sigma", sigma_raw.exp())
35+
36+
dataset = sample_prior_predictive(
37+
samples=5, model=m, var_names=["mu_raw", "sigma_raw"], random_seed=22
38+
).prior
39+
40+
# Test default
41+
with m:
42+
all_dets = compute_deterministics(dataset)
43+
assert set(all_dets.data_vars.variables) == {"mu", "sigma"}
44+
assert all_dets["mu"].dims == ("chain", "draw", "group")
45+
assert all_dets["sigma"].dims == ("chain", "draw")
46+
assert_allclose(all_dets["mu"], dataset["mu_raw"].cumsum("group"))
47+
assert_allclose(all_dets["sigma"], np.exp(dataset["sigma_raw"]))
48+
49+
# Test custom arguments
50+
extended_with_mu = compute_deterministics(
51+
dataset,
52+
var_names=["mu"],
53+
merge_dataset=True,
54+
model=m,
55+
compile_kwargs={"mode": "FAST_COMPILE"},
56+
progressbar=False,
57+
)
58+
assert set(extended_with_mu.data_vars.variables) == {"mu_raw", "sigma_raw", "mu"}
59+
assert extended_with_mu["mu"].dims == ("chain", "draw", "group")
60+
assert_allclose(extended_with_mu["mu"], dataset["mu_raw"].cumsum("group"))
61+
62+
63+
def test_docstring_example():
64+
import pymc as pm
65+
66+
with pm.Model(coords={"group": (0, 2, 4)}) as m:
67+
mu_raw = pm.Normal("mu_raw", 0, 1, dims="group")
68+
mu = pm.Deterministic("mu", mu_raw.cumsum(), dims="group")
69+
70+
trace = pm.sample(var_names=["mu_raw"], chains=2, tune=5, draws=5)
71+
72+
assert "mu" not in trace.posterior
73+
74+
with m:
75+
trace.posterior = pm.compute_deterministics(trace.posterior, merge_dataset=True)
76+
77+
assert "mu" in trace.posterior

0 commit comments

Comments
 (0)