Skip to content

Commit 3ccff92

Browse files
authored
Expand logging test cases for sample_prior_predictive and add return type overloads (#7707)
1 parent 07fc908 commit 3ccff92

File tree

2 files changed

+32
-0
lines changed

2 files changed

+32
-0
lines changed

pymc/sampling/forward.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,10 @@
2020
from collections.abc import Callable, Iterable, Sequence
2121
from typing import (
2222
Any,
23+
Literal,
2324
TypeAlias,
2425
cast,
26+
overload,
2527
)
2628

2729
import numpy as np
@@ -360,6 +362,28 @@ def observed_dependent_deterministics(model: Model, extra_observeds=None):
360362
]
361363

362364

365+
@overload
366+
def sample_prior_predictive(
367+
draws: int = 500,
368+
model: Model | None = None,
369+
var_names: Iterable[str] | None = None,
370+
random_seed: RandomState = None,
371+
return_inferencedata: Literal[True] = True,
372+
idata_kwargs: dict | None = None,
373+
compile_kwargs: dict | None = None,
374+
samples: int | None = None,
375+
) -> InferenceData: ...
376+
@overload
377+
def sample_prior_predictive(
378+
draws: int = 500,
379+
model: Model | None = None,
380+
var_names: Iterable[str] | None = None,
381+
random_seed: RandomState = None,
382+
return_inferencedata: Literal[False] = False,
383+
idata_kwargs: dict | None = None,
384+
compile_kwargs: dict | None = None,
385+
samples: int | None = None,
386+
) -> dict[str, np.ndarray]: ...
363387
def sample_prior_predictive(
364388
draws: int = 500,
365389
model: Model | None = None,

tests/sampling/test_forward.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -857,16 +857,24 @@ def test_logging_sampled_basic_rvs_prior(self, caplog):
857857
y = pm.Deterministic("y", x + 1)
858858
z = pm.Normal("z", y, observed=0)
859859

860+
# all volatile RVs in model
860861
with m:
861862
pm.sample_prior_predictive(draws=1)
862863
assert caplog.record_tuples == [("pymc.sampling.forward", logging.INFO, "Sampling: [x, z]")]
863864
caplog.clear()
864865

866+
# `x` has no dependencies so will be sampled by itself
865867
with m:
866868
pm.sample_prior_predictive(draws=1, var_names=["x"])
867869
assert caplog.record_tuples == [("pymc.sampling.forward", logging.INFO, "Sampling: [x]")]
868870
caplog.clear()
869871

872+
# `z` depends on `x`
873+
with m:
874+
pm.sample_prior_predictive(draws=1, var_names=["z"])
875+
assert caplog.record_tuples == [("pymc.sampling.forward", logging.INFO, "Sampling: [x, z]")]
876+
caplog.clear()
877+
870878
def test_logging_sampled_basic_rvs_posterior(self, caplog):
871879
with pm.Model() as m:
872880
x = pm.Normal("x")

0 commit comments

Comments
 (0)