Skip to content

Commit 652c9de

Browse files
larryshamalamaricardoV94
authored andcommitted
Refactoring _print_name for certain RVs and specifying rv_type in their distributions
1 parent 7bad057 commit 652c9de

File tree

7 files changed

+158
-56
lines changed

7 files changed

+158
-56
lines changed

pymc/distributions/continuous.py

Lines changed: 61 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -35,21 +35,22 @@
3535
from aesara.tensor.math import tanh
3636
from aesara.tensor.random.basic import (
3737
BetaRV,
38-
cauchy,
38+
CauchyRV,
39+
HalfCauchyRV,
40+
HalfNormalRV,
41+
LogNormalRV,
42+
NormalRV,
43+
UniformRV,
3944
chisquare,
4045
exponential,
4146
gamma,
4247
gumbel,
43-
halfcauchy,
44-
halfnormal,
4548
invgamma,
4649
laplace,
4750
logistic,
48-
lognormal,
4951
normal,
5052
pareto,
5153
triangular,
52-
uniform,
5354
vonmises,
5455
)
5556
from aesara.tensor.random.op import RandomVariable
@@ -252,6 +253,13 @@ def get_tau_sigma(tau=None, sigma=None):
252253
return floatX(tau), floatX(sigma)
253254

254255

256+
class PyMCUniformRV(UniformRV):
257+
_print_name = ("Uniform", "\\operatorname{Uniform}")
258+
259+
260+
pymc_uniform = PyMCUniformRV()
261+
262+
255263
class Uniform(BoundedContinuous):
256264
r"""
257265
Continuous uniform log-likelihood.
@@ -295,7 +303,8 @@ class Uniform(BoundedContinuous):
295303
upper : tensor_like of float, default 1
296304
Upper limit.
297305
"""
298-
rv_op = uniform
306+
rv_op = pymc_uniform
307+
rv_type = UniformRV
299308
bound_args_indices = (3, 4) # Lower, Upper
300309

301310
@classmethod
@@ -479,6 +488,13 @@ def logcdf(value):
479488
return at.switch(at.lt(value, np.inf), -np.inf, at.switch(at.eq(value, np.inf), 0, -np.inf))
480489

481490

491+
class PyMCNormalRV(NormalRV):
492+
_print_name = ("Normal", "\\operatorname{Normal}")
493+
494+
495+
pymc_normal = PyMCNormalRV()
496+
497+
482498
class Normal(Continuous):
483499
r"""
484500
Univariate normal log-likelihood.
@@ -544,7 +560,8 @@ class Normal(Continuous):
544560
with pm.Model():
545561
x = pm.Normal('x', mu=0, tau=1/23)
546562
"""
547-
rv_op = normal
563+
rv_op = pymc_normal
564+
rv_type = NormalRV
548565

549566
@classmethod
550567
def dist(cls, mu=0, sigma=None, tau=None, **kwargs):
@@ -801,6 +818,13 @@ def truncated_normal_default_transform(op, rv):
801818
return bounded_cont_transform(op, rv, TruncatedNormal.bound_args_indices)
802819

803820

821+
class PyMCHalfNormalRV(HalfNormalRV):
822+
_print_name = ("HalfNormal", "\\operatorname{HalfNormal}")
823+
824+
825+
pymc_halfnormal = PyMCHalfNormalRV()
826+
827+
804828
class HalfNormal(PositiveContinuous):
805829
r"""
806830
Half-normal log-likelihood.
@@ -867,7 +891,8 @@ class HalfNormal(PositiveContinuous):
867891
with pm.Model():
868892
x = pm.HalfNormal('x', tau=1/15)
869893
"""
870-
rv_op = halfnormal
894+
rv_op = pymc_halfnormal
895+
rv_type = HalfNormalRV
871896

872897
@classmethod
873898
def dist(cls, sigma=None, tau=None, *args, **kwargs):
@@ -1690,6 +1715,13 @@ def logp(value, b, kappa, mu):
16901715
return check_parameters(res, 0 < b, 0 < kappa, msg="b > 0, kappa > 0")
16911716

