|
16 | 16 | import pytest
|
17 | 17 |
|
18 | 18 | from pytensor import config
|
| 19 | +from pytensor.compile import SharedVariable |
| 20 | +from pytensor.graph import Constant |
19 | 21 |
|
20 | 22 | import pymc as pm
|
21 | 23 |
|
| 24 | +from pymc import sample_posterior_predictive, set_data |
22 | 25 | from pymc.distributions.transforms import logodds
|
23 | 26 | from pymc.model.transform.conditioning import (
|
24 | 27 | change_value_transforms,
|
@@ -253,6 +256,46 @@ def test_do_self_reference():
|
253 | 256 | np.testing.assert_allclose(draw_x + 100, draw_do_x)
|
254 | 257 |
|
255 | 258 |
|
| 259 | +def test_do_make_intervenstions_shared(): |
| 260 | + with pm.Model(coords={"obs": [1]}) as m: |
| 261 | + x = pm.Normal("x", dims="obs") |
| 262 | + y = pm.Normal("y", dims="obs") |
| 263 | + |
| 264 | + constant_m = do(m, {x: [0.5]}, make_interventions_shared=False) |
| 265 | + constant_x = constant_m["x"] |
| 266 | + assert isinstance(constant_x, Constant) |
| 267 | + np.testing.assert_array_equal(constant_x.data, [0.5]) |
| 268 | + |
| 269 | + shared_m = do(m, {x: [0.5]}, make_interventions_shared=True) |
| 270 | + shared_x = shared_m["x"] |
| 271 | + assert isinstance(shared_x, SharedVariable) |
| 272 | + np.testing.assert_array_equal(shared_x.get_value(borrow=True), [0.5]) |
| 273 | + |
| 274 | + with shared_m: |
| 275 | + set_data({"x": [0.6, 0.9]}, coords={"obs": [2, 3]}) |
| 276 | + pp_y = pm.sample_prior_predictive(draws=3).prior["y"] |
| 277 | + assert pp_y.sizes == {"chain": 1, "draw": 3, "obs": 2} |
| 278 | + assert pp_y.shape == (1, 3, 2) |
| 279 | + |
| 280 | + |
| 281 | +@pytest.mark.parametrize( |
| 282 | + "make_interventions_shared", |
| 283 | + [True, pytest.param(False, marks=pytest.mark.xfail(reason="#6876"))], |
| 284 | +) |
| 285 | +def test_do_sample_posterior_predictive(make_interventions_shared): |
| 286 | + # Regression test for https://github.com/pymc-devs/pymc/issues/6977 |
| 287 | + with pm.Model() as model: |
| 288 | + a = pm.Normal("a") |
| 289 | + b = pm.Deterministic("b", a * 2) |
| 290 | + c = pm.Normal("c", b / 2) |
| 291 | + |
| 292 | + idata = az.from_dict({"a": [[1.0]], "b": [[2.0]], "c": [[1.0]]}) |
| 293 | + |
| 294 | + with do(model, {a: 1000}, make_interventions_shared=make_interventions_shared): |
| 295 | + pp = sample_posterior_predictive(idata, var_names=["c"], predictions=True).predictions |
| 296 | + assert (pp["c"] > 500).all() |
| 297 | + |
| 298 | + |
256 | 299 | def test_change_value_transforms():
|
257 | 300 | with pm.Model() as base_m:
|
258 | 301 | p = pm.Uniform("p", 0, 1, default_transform=None)
|
|
0 commit comments