Skip to content

Commit 2eaa0e3

Browse files
ricardoV94markgreene74Hemant19870601
committed
Use more specialized imports in test_custom
Co-authored-by: Giuseppe Cunsolo <[email protected]> Co-authored-by: [email protected] Co-authored-by: "Hemant19870601" <[email protected]>
1 parent 6108eb4 commit 2eaa0e3

File tree

1 file changed

+66
-61
lines changed

1 file changed

+66
-61
lines changed

tests/distributions/test_custom.py

Lines changed: 66 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -23,27 +23,34 @@
2323
from pytensor import tensor as pt
2424
from scipy import stats as st
2525

26-
import pymc as pm
27-
28-
from pymc import (
29-
CustomDist,
30-
Deterministic,
26+
from pymc.distributions import (
27+
Bernoulli,
28+
Beta,
29+
Categorical,
30+
ChiSquared,
3131
DiracDelta,
32+
Flat,
3233
HalfNormal,
3334
LogNormal,
34-
Model,
35+
Mixture,
36+
MvNormal,
3537
Normal,
36-
draw,
37-
logcdf,
38-
logp,
39-
sample,
38+
NormalMixture,
39+
RandomWalk,
40+
StudentT,
41+
Truncated,
42+
Uniform,
4043
)
41-
from pymc.distributions.custom import CustomDistRV, CustomSymbolicDistRV
44+
from pymc.distributions.custom import CustomDist, CustomDistRV, CustomSymbolicDistRV
4245
from pymc.distributions.distribution import support_point
4346
from pymc.distributions.shape_utils import change_dist_size, rv_size_is_none, to_tuple
4447
from pymc.distributions.transforms import log
4548
from pymc.exceptions import BlockModelAccessError
49+
from pymc.logprob import logcdf, logp
50+
from pymc.model import Deterministic, Model
4651
from pymc.pytensorf import collect_default_updates
52+
from pymc.sampling import draw, sample, sample_posterior_predictive
53+
from pymc.step_methods import Metropolis
4754
from pymc.testing import assert_support_point_is_expected
4855

4956

@@ -88,15 +95,15 @@ def test_custom_dist_without_random(self):
8895
custom_dist = CustomDist(
8996
"custom_dist",
9097
mu,
91-
logp=lambda value, mu: logp(pm.Normal.dist(mu, 1, size=100), value),
98+
logp=lambda value, mu: logp(Normal.dist(mu, 1, size=100), value),
9299
observed=np.random.randn(100),
93100
initval=0,
94101
)
95102
assert isinstance(custom_dist.owner.op, CustomDistRV)
96-
idata = sample(tune=50, draws=100, cores=1, step=pm.Metropolis())
103+
idata = sample(tune=50, draws=100, cores=1, step=Metropolis())
97104

98105
with pytest.raises(NotImplementedError):
99-
pm.sample_posterior_predictive(idata, model=model)
106+
sample_posterior_predictive(idata, model=model)
100107

101108
@pytest.mark.xfail(
102109
NotImplementedError,
@@ -159,7 +166,7 @@ def test_custom_dist_multivariate_logp(self, size):
159166
with Model() as model:
160167

161168
def logp(value, mu):
162-
return pm.MvNormal.logp(value, mu, pt.eye(mu.shape[0]))
169+
return MvNormal.logp(value, mu, pt.eye(mu.shape[0]))
163170

164171
mu = Normal("mu", size=supp_shape)
165172
a = CustomDist("a", mu, logp=logp, ndims_params=[1], ndim_supp=1, size=size)
@@ -184,14 +191,14 @@ def logp(value, mu):
184191
def test_custom_dist_default_support_point_univariate(self, support_point, size, expected):
185192
if support_point == "custom_support_point":
186193
support_point = lambda rv, size, *rv_inputs: 5 * pt.ones(size, dtype=rv.dtype) # noqa E731
187-
with pm.Model() as model:
194+
with Model() as model:
188195
x = CustomDist("x", support_point=support_point, size=size)
189196
assert isinstance(x.owner.op, CustomDistRV)
190197
assert_support_point_is_expected(model, expected, check_finite_logp=False)
191198

192199
def test_custom_dist_moment_future_warning(self):
193200
moment = lambda rv, size, *rv_inputs: 5 * pt.ones(size, dtype=rv.dtype) # noqa E731
194-
with pm.Model() as model:
201+
with Model() as model:
195202
with pytest.warns(
196203
FutureWarning, match="`moment` argument is deprecated. Use `support_point` instead."
197204
):
@@ -280,24 +287,24 @@ def test_dist(self):
280287
mu = 1
281288
x = CustomDist.dist(
282289
mu,
283-
logp=lambda value, mu: pm.logp(pm.Normal.dist(mu), value),
290+
logp=lambda value, mu: logp(Normal.dist(mu), value),
284291
random=lambda mu, rng=None, size=None: rng.normal(loc=mu, scale=1, size=size),
285292
shape=(3,),
286293
)
287294

288295
x = cloudpickle.loads(cloudpickle.dumps(x))
289296

290-
test_value = pm.draw(x, random_seed=1)
291-
assert np.all(test_value == pm.draw(x, random_seed=1))
297+
test_value = draw(x, random_seed=1)
298+
assert np.all(test_value == draw(x, random_seed=1))
292299

293-
x_logp = pm.logp(x, test_value)
300+
x_logp = logp(x, test_value)
294301
assert np.allclose(x_logp.eval(), st.norm(1).logpdf(test_value))
295302

296303

297304
class TestCustomSymbolicDist:
298305
def test_basic(self):
299306
def custom_dist(mu, sigma, size):
300-
return pt.exp(pm.Normal.dist(mu, sigma, size=size))
307+
return pt.exp(Normal.dist(mu, sigma, size=size))
301308

302309
with Model() as m:
303310
mu = Normal("mu")
@@ -315,7 +322,7 @@ def custom_dist(mu, sigma, size):
315322
assert isinstance(lognormal.owner.op, CustomSymbolicDistRV)
316323

317324
# Fix mu and sigma, so that all source of randomness comes from the symbolic RV
318-
draws = pm.draw(lognormal, draws=3, givens={mu: 0.0, sigma: 1.0})
325+
draws = draw(lognormal, draws=3, givens={mu: 0.0, sigma: 1.0})
319326
assert draws.shape == (3, 10)
320327
assert np.unique(draws).size == 30
321328

@@ -334,31 +341,31 @@ def custom_dist(mu, sigma, size):
334341
(5, 1),
335342
None,
336343
np.exp(5),
337-
lambda mu, sigma, size: pt.exp(pm.Normal.dist(mu, sigma, size=size)),
344+
lambda mu, sigma, size: pt.exp(Normal.dist(mu, sigma, size=size)),
338345
),
339346
(
340347
(2, np.ones(5)),
341348
None,
342349
np.exp(2 + np.ones(5)),
343-
lambda mu, sigma, size: pt.exp(pm.Normal.dist(mu, sigma, size=size) + 1.0),
350+
lambda mu, sigma, size: pt.exp(Normal.dist(mu, sigma, size=size) + 1.0),
344351
),
345352
(
346353
(1, 2),
347354
None,
348355
np.sqrt(np.exp(1 + 0.5 * 2**2)),
349-
lambda mu, sigma, size: pt.sqrt(pm.LogNormal.dist(mu, sigma, size=size)),
356+
lambda mu, sigma, size: pt.sqrt(LogNormal.dist(mu, sigma, size=size)),
350357
),
351358
(
352359
(4,),
353360
(3,),
354361
np.log([4, 4, 4]),
355-
lambda nu, size: pt.log(pm.ChiSquared.dist(nu, size=size)),
362+
lambda nu, size: pt.log(ChiSquared.dist(nu, size=size)),
356363
),
357364
(
358365
(12, 1),
359366
None,
360367
12,
361-
lambda mu1, sigma, size: pm.Normal.dist(mu1, sigma, size=size),
368+
lambda mu1, sigma, size: Normal.dist(mu1, sigma, size=size),
362369
),
363370
],
364371
)
@@ -369,7 +376,7 @@ def test_custom_dist_default_support_point(self, dist_params, size, expected, di
369376

370377
def test_custom_dist_default_support_point_scan(self):
371378
def scan_step(left, right):
372-
x = pm.Uniform.dist(left, right)
379+
x = Uniform.dist(left, right)
373380
x_update = collect_default_updates([x])
374381
return x, x_update
375382

@@ -390,7 +397,7 @@ def dist(size):
390397

391398
def test_custom_dist_default_support_point_scan_recurring(self):
392399
def scan_step(xtm1):
393-
x = pm.Normal.dist(xtm1 + 1)
400+
x = Normal.dist(xtm1 + 1)
394401
x_update = collect_default_updates([x])
395402
return x, x_update
396403

@@ -417,15 +424,15 @@ def dist(size):
417424
)
418425
def test_custom_dist_default_support_point_nested(self, left, right, size, expected):
419426
def dist_fn(left, right, size):
420-
return pm.Truncated.dist(pm.Normal.dist(0, 1), left, right, size=size) + 5
427+
return Truncated.dist(Normal.dist(0, 1), left, right, size=size) + 5
421428