16921717

1718+
class PyMCLogNormalRV(LogNormalRV):
1719+
_print_name = ("LogNormal", "\\operatorname{LogNormal}")
1720+
1721+
1722+
pymc_lognormal = PyMCLogNormalRV()
1723+
1724+
16931725
class LogNormal(PositiveContinuous):
16941726
r"""
16951727
Log-normal log-likelihood.
@@ -1758,7 +1790,8 @@ class LogNormal(PositiveContinuous):
17581790
x = pm.LogNormal('x', mu=2, tau=1/100)
17591791
"""
17601792

1761-
rv_op = lognormal
1793+
rv_op = pymc_lognormal
1794+
rv_type = LogNormalRV
17621795

17631796
@classmethod
17641797
def dist(cls, mu=0, sigma=None, tau=None, *args, **kwargs):
@@ -2049,6 +2082,13 @@ def pareto_default_transform(op, rv):
20492082
return bounded_cont_transform(op, rv, Pareto.bound_args_indices)
20502083

20512084

2085+
class PyMCCauchyRV(CauchyRV):
2086+
_print_name = ("Cauchy", "\\operatorname{Cauchy}")
2087+
2088+
2089+
pymc_cauchy = PyMCCauchyRV()
2090+
2091+
20522092
class Cauchy(Continuous):
20532093
r"""
20542094
Cauchy log-likelihood.
@@ -2095,7 +2135,8 @@ class Cauchy(Continuous):
20952135
beta : tensor_like of float
20962136
Scale parameter > 0.
20972137
"""
2098-
rv_op = cauchy
2138+
rv_op = pymc_cauchy
2139+
rv_type = CauchyRV
20992140

21002141
@classmethod
21012142
def dist(cls, alpha, beta, *args, **kwargs):
@@ -2133,6 +2174,13 @@ def logcdf(value, alpha, beta):
21332174
)
21342175

21352176

2177+
class PyMCHalfCauchyRV(HalfCauchyRV):
2178+
_print_name = ("HalfCauchy", "\\operatorname{HalfCauchy}")
2179+
2180+
2181+
pymc_halfcauchy = PyMCHalfCauchyRV()
2182+
2183+
21362184
class HalfCauchy(PositiveContinuous):
21372185
r"""
21382186
Half-Cauchy log-likelihood.
@@ -2172,7 +2220,8 @@ class HalfCauchy(PositiveContinuous):
21722220
beta : tensor_like of float
21732221
Scale parameter (beta > 0).
21742222
"""
2175-
rv_op = halfcauchy
2223+
rv_op = pymc_halfcauchy
2224+
rv_type = HalfCauchyRV
21762225

21772226
@classmethod
21782227
def dist(cls, beta, *args, **kwargs):
@@ -3942,7 +3991,7 @@ class PolyaGammaRV(RandomVariable):
39423991
ndim_supp = 0
39433992
ndims_params = [0, 0]
39443993
dtype = "floatX"
3945-
_print_name = ("PG", "\\operatorname{PG}")
3994+
_print_name = ("PolyaGamma", "\\operatorname{PolyaGamma}")
39463995

39473996
def __call__(self, h=1.0, z=0.0, size=None, **kwargs):
39483997
return super().__call__(h, z, size=size, **kwargs)

pymc/distributions/discrete.py

Lines changed: 40 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,16 @@
1717
import numpy as np
1818

1919
from aesara.tensor.random.basic import (
20+
GeometricRV,
21+
HyperGeometricRV,
22+
NegBinomialRV,
23+
PoissonRV,
2024
RandomVariable,
2125
ScipyRandomVariable,
2226
bernoulli,
2327
betabinom,
2428
binomial,
2529
categorical,
26-
geometric,
27-
hypergeometric,
28-
nbinom,
29-
poisson,
3030
)
3131
from scipy import stats
3232

@@ -560,6 +560,13 @@ def logcdf(value, q, beta):
560560
return check_parameters(res, 0 < q, q < 1, 0 < beta, msg="0 < q < 1, beta > 0")
561561

