Skip to content

Commit 397e6f9

Browse files
Standardize draws as parameter in sample_prior_predictive (#7366)
Co-authored-by: Ricardo Vieira <[email protected]>
1 parent 5bc6801 commit 397e6f9

File tree

8 files changed

+53
-32
lines changed

8 files changed

+53
-32
lines changed

docs/source/learn/core_notebooks/posterior_predictive.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@
156156
" sigma = pm.Exponential(\"sigma\", 1.0)\n",
157157
"\n",
158158
" pm.Normal(\"obs\", mu=mu, sigma=sigma, observed=outcome_scaled)\n",
159-
" idata = pm.sample_prior_predictive(samples=50, random_seed=rng)"
159+
" idata = pm.sample_prior_predictive(draws=50, random_seed=rng)"
160160
]
161161
},
162162
{
@@ -225,7 +225,7 @@
225225
" sigma = pm.Exponential(\"sigma\", 1.0)\n",
226226
"\n",
227227
" pm.Normal(\"obs\", mu=mu, sigma=sigma, observed=outcome_scaled)\n",
228-
" idata = pm.sample_prior_predictive(samples=50, random_seed=rng)"
228+
" idata = pm.sample_prior_predictive(draws=50, random_seed=rng)"
229229
]
230230
},
231231
{

pymc/sampling/forward.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -338,19 +338,20 @@ def observed_dependent_deterministics(model: Model):
338338

339339

340340
def sample_prior_predictive(
341-
samples: int = 500,
341+
draws: int = 500,
342342
model: Model | None = None,
343343
var_names: Iterable[str] | None = None,
344344
random_seed: RandomState = None,
345345
return_inferencedata: bool = True,
346346
idata_kwargs: dict | None = None,
347347
compile_kwargs: dict | None = None,
348+
samples: int | None = None,
348349
) -> InferenceData | dict[str, np.ndarray]:
349350
"""Generate samples from the prior predictive distribution.
350351
351352
Parameters
352353
----------
353-
samples : int
354+
draws : int
354355
Number of samples from the prior predictive to generate. Defaults to 500.
355356
model : Model (optional if in ``with`` context)
356357
var_names : Iterable[str]
@@ -366,13 +367,24 @@ def sample_prior_predictive(
366367
Keyword arguments for :func:`pymc.to_inference_data`
367368
compile_kwargs: dict, optional
368369
Keyword arguments for :func:`pymc.pytensorf.compile_pymc`.
370+
samples : int
371+
Number of samples from the prior predictive to generate. Deprecated in favor of `draws`.
369372
370373
Returns
371374
-------
372375
arviz.InferenceData or Dict
373376
An ArviZ ``InferenceData`` object containing the prior and prior predictive samples (default),
374377
or a dictionary with variable names as keys and samples as numpy arrays.
375378
"""
379+
if samples is not None:
380+
warnings.warn(
381+
f"The samples argument has been deprecated in favor of draws. Use draws={samples} going forward.",
382+
DeprecationWarning,
383+
stacklevel=1,
384+
)
385+
386+
draws = samples
387+
376388
model = modelcontext(model)
377389

378390
if model.potentials:
@@ -415,7 +427,7 @@ def sample_prior_predictive(
415427

416428
# All model variables have a name, but mypy does not know this
417429
_log.info(f"Sampling: {list(sorted(volatile_basic_rvs, key=lambda var: var.name))}") # type: ignore
418-
values = zip(*(sampler_fn() for i in range(samples)))
430+
values = zip(*(sampler_fn() for i in range(draws)))
419431

420432
data = {k: np.stack(v) for k, v in zip(names, values)}
421433
if data is None:

tests/distributions/test_mixture.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -561,7 +561,7 @@ def test_single_poisson_predictive_sampling_shape(self):
561561

562562
n_samples = 30
563563
with model:
564-
prior = sample_prior_predictive(samples=n_samples, return_inferencedata=False)
564+
prior = sample_prior_predictive(draws=n_samples, return_inferencedata=False)
565565
ppc = sample_posterior_predictive(
566566
n_samples * [self.get_initial_point(model)], return_inferencedata=False
567567
)
@@ -607,7 +607,7 @@ def test_list_mvnormals_predictive_sampling_shape(self):
607607

608608
n_samples = 20
609609
with model:
610-
prior = sample_prior_predictive(samples=n_samples, return_inferencedata=False)
610+
prior = sample_prior_predictive(draws=n_samples, return_inferencedata=False)
611611
ppc = sample_posterior_predictive(
612612
n_samples * [self.get_initial_point(model)], return_inferencedata=False
613613
)
@@ -1028,7 +1028,7 @@ def test_with_multinomial(self, seeded_test, batch_shape):
10281028
comp_dists=comp_dists,
10291029
shape=(*batch_shape, 3),
10301030
)
1031-
prior = sample_prior_predictive(samples=self.n_samples, return_inferencedata=False)
1031+
prior = sample_prior_predictive(draws=self.n_samples, return_inferencedata=False)
10321032

10331033
assert prior["mixture"].shape == (self.n_samples, *batch_shape, 3)
10341034
assert draw(mixture, draws=self.size).shape == (self.size, *batch_shape, 3)
@@ -1060,7 +1060,7 @@ def test_with_mvnormal(self, seeded_test):
10601060
with Model() as model:
10611061
comp_dists = MvNormal.dist(mu=mu, chol=chol, shape=(self.mixture_comps, 3))
10621062
mixture = Mixture("mixture", w=w, comp_dists=comp_dists, shape=(3,))
1063-
prior = sample_prior_predictive(samples=self.n_samples, return_inferencedata=False)
1063+
prior = sample_prior_predictive(draws=self.n_samples, return_inferencedata=False)
10641064

10651065
assert prior["mixture"].shape == (self.n_samples, 3)
10661066
assert draw(mixture, draws=self.size).shape == (self.size, 3)
@@ -1084,7 +1084,7 @@ def test_broadcasting_in_shape(self):
10841084
mu = Gamma("mu", 1.0, 1.0, shape=2)
10851085
comp_dists = Poisson.dist(mu, shape=2)
10861086
mix = Mixture("mix", w=np.ones(2) / 2, comp_dists=comp_dists, shape=(1000,))
1087-
prior = sample_prior_predictive(samples=self.n_samples, return_inferencedata=False)
1087+
prior = sample_prior_predictive(draws=self.n_samples, return_inferencedata=False)
10881088

10891089
assert prior["mix"].shape == (self.n_samples, 1000)
10901090

tests/distributions/test_multivariate.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1448,7 +1448,7 @@ def test_with_chol_rv(self):
14481448
"chol_cov", n=3, eta=2, sd_dist=sd_dist, compute_corr=True
14491449
)
14501450
mv = pm.MvNormal("mv", mu, chol=chol, size=4)
1451-
prior = pm.sample_prior_predictive(samples=10, return_inferencedata=False)
1451+
prior = pm.sample_prior_predictive(draws=10, return_inferencedata=False)
14521452

