Skip to content

Commit 63cf0e8

Browse files
committed
Make do interventions shared variables by default
1 parent 62335ac commit 63cf0e8

File tree

2 files changed

+56
-4
lines changed

2 files changed

+56
-4
lines changed

pymc/model/transform/conditioning.py

+13-4
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@
1616
from collections.abc import Mapping, Sequence
1717
from typing import Any, Union
1818

19-
from pytensor.graph import ancestors
19+
import pytensor
20+
21+
from pytensor.graph import Constant, ancestors
2022
from pytensor.tensor import TensorVariable
2123

2224
from pymc.logprob.transforms import Transform
@@ -126,7 +128,9 @@ def observe(
126128
def do(
127129
model: Model,
128130
vars_to_interventions: Mapping[Union["str", TensorVariable], Any],
129-
prune_vars=False,
131+
*,
132+
make_interventions_shared: bool = True,
133+
prune_vars: bool = False,
130134
) -> Model:
131135
"""Replace model variables by intervention variables.
132136
@@ -140,6 +144,8 @@ def do(
140144
Dictionary that maps model variables (or names) to intervention expressions.
141145
Intervention expressions must have a shape and data type that is compatible
142146
with the original model variable.
147+
make_interventions_shared: bool, defaults to True,
148+
Whether to make constant interventions shared variables.
143149
prune_vars: bool, defaults to False
144150
Whether to prune model variables that are not connected to any observed variables,
145151
after the interventions.
@@ -170,11 +176,14 @@ def do(
170176
171177
"""
172178
do_mapping = {}
173-
for var, obs in vars_to_interventions.items():
179+
for var, intervention in vars_to_interventions.items():
174180
if isinstance(var, str):
175181
var = model[var]
176182
try:
177-
do_mapping[var] = var.type.filter_variable(obs)
183+
intervention = var.type.filter_variable(intervention)
184+
if make_interventions_shared and isinstance(intervention, Constant):
185+
intervention = pytensor.shared(intervention.data, name=var.name)
186+
do_mapping[var] = intervention
178187
except TypeError as err:
179188
raise TypeError(
180189
"Incompatible replacement type. Make sure the shape and datatype of the interventions match the original variables"

tests/model/transform/test_conditioning.py

+43
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,12 @@
1616
import pytest
1717

1818
from pytensor import config
19+
from pytensor.compile import SharedVariable
20+
from pytensor.graph import Constant
1921

2022
import pymc as pm
2123

24+
from pymc import sample_posterior_predictive, set_data
2225
from pymc.distributions.transforms import logodds
2326
from pymc.model.transform.conditioning import (
2427
change_value_transforms,
@@ -253,6 +256,46 @@ def test_do_self_reference():
253256
np.testing.assert_allclose(draw_x + 100, draw_do_x)
254257

255258

259+
def test_do_make_intervenstions_shared():
260+
with pm.Model(coords={"obs": [1]}) as m:
261+
x = pm.Normal("x", dims="obs")
262+
y = pm.Normal("y", dims="obs")
263+
264+
constant_m = do(m, {x: [0.5]}, make_interventions_shared=False)
265+
constant_x = constant_m["x"]
266+
assert isinstance(constant_x, Constant)
267+
np.testing.assert_array_equal(constant_x.data, [0.5])
268+
269+
shared_m = do(m, {x: [0.5]}, make_interventions_shared=True)
270+
shared_x = shared_m["x"]
271+
assert isinstance(shared_x, SharedVariable)
272+
np.testing.assert_array_equal(shared_x.get_value(borrow=True), [0.5])
273+
274+
with shared_m:
275+
set_data({"x": [0.6, 0.9]}, coords={"obs": [2, 3]})
276+
pp_y = pm.sample_prior_predictive(draws=3).prior["y"]
277+
assert pp_y.sizes == {"chain": 1, "draw": 3, "obs": 2}
278+
assert pp_y.shape == (1, 3, 2)
279+
280+
281+
@pytest.mark.parametrize(
282+
"make_interventions_shared",
283+
[True, pytest.param(False, marks=pytest.mark.xfail(reason="#6876"))],
284+
)
285+
def test_do_sample_posterior_predictive(make_interventions_shared):
286+
# Regression test for https://github.com/pymc-devs/pymc/issues/6977
287+
with pm.Model() as model:
288+
a = pm.Normal("a")
289+
b = pm.Deterministic("b", a * 2)
290+
c = pm.Normal("c", b / 2)
291+
292+
idata = az.from_dict({"a": [[1.0]], "b": [[2.0]], "c": [[1.0]]})
293+
294+
with do(model, {a: 1000}, make_interventions_shared=make_interventions_shared):
295+
pp = sample_posterior_predictive(idata, var_names=["c"], predictions=True).predictions
296+
assert (pp["c"] > 500).all()
297+
298+
256299
def test_change_value_transforms():
257300
with pm.Model() as base_m:
258301
p = pm.Uniform("p", 0, 1, default_transform=None)

0 commit comments

Comments
 (0)