562562

563+
class PyMCPoissonRV(PoissonRV):
564+
_print_name = ("Poisson", "\\operatorname{Poisson}")
565+
566+
567+
pymc_poisson = PyMCPoissonRV()
568+
569+
563570
class Poisson(Discrete):
564571
R"""
565572
Poisson log-likelihood.
@@ -605,7 +612,8 @@ class Poisson(Discrete):
605612
The Poisson distribution can be derived as a limiting case of the
606613
binomial distribution.
607614
"""
608-
rv_op = poisson
615+
rv_op = pymc_poisson
616+
rv_type = PoissonRV
609617

610618
@classmethod
611619
def dist(cls, mu, *args, **kwargs):
@@ -674,6 +682,13 @@ def logcdf(value, mu):
674682
return check_parameters(res, 0 <= mu, msg="mu >= 0")
675683

676684

685+
class PyMCNegativeBinomialRV(NegBinomialRV):
686+
_print_name = ("NegBinom", "\\operatorname{NegBinom}")
687+
688+
689+
pymc_nbinom = PyMCNegativeBinomialRV()
690+
691+
677692
class NegativeBinomial(Discrete):
678693
R"""
679694
Negative binomial log-likelihood.
@@ -746,7 +761,8 @@ def NegBinom(a, m, x):
746761
n : tensor_like of float
747762
Alternative number of target success trials (n > 0)
748763
"""
749-
rv_op = nbinom
764+
rv_op = pymc_nbinom
765+
rv_type = NegBinomialRV
750766

751767
@classmethod
752768
def dist(cls, mu=None, alpha=None, p=None, n=None, *args, **kwargs):
@@ -847,6 +863,13 @@ def logcdf(value, n, p):
847863
)
848864

849865

866+
class PyMCGeometricRV(GeometricRV):
867+
_print_name = ("Geometric", "\\operatorname{Geometric}")
868+
869+
870+
pymc_geometric = PyMCGeometricRV()
871+
872+
850873
class Geometric(Discrete):
851874
R"""
852875
Geometric log-likelihood.
@@ -886,7 +909,8 @@ class Geometric(Discrete):
886909
Probability of success on an individual trial (0 < p <= 1).
887910
"""
888911

889-
rv_op = geometric
912+
rv_op = pymc_geometric
913+
rv_type = GeometricRV
890914

891915
@classmethod
892916
def dist(cls, p, *args, **kwargs):
@@ -956,6 +980,13 @@ def logcdf(value, p):
956980
)
957981

958982

983+
class PyMCHyperGeometricRV(HyperGeometricRV):
984+
_print_name = ("HyperGeometric", "\\operatorname{HyperGeometric}")
985+
986+
987+
pymc_hypergeometric = PyMCHyperGeometricRV()
988+
989+
959990
class HyperGeometric(Discrete):
960991
R"""
961992
Discrete hypergeometric distribution.
@@ -1004,7 +1035,8 @@ class HyperGeometric(Discrete):
10041035
Number of samples drawn from the population (0 <= n <= N)
10051036
"""
10061037

1007-
rv_op = hypergeometric
1038+
rv_op = pymc_hypergeometric
1039+
rv_type = HyperGeometricRV
10081040

10091041
@classmethod
10101042
def dist(cls, N, k, n, *args, **kwargs):

pymc/distributions/distribution.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,9 +102,9 @@ def _random(*args, **kwargs):
102102
clsdict["random"] = _random
103103

104104
rv_op = clsdict.setdefault("rv_op", None)
105-
rv_type = None
105+
rv_type = clsdict.setdefault("rv_type", None)
106106

107-
if isinstance(rv_op, RandomVariable):
107+
if rv_type is None and isinstance(rv_op, RandomVariable):
108108
rv_type = type(rv_op)
109109
clsdict["rv_type"] = rv_type
110110