14531453
assert prior["mv"].shape == (10, 4, 3)
14541454

@@ -1462,7 +1462,7 @@ def test_with_cov_rv(
14621462
"chol_cov", n=3, eta=2, sd_dist=sd_dist, compute_corr=True
14631463
)
14641464
mv = pm.MvNormal("mv", mu, cov=pm.math.dot(chol, chol.T), size=4)
1465-
prior = pm.sample_prior_predictive(samples=10, return_inferencedata=False)
1465+
prior = pm.sample_prior_predictive(draws=10, return_inferencedata=False)
14661466

14671467
assert prior["mv"].shape == (10, 4, 3)
14681468

@@ -1473,7 +1473,7 @@ def test_with_lkjcorr_matrix(
14731473
corr = pm.LKJCorr("corr", n=3, eta=2, return_matrix=True)
14741474
pm.Deterministic("corr_mat", corr)
14751475
mv = pm.MvNormal("mv", 0.0, cov=corr, size=4)
1476-
prior = pm.sample_prior_predictive(samples=10, return_inferencedata=False)
1476+
prior = pm.sample_prior_predictive(draws=10, return_inferencedata=False)
14771477

14781478
assert prior["corr_mat"].shape == (10, 3, 3) # square
14791479
assert (prior["corr_mat"][:, [0, 1, 2], [0, 1, 2]] == 1.0).all() # 1.0 on diagonal

tests/gp/test_hsgp_approx.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,7 @@ def test_prior(self, model, cov_func, X1, parametrization, rng):
213213
gp = pm.gp.Latent(cov_func=cov_func)
214214
f2 = gp.prior("f2", X=X1)
215215

216-
idata = pm.sample_prior_predictive(samples=1000, random_seed=rng)
216+
idata = pm.sample_prior_predictive(draws=1000, random_seed=rng)
217217

218218
samples1 = az.extract(idata.prior["f1"])["f1"].values.T
219219
samples2 = az.extract(idata.prior["f2"])["f2"].values.T
@@ -240,7 +240,7 @@ def test_conditional(self, model, cov_func, X1, parametrization):
240240
f = hsgp.prior("f", X=X1)
241241
fc = hsgp.conditional("fc", Xnew=X1)
242242

243-
idata = pm.sample_prior_predictive(samples=1000)
243+
idata = pm.sample_prior_predictive(draws=1000)
244244

245245
samples1 = az.extract(idata.prior["f"])["f"].values.T
246246
samples2 = az.extract(idata.prior["fc"])["fc"].values.T
@@ -300,7 +300,7 @@ def test_prior(self, model, cov_func, eta, X1, rng):
300300
gp = pm.gp.Latent(cov_func=eta**2 * cov_func)
301301
f2 = gp.prior("f2", X=X1)
302302

303-
idata = pm.sample_prior_predictive(samples=1000, random_seed=rng)
303+
idata = pm.sample_prior_predictive(draws=1000, random_seed=rng)
304304

305305
samples1 = az.extract(idata.prior["f1"])["f1"].values.T
306306
samples2 = az.extract(idata.prior["f2"])["f2"].values.T
@@ -321,7 +321,7 @@ def test_conditional_periodic(self, model, cov_func, X1):
321321
f = hsgp.prior("f", X=X1)
322322
fc = hsgp.conditional("fc", Xnew=X1)
323323

324-
idata = pm.sample_prior_predictive(samples=1000)
324+
idata = pm.sample_prior_predictive(draws=1000)
325325

326326
samples1 = az.extract(idata.prior["f"])["f"].values.T
327327
samples2 = az.extract(idata.prior["fc"])["fc"].values.T

tests/model/test_core.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -873,7 +873,7 @@ def test_none_coords_autonumbering(self):
873873
m.add_coord(name="a", values=None, length=3)
874874
m.add_coord(name="b", values=range(5))
875875
x = pm.Normal("x", dims=("a", "b"))
876-
prior = pm.sample_prior_predictive(samples=2).prior
876+
prior = pm.sample_prior_predictive(draws=2).prior
877877
assert prior["x"].shape == (1, 2, 3, 5)
878878
assert list(prior.coords["a"].values) == list(range(3))
879879
assert list(prior.coords["b"].values) == list(range(5))

tests/sampling/test_deterministic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def test_compute_deterministics():
3434
sigma = Deterministic("sigma", sigma_raw.exp())
3535

3636
dataset = sample_prior_predictive(
37-
samples=5, model=m, var_names=["mu_raw", "sigma_raw"], random_seed=22
37+
draws=5, model=m, var_names=["mu_raw", "sigma_raw"], random_seed=22
3838
).prior
3939

4040
# Test default

tests/sampling/test_forward.py

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -794,12 +794,12 @@ def test_logging_sampled_basic_rvs_prior(self, caplog):
794794
z = pm.Normal("z", y, observed=0)
795795

796796
with m:
797-
pm.sample_prior_predictive(samples=1)
797+
pm.sample_prior_predictive(draws=1)
798798
assert caplog.record_tuples == [("pymc.sampling.forward", logging.INFO, "Sampling: [x, z]")]
799799
caplog.clear()
800800

801801
with m:
802-
pm.sample_prior_predictive(samples=1, var_names=["x"])
802+
pm.sample_prior_predictive(draws=1, var_names=["x"])
803803
assert caplog.record_tuples == [("pymc.sampling.forward", logging.INFO, "Sampling: [x]")]
804804
caplog.clear()
805805

@@ -1028,7 +1028,7 @@ def test_observed_data_needed_in_pp(self):
10281028
mu = x_data.sum(-1)
10291029
pm.Normal("y", mu=mu, sigma=sigma, observed=y_data, shape=mu.shape, dims=("trial",))
10301030

1031-
prior = pm.sample_prior_predictive(samples=25).prior
1031+
prior = pm.sample_prior_predictive(draws=25).prior
10321032

10331033
fake_idata = InferenceData(posterior=prior)
10341034

@@ -1052,7 +1052,7 @@ def test_observed_data_needed_in_pp(self):
10521052
mu = (y_data.sum() * x_data).sum(-1)
10531053
pm.Normal("y", mu=mu, sigma=sigma, observed=y_data, shape=mu.shape, dims=("trial",))
10541054

1055-
prior = pm.sample_prior_predictive(samples=25).prior
1055+
prior = pm.sample_prior_predictive(draws=25).prior
10561056

10571057
fake_idata = InferenceData(posterior=prior)
10581058

@@ -1135,7 +1135,7 @@ def test_multivariate2(self, seeded_test):
11351135
compute_convergence_checks=False,
11361136
)
11371137
sim_priors = pm.sample_prior_predictive(
1138-
return_inferencedata=False, samples=20, model=dm_model
1138+
return_inferencedata=False, draws=20, model=dm_model
11391139
)
11401140
sim_ppc = pm.sample_posterior_predictive(
11411141
burned_trace, return_inferencedata=False, model=dm_model
@@ -1227,7 +1227,7 @@ def test_zeroinflatedpoisson(self):
12271227
mu = pm.Beta("mu", alpha=1, beta=1)
12281228
psi = pm.HalfNormal("psi", sigma=1)
12291229
pm.ZeroInflatedPoisson("suppliers", psi=psi, mu=mu, size=20)
1230-
gen_data = pm.sample_prior_predictive(samples=5000)
1230+
gen_data = pm.sample_prior_predictive(draws=5000)
12311231
assert gen_data.prior["mu"].shape == (1, 5000)
12321232
assert gen_data.prior["psi"].shape == (1, 5000)
12331233
assert gen_data.prior["suppliers"].shape == (1, 5000, 20)
@@ -1240,7 +1240,7 @@ def test_potentials_warning(self):
12401240

12411241
with m:
12421242
with pytest.warns(UserWarning, match=warning_msg):
1243-
pm.sample_prior_predictive(samples=5)
1243+
pm.sample_prior_predictive(draws=5)
12441244

12451245
def test_transformed_vars_not_supported(self):
12461246
with pm.Model() as model:
@@ -1260,7 +1260,7 @@ def test_issue_4490(self):
12601260
c = pm.Normal("c")
12611261
d = pm.Normal("d")
12621262
prior1 = pm.sample_prior_predictive(
1263-
samples=1, var_names=["a", "b", "c", "d"], random_seed=seed
1263+
draws=1, var_names=["a", "b", "c", "d"], random_seed=seed
12641264
)
12651265

12661266
with pm.Model() as m2:
@@ -1269,7 +1269,7 @@ def test_issue_4490(self):
12691269
c = pm.Normal("c")
12701270
d = pm.Normal("d")
12711271
prior2 = pm.sample_prior_predictive(
1272-
samples=1, var_names=["b", "a", "d", "c"], random_seed=seed
1272+
draws=1, var_names=["b", "a", "d", "c"], random_seed=seed
12731273
)
12741274

12751275
assert prior1.prior["a"] == prior2.prior["a"]
@@ -1284,7 +1284,7 @@ def test_pytensor_function_kwargs(self):
12841284
y = pm.Deterministic("y", x + sharedvar)
12851285

12861286
prior = pm.sample_prior_predictive(
1287-
samples=5,
1287+
draws=5,
12881288
return_inferencedata=False,
12891289
compile_kwargs=dict(
12901290
mode=Mode("py"),
@@ -1308,7 +1308,7 @@ def test_sample_from_xarray_prior(self, point_list_arg_bug_fixture):
13081308

13091309
with pmodel:
13101310
prior = pm.sample_prior_predictive(
1311-
samples=20,
1311+
draws=20,
13121312
return_inferencedata=False,
13131313
)
13141314
idat = pm.to_inference_data(trace, prior=prior)
@@ -1367,7 +1367,7 @@ def test_distinct_rvs():
13671367
Y_rv = pm.Normal("y")
13681368

13691369
pp_samples = pm.sample_prior_predictive(
1370-
samples=2, return_inferencedata=False, random_seed=npr.RandomState(2023532)
1370+
draws=2, return_inferencedata=False, random_seed=npr.RandomState(2023532)
13711371
)
13721372

13731373
assert X_rv.owner.inputs[0] != Y_rv.owner.inputs[0]
@@ -1377,7 +1377,7 @@ def test_distinct_rvs():
13771377
Y_rv = pm.Normal("y")
13781378

13791379
pp_samples_2 = pm.sample_prior_predictive(
1380-
samples=2, return_inferencedata=False, random_seed=npr.RandomState(2023532)
1380+
draws=2, return_inferencedata=False, random_seed=npr.RandomState(2023532)
13811381
)
13821382

13831383
assert np.array_equal(pp_samples["y"], pp_samples_2["y"])
@@ -1706,3 +1706,12 @@ def test_observed_dependent_deterministics():
17061706
det_mixed = pm.Deterministic("det_mixed", free + obs)
17071707

17081708
assert set(observed_dependent_deterministics(m)) == {det_obs, det_obs2, det_mixed}
1709+
1710+
1711+
def test_sample_prior_predictive_samples_deprecated_warns() -> None:
1712+
with pm.Model() as m:
1713+
pm.Normal("a")
1714+
1715+
match = "The samples argument has been deprecated"
1716+
with pytest.warns(DeprecationWarning, match=match):
1717+
pm.sample_prior_predictive(model=m, samples=10)

0 commit comments

Comments
 (0)