23
23
from pytensor import tensor as pt
24
24
from scipy import stats as st
25
25
26
- import pymc as pm
27
-
28
- from pymc import (
29
- CustomDist ,
30
- Deterministic ,
26
+ from pymc . distributions import (
27
+ Bernoulli ,
28
+ Beta ,
29
+ Categorical ,
30
+ ChiSquared ,
31
31
DiracDelta ,
32
+ Flat ,
32
33
HalfNormal ,
33
34
LogNormal ,
34
- Model ,
35
+ Mixture ,
36
+ MvNormal ,
35
37
Normal ,
36
- draw ,
37
- logcdf ,
38
- logp ,
39
- sample ,
38
+ NormalMixture ,
39
+ RandomWalk ,
40
+ StudentT ,
41
+ Truncated ,
42
+ Uniform ,
40
43
)
41
- from pymc .distributions .custom import CustomDistRV , CustomSymbolicDistRV
44
+ from pymc .distributions .custom import CustomDist , CustomDistRV , CustomSymbolicDistRV
42
45
from pymc .distributions .distribution import support_point
43
46
from pymc .distributions .shape_utils import change_dist_size , rv_size_is_none , to_tuple
44
47
from pymc .distributions .transforms import log
45
48
from pymc .exceptions import BlockModelAccessError
49
+ from pymc .logprob import logcdf , logp
50
+ from pymc .model import Deterministic , Model
46
51
from pymc .pytensorf import collect_default_updates
52
+ from pymc .sampling import draw , sample , sample_posterior_predictive
53
+ from pymc .step_methods import Metropolis
47
54
from pymc .testing import assert_support_point_is_expected
48
55
49
56
@@ -88,15 +95,15 @@ def test_custom_dist_without_random(self):
88
95
custom_dist = CustomDist (
89
96
"custom_dist" ,
90
97
mu ,
91
- logp = lambda value , mu : logp (pm . Normal .dist (mu , 1 , size = 100 ), value ),
98
+ logp = lambda value , mu : logp (Normal .dist (mu , 1 , size = 100 ), value ),
92
99
observed = np .random .randn (100 ),
93
100
initval = 0 ,
94
101
)
95
102
assert isinstance (custom_dist .owner .op , CustomDistRV )
96
- idata = sample (tune = 50 , draws = 100 , cores = 1 , step = pm . Metropolis ())
103
+ idata = sample (tune = 50 , draws = 100 , cores = 1 , step = Metropolis ())
97
104
98
105
with pytest .raises (NotImplementedError ):
99
- pm . sample_posterior_predictive (idata , model = model )
106
+ sample_posterior_predictive (idata , model = model )
100
107
101
108
@pytest .mark .xfail (
102
109
NotImplementedError ,
@@ -159,7 +166,7 @@ def test_custom_dist_multivariate_logp(self, size):
159
166
with Model () as model :
160
167
161
168
def logp (value , mu ):
162
- return pm . MvNormal .logp (value , mu , pt .eye (mu .shape [0 ]))
169
+ return MvNormal .logp (value , mu , pt .eye (mu .shape [0 ]))
163
170
164
171
mu = Normal ("mu" , size = supp_shape )
165
172
a = CustomDist ("a" , mu , logp = logp , ndims_params = [1 ], ndim_supp = 1 , size = size )
@@ -184,14 +191,14 @@ def logp(value, mu):
184
191
def test_custom_dist_default_support_point_univariate (self , support_point , size , expected ):
185
192
if support_point == "custom_support_point" :
186
193
support_point = lambda rv , size , * rv_inputs : 5 * pt .ones (size , dtype = rv .dtype ) # noqa E731
187
- with pm . Model () as model :
194
+ with Model () as model :
188
195
x = CustomDist ("x" , support_point = support_point , size = size )
189
196
assert isinstance (x .owner .op , CustomDistRV )
190
197
assert_support_point_is_expected (model , expected , check_finite_logp = False )
191
198
192
199
def test_custom_dist_moment_future_warning (self ):
193
200
moment = lambda rv , size , * rv_inputs : 5 * pt .ones (size , dtype = rv .dtype ) # noqa E731
194
- with pm . Model () as model :
201
+ with Model () as model :
195
202
with pytest .warns (
196
203
FutureWarning , match = "`moment` argument is deprecated. Use `support_point` instead."
197
204
):
@@ -280,24 +287,24 @@ def test_dist(self):
280
287
mu = 1
281
288
x = CustomDist .dist (
282
289
mu ,
283
- logp = lambda value , mu : pm . logp (pm . Normal .dist (mu ), value ),
290
+ logp = lambda value , mu : logp (Normal .dist (mu ), value ),
284
291
random = lambda mu , rng = None , size = None : rng .normal (loc = mu , scale = 1 , size = size ),
285
292
shape = (3 ,),
286
293
)
287
294
288
295
x = cloudpickle .loads (cloudpickle .dumps (x ))
289
296
290
- test_value = pm . draw (x , random_seed = 1 )
291
- assert np .all (test_value == pm . draw (x , random_seed = 1 ))
297
+ test_value = draw (x , random_seed = 1 )
298
+ assert np .all (test_value == draw (x , random_seed = 1 ))
292
299
293
- x_logp = pm . logp (x , test_value )
300
+ x_logp = logp (x , test_value )
294
301
assert np .allclose (x_logp .eval (), st .norm (1 ).logpdf (test_value ))
295
302
296
303
297
304
class TestCustomSymbolicDist :
298
305
def test_basic (self ):
299
306
def custom_dist (mu , sigma , size ):
300
- return pt .exp (pm . Normal .dist (mu , sigma , size = size ))
307
+ return pt .exp (Normal .dist (mu , sigma , size = size ))
301
308
302
309
with Model () as m :
303
310
mu = Normal ("mu" )
@@ -315,7 +322,7 @@ def custom_dist(mu, sigma, size):
315
322
assert isinstance (lognormal .owner .op , CustomSymbolicDistRV )
316
323
317
324
# Fix mu and sigma, so that all source of randomness comes from the symbolic RV
318
- draws = pm . draw (lognormal , draws = 3 , givens = {mu : 0.0 , sigma : 1.0 })
325
+ draws = draw (lognormal , draws = 3 , givens = {mu : 0.0 , sigma : 1.0 })
319
326
assert draws .shape == (3 , 10 )
320
327
assert np .unique (draws ).size == 30
321
328
@@ -334,31 +341,31 @@ def custom_dist(mu, sigma, size):
334
341
(5 , 1 ),
335
342
None ,
336
343
np .exp (5 ),
337
- lambda mu , sigma , size : pt .exp (pm . Normal .dist (mu , sigma , size = size )),
344
+ lambda mu , sigma , size : pt .exp (Normal .dist (mu , sigma , size = size )),
338
345
),
339
346
(
340
347
(2 , np .ones (5 )),
341
348
None ,
342
349
np .exp (2 + np .ones (5 )),
343
- lambda mu , sigma , size : pt .exp (pm . Normal .dist (mu , sigma , size = size ) + 1.0 ),
350
+ lambda mu , sigma , size : pt .exp (Normal .dist (mu , sigma , size = size ) + 1.0 ),
344
351
),
345
352
(
346
353
(1 , 2 ),
347
354
None ,
348
355
np .sqrt (np .exp (1 + 0.5 * 2 ** 2 )),
349
- lambda mu , sigma , size : pt .sqrt (pm . LogNormal .dist (mu , sigma , size = size )),
356
+ lambda mu , sigma , size : pt .sqrt (LogNormal .dist (mu , sigma , size = size )),
350
357
),
351
358
(
352
359
(4 ,),
353
360
(3 ,),
354
361
np .log ([4 , 4 , 4 ]),
355
- lambda nu , size : pt .log (pm . ChiSquared .dist (nu , size = size )),
362
+ lambda nu , size : pt .log (ChiSquared .dist (nu , size = size )),
356
363
),
357
364
(
358
365
(12 , 1 ),
359
366
None ,
360
367
12 ,
361
- lambda mu1 , sigma , size : pm . Normal .dist (mu1 , sigma , size = size ),
368
+ lambda mu1 , sigma , size : Normal .dist (mu1 , sigma , size = size ),
362
369
),
363
370
],
364
371
)
@@ -369,7 +376,7 @@ def test_custom_dist_default_support_point(self, dist_params, size, expected, di
369
376
370
377
def test_custom_dist_default_support_point_scan (self ):
371
378
def scan_step (left , right ):
372
- x = pm . Uniform .dist (left , right )
379
+ x = Uniform .dist (left , right )
373
380
x_update = collect_default_updates ([x ])
374
381
return x , x_update
375
382
@@ -390,7 +397,7 @@ def dist(size):
390
397
391
398
def test_custom_dist_default_support_point_scan_recurring (self ):
392
399
def scan_step (xtm1 ):
393
- x = pm . Normal .dist (xtm1 + 1 )
400
+ x = Normal .dist (xtm1 + 1 )
394
401
x_update = collect_default_updates ([x ])
395
402
return x , x_update
396
403
@@ -417,15 +424,15 @@ def dist(size):
417
424
)
418
425
def test_custom_dist_default_support_point_nested (self , left , right , size , expected ):
419
426
def dist_fn (left , right , size ):
420
- return pm . Truncated .dist (pm . Normal .dist (0 , 1 ), left , right , size = size ) + 5
427
+ return Truncated .dist (Normal .dist (0 , 1 ), left , right , size = size ) + 5
421
428
422
429
with Model () as model :
423
430
CustomDist ("x" , left , right , size = size , dist = dist_fn )
424
431
assert_support_point_is_expected (model , expected )
425
432
426
433
def test_logcdf_inference (self ):
427
434
def custom_dist (mu , sigma , size ):
428
- return pt .exp (pm . Normal .dist (mu , sigma , size = size ))
435
+ return pt .exp (Normal .dist (mu , sigma , size = size ))
429
436
430
437
mu = 1
431
438
sigma = 1.25
@@ -435,16 +442,16 @@ def custom_dist(mu, sigma, size):
435
442
ref_lognormal = LogNormal .dist (mu , sigma )
436
443
437
444
np .testing .assert_allclose (
438
- pm . logcdf (custom_lognormal , test_value ).eval (),
439
- pm . logcdf (ref_lognormal , test_value ).eval (),
445
+ logcdf (custom_lognormal , test_value ).eval (),
446
+ logcdf (ref_lognormal , test_value ).eval (),
440
447
)
441
448
442
449
def test_random_multiple_rngs (self ):
443
450
def custom_dist (p , sigma , size ):
444
- idx = pm . Bernoulli .dist (p = p )
451
+ idx = Bernoulli .dist (p = p )
445
452
if rv_size_is_none (size ):
446
453
size = pt .broadcast_shape (p , sigma )
447
- comps = pm . Normal .dist ([- sigma , sigma ], 1e-1 , size = (* size , 2 )).T
454
+ comps = Normal .dist ([- sigma , sigma ], 1e-1 , size = (* size , 2 )).T
448
455
return comps [idx ]
449
456
450
457
customdist = CustomDist .dist (
@@ -461,7 +468,7 @@ def custom_dist(p, sigma, size):
461
468
assert len (node .outputs ) == 3 # RV and 2 updated RNGs
462
469
assert len (node .op .update (node )) == 2
463
470
464
- draws = pm . draw (customdist , draws = 2 , random_seed = 123 )
471
+ draws = draw (customdist , draws = 2 , random_seed = 123 )
465
472
assert np .unique (draws ).size == 20
466
473
467
474
def test_custom_methods (self ):
@@ -494,7 +501,7 @@ def custom_logcdf(value, mu):
494
501
495
502
def test_change_size (self ):
496
503
def custom_dist (mu , sigma , size ):
497
- return pt .exp (pm . Normal .dist (mu , sigma , size = size ))
504
+ return pt .exp (Normal .dist (mu , sigma , size = size ))
498
505
499
506
lognormal = CustomDist .dist (
500
507
0 ,
@@ -515,9 +522,9 @@ def custom_dist(mu, sigma, size):
515
522
516
523
def test_error_model_access (self ):
517
524
def custom_dist (size ):
518
- return pm . Flat ("Flat" , size = size )
525
+ return Flat ("Flat" , size = size )
519
526
520
- with pm . Model () as m :
527
+ with Model () as m :
521
528
with pytest .raises (
522
529
BlockModelAccessError ,
523
530
match = "Model variables cannot be created in the dist function" ,
@@ -526,7 +533,7 @@ def custom_dist(size):
526
533
527
534
def test_api_change_error (self ):
528
535
def old_random (size ):
529
- return pm . Flat .dist (size = size )
536
+ return Flat .dist (size = size )
530
537
531
538
# Old API raises
532
539
with pytest .raises (TypeError , match = "API change: function passed to `random` argument" ):
@@ -541,7 +548,7 @@ def trw(nu, sigma, steps, size):
541
548
size = ()
542
549
543
550
def step (xtm1 , nu , sigma ):
544
- x = pm . StudentT .dist (nu = nu , mu = xtm1 , sigma = sigma , shape = size )
551
+ x = StudentT .dist (nu = nu , mu = xtm1 , sigma = sigma , shape = size )
545
552
return x , collect_default_updates ([x ])
546
553
547
554
xs , _ = scan (
@@ -562,52 +569,50 @@ def step(xtm1, nu, sigma):
562
569
batch_size = 3
563
570
x = CustomDist .dist (nu , sigma , steps , dist = trw , size = batch_size )
564
571
565
- x_draw = pm . draw (x , random_seed = 1 )
572
+ x_draw = draw (x , random_seed = 1 )
566
573
assert x_draw .shape == (steps , batch_size )
567
- np .testing .assert_allclose (pm . draw (x , random_seed = 1 ), x_draw )
568
- assert not np .any (pm . draw (x , random_seed = 2 ) == x_draw )
574
+ np .testing .assert_allclose (draw (x , random_seed = 1 ), x_draw )
575
+ assert not np .any (draw (x , random_seed = 2 ) == x_draw )
569
576
570
- ref_dist = pm . RandomWalk .dist (
571
- init_dist = pm . Flat .dist (),
572
- innovation_dist = pm . StudentT .dist (nu = nu , sigma = sigma ),
577
+ ref_dist = RandomWalk .dist (
578
+ init_dist = Flat .dist (),
579
+ innovation_dist = StudentT .dist (nu = nu , sigma = sigma ),
573
580
steps = steps ,
574
581
size = (batch_size ,),
575
582
)
576
583
ref_val = pt .concatenate ([np .zeros ((1 , batch_size )), x_draw ]).T
577
584
578
585
np .testing .assert_allclose (
579
- pm . logp (x , x_draw ).eval ().sum (0 ),
580
- pm . logp (ref_dist , ref_val ).eval (),
586
+ logp (x , x_draw ).eval ().sum (0 ),
587
+ logp (ref_dist , ref_val ).eval (),
581
588
)
582
589
583
590
def test_inferred_logp_mixture (self ):
584
591
import numpy as np
585
592
586
- import pymc as pm
587
-
588
593
def shifted_normal (mu , sigma , size ):
589
- return mu + pm . Normal .dist (0 , sigma , shape = size )
594
+ return mu + Normal .dist (0 , sigma , shape = size )
590
595
591
596
mus = [3.5 , - 4.3 ]
592
597
sds = [1.5 , 2.3 ]
593
598
w = [0.3 , 0.7 ]
594
- with pm . Model () as m :
599
+ with Model () as m :
595
600
comp_dists = [
596
601
CustomDist .dist (mus [0 ], sds [0 ], dist = shifted_normal ),
597
602
CustomDist .dist (mus [1 ], sds [1 ], dist = shifted_normal ),
598
603
]
599
- pm . Mixture ("mix" , w = w , comp_dists = comp_dists )
604
+ Mixture ("mix" , w = w , comp_dists = comp_dists )
600
605
601
606
test_value = 0.1
602
607
np .testing .assert_allclose (
603
608
m .compile_logp ()({"mix" : test_value }),
604
- pm . logp (pm . NormalMixture .dist (w = w , mu = mus , sigma = sds ), test_value ).eval (),
609
+ logp (NormalMixture .dist (w = w , mu = mus , sigma = sds ), test_value ).eval (),
605
610
)
606
611
607
612
def test_symbolic_dist (self ):
608
613
# Test we can create a SymbolicDist inside a CustomDist
609
614
def dist (size ):
610
- return pm . Truncated .dist (pm . Beta .dist (1 , 1 , size = size ), lower = 0.1 , upper = 0.9 )
615
+ return Truncated .dist (Beta .dist (1 , 1 , size = size ), lower = 0.1 , upper = 0.9 )
611
616
612
617
assert CustomDist .dist (dist = dist )
613
618
@@ -616,20 +621,20 @@ def test_nested_custom_dist(self):
616
621
617
622
def dist (size = None ):
618
623
def inner_dist (size = None ):
619
- return pm . Normal .dist (size = size )
624
+ return Normal .dist (size = size )
620
625
621
626
inner_dist = CustomDist .dist (dist = inner_dist , size = size )
622
627
return pt .exp (inner_dist )
623
628
624
629
rv = CustomDist .dist (dist = dist )
625
630
np .testing .assert_allclose (
626
- pm . logp (rv , 1.0 ).eval (),
627
- pm . logp (pm . LogNormal .dist (), 1.0 ).eval (),
631
+ logp (rv , 1.0 ).eval (),
632
+ logp (LogNormal .dist (), 1.0 ).eval (),
628
633
)
629
634
630
635
def test_signature (self ):
631
636
def dist (p , size ):
632
- return - pm . Categorical .dist (p = p , size = size )
637
+ return - Categorical .dist (p = p , size = size )
633
638
634
639
out = CustomDist .dist ([0.25 , 0.75 ], dist = dist , signature = "(p)->()" )
635
640
# Size and updates are added automatically to the signature
0 commit comments