Skip to content

Commit 4832992

Browse files
tvwengerricardoV94
authored andcommitted
Implement Chi distribution helper
1 parent ebba64e commit 4832992

File tree

4 files changed

+94
-4
lines changed

4 files changed

+94
-4
lines changed

docs/api_reference.rst

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,10 @@ Distributions
2828
.. autosummary::
2929
:toctree: generated/
3030

31-
GenExtreme
32-
GeneralizedPoisson
31+
Chi
3332
DiscreteMarkovChain
33+
GeneralizedPoisson
34+
GenExtreme
3435
R2D2M2CP
3536
histogram_approximation
3637

pymc_experimental/distributions/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
Experimental probability distributions for stochastic nodes in PyMC.
1818
"""
1919

20-
from pymc_experimental.distributions.continuous import GenExtreme
20+
from pymc_experimental.distributions.continuous import Chi, GenExtreme
2121
from pymc_experimental.distributions.discrete import GeneralizedPoisson
2222
from pymc_experimental.distributions.histogram_utils import histogram_approximation
2323
from pymc_experimental.distributions.multivariate import R2D2M2CP
@@ -29,4 +29,5 @@
2929
"GenExtreme",
3030
"R2D2M2CP",
3131
"histogram_approximation",
32+
"Chi",
3233
]

pymc_experimental/distributions/continuous.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323

2424
import numpy as np
2525
import pytensor.tensor as pt
26+
from pymc import ChiSquared, CustomDist
27+
from pymc.distributions import transforms
2628
from pymc.distributions.dist_math import check_parameters
2729
from pymc.distributions.distribution import Continuous
2830
from pymc.distributions.shape_utils import rv_size_is_none
@@ -216,3 +218,65 @@ def moment(rv, size, mu, sigma, xi):
216218
if not rv_size_is_none(size):
217219
mode = pt.full(size, mode)
218220
return mode
221+
222+
223+
class Chi:
224+
r"""
225+
:math:`\chi` log-likelihood.
226+
227+
The pdf of this distribution is
228+
229+
.. math::
230+
231+
f(x \mid \nu) = \frac{x^{\nu - 1}e^{-x^2/2}}{2^{\nu/2 - 1}\Gamma(\nu/2)}
232+
233+
.. plot::
234+
:context: close-figs
235+
236+
import matplotlib.pyplot as plt
237+
import numpy as np
238+
import scipy.stats as st
239+
import arviz as az
240+
plt.style.use('arviz-darkgrid')
241+
x = np.linspace(0, 10, 200)
242+
for df in [1, 2, 3, 6, 9]:
243+
pdf = st.chi.pdf(x, df)
244+
plt.plot(x, pdf, label=r'$\nu$ = {}'.format(df))
245+
plt.xlabel('x', fontsize=12)
246+
plt.ylabel('f(x)', fontsize=12)
247+
plt.legend(loc=1)
248+
plt.show()
249+
250+
======== =========================================================================
251+
Support :math:`x \in [0, \infty)`
252+
Mean :math:`\sqrt{2}\frac{\Gamma((\nu + 1)/2)}{\Gamma(\nu/2)}`
253+
Variance :math:`\nu - 2\left(\frac{\Gamma((\nu + 1)/2)}{\Gamma(\nu/2)}\right)^2`
254+
======== =========================================================================
255+
256+
Parameters
257+
----------
258+
nu : tensor_like of float
259+
Degrees of freedom (nu > 0).
260+
261+
Examples
262+
--------
263+
.. code-block:: python
264+
import pymc as pm
265+
from pymc_experimental.distributions import Chi
266+
267+
with pm.Model():
268+
x = Chi('x', nu=1)
269+
"""
270+
271+
@staticmethod
272+
def chi_dist(nu: TensorVariable, size: TensorVariable) -> TensorVariable:
273+
return pt.math.sqrt(ChiSquared.dist(nu=nu, size=size))
274+
275+
def __new__(cls, name, nu, **kwargs):
276+
if "observed" not in kwargs:
277+
kwargs.setdefault("transform", transforms.log)
278+
return CustomDist(name, nu, dist=cls.chi_dist, class_name="Chi", **kwargs)
279+
280+
@classmethod
281+
def dist(cls, nu, **kwargs):
282+
return CustomDist.dist(nu, dist=cls.chi_dist, class_name="Chi", **kwargs)

pymc_experimental/tests/distributions/test_continuous.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
BaseTestDistributionRandom,
2727
Domain,
2828
R,
29+
Rplus,
2930
Rplusbig,
3031
assert_moment_is_expected,
3132
check_logcdf,
@@ -35,7 +36,7 @@
3536
)
3637

3738
# the distributions to be tested
38-
from pymc_experimental.distributions import GenExtreme
39+
from pymc_experimental.distributions import Chi, GenExtreme
3940

4041

4142
class TestGenExtremeClass:
@@ -149,3 +150,26 @@ class TestGenExtreme(BaseTestDistributionRandom):
149150
"check_pymc_draws_match_reference",
150151
"check_rv_size",
151152
]
153+
154+
155+
class TestChiClass:
156+
"""
157+
Wrapper class so that tests of experimental additions can be dropped into
158+
PyMC directly on adoption.
159+
"""
160+
161+
def test_logp(self):
162+
check_logp(
163+
Chi,
164+
Rplus,
165+
{"nu": Rplus},
166+
lambda value, nu: sp.chi.logpdf(value, df=nu),
167+
)
168+
169+
def test_logcdf(self):
170+
check_logcdf(
171+
Chi,
172+
Rplus,
173+
{"nu": Rplus},
174+
lambda value, nu: sp.chi.logcdf(value, df=nu),
175+
)

0 commit comments

Comments
 (0)