422429
with Model() as model:
423430
CustomDist("x", left, right, size=size, dist=dist_fn)
424431
assert_support_point_is_expected(model, expected)
425432

426433
def test_logcdf_inference(self):
427434
def custom_dist(mu, sigma, size):
428-
return pt.exp(pm.Normal.dist(mu, sigma, size=size))
435+
return pt.exp(Normal.dist(mu, sigma, size=size))
429436

430437
mu = 1
431438
sigma = 1.25
@@ -435,16 +442,16 @@ def custom_dist(mu, sigma, size):
435442
ref_lognormal = LogNormal.dist(mu, sigma)
436443

437444
np.testing.assert_allclose(
438-
pm.logcdf(custom_lognormal, test_value).eval(),
439-
pm.logcdf(ref_lognormal, test_value).eval(),
445+
logcdf(custom_lognormal, test_value).eval(),
446+
logcdf(ref_lognormal, test_value).eval(),
440447
)
441448

442449
def test_random_multiple_rngs(self):
443450
def custom_dist(p, sigma, size):
444-
idx = pm.Bernoulli.dist(p=p)
451+
idx = Bernoulli.dist(p=p)
445452
if rv_size_is_none(size):
446453
size = pt.broadcast_shape(p, sigma)
447-
comps = pm.Normal.dist([-sigma, sigma], 1e-1, size=(*size, 2)).T
454+
comps = Normal.dist([-sigma, sigma], 1e-1, size=(*size, 2)).T
448455
return comps[idx]
449456