pymc/distributions/multivariate.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,12 @@
3232
from aesara.sparse.basic import sp_sum
3333
from aesara.tensor import gammaln, sigmoid
3434
from aesara.tensor.nlinalg import det, eigh, matrix_inverse, trace
35-
from aesara.tensor.random.basic import dirichlet, multinomial, multivariate_normal
35+
from aesara.tensor.random.basic import (
36+
DirichletRV,
37+
MvNormalRV,
38+
multinomial,
39+
multivariate_normal,
40+
)
3641
from aesara.tensor.random.op import RandomVariable, default_supp_shape_from_params
3742
from aesara.tensor.random.utils import broadcast_params
3843
from aesara.tensor.slinalg import Cholesky, SolveTriangular
@@ -190,6 +195,13 @@ def quaddist_tau(delta, chol_mat):
190195
return quaddist, logdet, ok
191196

192197

198+
class PyMCMvNormalRV(MvNormalRV):
199+
_print_name = ("MvNormal", "\\operatorname{MvNormal}")
200+
201+
202+
pymc_multivariate_normal = PyMCMvNormalRV()
203+
204+
193205
class MvNormal(Continuous):
194206
r"""
195207
Multivariate normal log-likelihood.
@@ -254,7 +266,8 @@ class MvNormal(Continuous):
254266
vals_raw = pm.Normal('vals_raw', mu=0, sigma=1, shape=(5, 3))
255267
vals = pm.Deterministic('vals', at.dot(chol, vals_raw.T).T)
256268
"""
257-
rv_op = multivariate_normal
269+
rv_op = pymc_multivariate_normal
270+
rv_type = MvNormalRV
258271

259272
@classmethod
260273
def dist(cls, mu, cov=None, tau=None, chol=None, lower=True, **kwargs):
@@ -436,6 +449,13 @@ def logp(value, nu, mu, scale):
436449
)
437450

438451

452+
class PyMCDirichletRV(DirichletRV):
453+
_print_name = ("Dirichlet", "\\operator{Dirichlet}")
454+
455+
456+
pymc_dirichlet = PyMCDirichletRV()
457+
458+
439459
class Dirichlet(SimplexContinuous):
440460
r"""
441461
Dirichlet log-likelihood.
@@ -460,7 +480,8 @@ class Dirichlet(SimplexContinuous):
460480
Concentration parameters (a > 0). The number of categories is given by the
461481
length of the last axis.
462482
"""
463-
rv_op = dirichlet
483+
rv_op = pymc_dirichlet
484+
rv_type = DirichletRV
464485

465486
@classmethod
466487
def dist(cls, a, **kwargs):

pymc/tests/distributions/test_logprob.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -320,7 +320,7 @@ def test_ignore_logprob_basic():
320320
new_x = ignore_logprob(x)
321321
assert new_x is not x
322322
assert isinstance(new_x.owner.op, Normal)
323-
assert type(new_x.owner.op).__name__ == "UnmeasurableNormalRV"
323+
assert type(new_x.owner.op).__name__ == "UnmeasurablePyMCNormalRV"
324324
# Confirm that it does not have measurable output
325325
assert get_measurable_outputs(new_x.owner.op, new_x.owner) is None
326326

pymc/tests/test_aesaraf.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from aeppl.logprob import ParameterValueError
2626
from aesara.compile.builders import OpFromGraph
2727
from aesara.graph.basic import Variable, equal_computations
28-
from aesara.tensor.random.basic import normal, uniform
28+
from aesara.tensor.random.basic import NormalRV, normal, uniform
2929
from aesara.tensor.random.op import RandomVariable
3030
from aesara.tensor.random.var import RandomStateSharedVariable
3131
from aesara.tensor.subtensor import AdvancedIncSubtensor, AdvancedIncSubtensor1
@@ -405,7 +405,7 @@ def test_rvs_to_value_vars_unvalued_rv():
405405
res_y = res.owner.inputs[1]
406406
# Graph should have be cloned, and therefore y and res_y should have different ids
407407
assert res_y is not y
408-
assert res_y.owner.op == at.random.normal
408+
assert isinstance(res_y.owner.op, NormalRV)
409409
assert res_y.owner.inputs[3] is x_value
410410

411411

0 commit comments

Comments
 (0)