450457
customdist = CustomDist.dist(
@@ -461,7 +468,7 @@ def custom_dist(p, sigma, size):
461468
assert len(node.outputs) == 3 # RV and 2 updated RNGs
462469
assert len(node.op.update(node)) == 2
463470

464-
draws = pm.draw(customdist, draws=2, random_seed=123)
471+
draws = draw(customdist, draws=2, random_seed=123)
465472
assert np.unique(draws).size == 20
466473

467474
def test_custom_methods(self):
@@ -494,7 +501,7 @@ def custom_logcdf(value, mu):
494501

495502
def test_change_size(self):
496503
def custom_dist(mu, sigma, size):
497-
return pt.exp(pm.Normal.dist(mu, sigma, size=size))
504+
return pt.exp(Normal.dist(mu, sigma, size=size))
498505

499506
lognormal = CustomDist.dist(
500507
0,
@@ -515,9 +522,9 @@ def custom_dist(mu, sigma, size):
515522

516523
def test_error_model_access(self):
517524
def custom_dist(size):
518-
return pm.Flat("Flat", size=size)
525+
return Flat("Flat", size=size)
519526

520-
with pm.Model() as m:
527+
with Model() as m:
521528
with pytest.raises(
522529
BlockModelAccessError,
523530
match="Model variables cannot be created in the dist function",
@@ -526,7 +533,7 @@ def custom_dist(size):
526533

527534
def test_api_change_error(self):
528535
def old_random(size):
529-
return pm.Flat.dist(size=size)
536+
return Flat.dist(size=size)
530537

531538
# Old API raises
532539
with pytest.raises(TypeError, match="API change: function passed to `random` argument"):
@@ -541,7 +548,7 @@ def trw(nu, sigma, steps, size):
541548
size = ()
542549

543550
def step(xtm1, nu, sigma):
544-
x = pm.StudentT.dist(nu=nu, mu=xtm1, sigma=sigma, shape=size)
551+
x = StudentT.dist(nu=nu, mu=xtm1, sigma=sigma, shape=size)
545552
return x, collect_default_updates([x])
546553

547554
xs, _ = scan(
@@ -562,52 +569,50 @@ def step(xtm1, nu, sigma):
562569
batch_size = 3
563570
x = CustomDist.dist(nu, sigma, steps, dist=trw, size=batch_size)
564571

565-
x_draw = pm.draw(x, random_seed=1)
572+
x_draw = draw(x, random_seed=1)
566573
assert x_draw.shape == (steps, batch_size)
567-
np.testing.assert_allclose(pm.draw(x, random_seed=1), x_draw)
568-
assert not np.any(pm.draw(x, random_seed=2) == x_draw)
574+
np.testing.assert_allclose(draw(x, random_seed=1), x_draw)
575+
assert not np.any(draw(x, random_seed=2) == x_draw)
569576

570-
ref_dist = pm.RandomWalk.dist(
571-
init_dist=pm.Flat.dist(),
572-
innovation_dist=pm.StudentT.dist(nu=nu, sigma=sigma),
577+
ref_dist = RandomWalk.dist(
578+
init_dist=Flat.dist(),
579+
innovation_dist=StudentT.dist(nu=nu, sigma=sigma),
573580
steps=steps,
574581
size=(batch_size,),
575582
)
576583
ref_val = pt.concatenate([np.zeros((1, batch_size)), x_draw]).T
577584

578585
np.testing.assert_allclose(
579-
pm.logp(x, x_draw).eval().sum(0),
580-
pm.logp(ref_dist, ref_val).eval(),
586+
logp(x, x_draw).eval().sum(0),
587+
logp(ref_dist, ref_val).eval(),
581588
)
582589

583590
def test_inferred_logp_mixture(self):
584591
import numpy as np
585592

586-
import pymc as pm
587-
588593
def shifted_normal(mu, sigma, size):
589-
return mu + pm.Normal.dist(0, sigma, shape=size)
594+
return mu + Normal.dist(0, sigma, shape=size)
590595

591596
mus = [3.5, -4.3]
592597
sds = [1.5, 2.3]
593598
w = [0.3, 0.7]
594-
with pm.Model() as m:
599+
with Model() as m:
595600
comp_dists = [
596601
CustomDist.dist(mus[0], sds[0], dist=shifted_normal),
597602
CustomDist.dist(mus[1], sds[1], dist=shifted_normal),
598603
]
599-
pm.Mixture("mix", w=w, comp_dists=comp_dists)
604+
Mixture("mix", w=w, comp_dists=comp_dists)
600605

601606
test_value = 0.1
602607
np.testing.assert_allclose(
603608
m.compile_logp()({"mix": test_value}),
604-
pm.logp(pm.NormalMixture.dist(w=w, mu=mus, sigma=sds), test_value).eval(),
609+
logp(NormalMixture.dist(w=w, mu=mus, sigma=sds), test_value).eval(),
605610
)
606611

607612
def test_symbolic_dist(self):
608613
# Test we can create a SymbolicDist inside a CustomDist
609614
def dist(size):
610-
return pm.Truncated.dist(pm.Beta.dist(1, 1, size=size), lower=0.1, upper=0.9)
615+
return Truncated.dist(Beta.dist(1, 1, size=size), lower=0.1, upper=0.9)
611616

612617
assert CustomDist.dist(dist=dist)
613618

@@ -616,20 +621,20 @@ def test_nested_custom_dist(self):
616621

617622
def dist(size=None):
618623
def inner_dist(size=None):
619-
return pm.Normal.dist(size=size)
624+
return Normal.dist(size=size)
620625

621626
inner_dist = CustomDist.dist(dist=inner_dist, size=size)
622627
return pt.exp(inner_dist)
623628

624629
rv = CustomDist.dist(dist=dist)
625630
np.testing.assert_allclose(
626-
pm.logp(rv, 1.0).eval(),
627-
pm.logp(pm.LogNormal.dist(), 1.0).eval(),
631+
logp(rv, 1.0).eval(),
632+
logp(LogNormal.dist(), 1.0).eval(),
628633
)
629634

630635
def test_signature(self):
631636
def dist(p, size):
632-
return -pm.Categorical.dist(p=p, size=size)
637+
return -Categorical.dist(p=p, size=size)
633638

634639
out = CustomDist.dist([0.25, 0.75], dist=dist, signature="(p)->()")
635640
# Size and updates are added automatically to the signature

0 commit comments

Comments
 (0)