Skip to content

Commit 0ff1675

Browse files
committed
attempt to implement Chi distribution helper
1 parent 4500708 commit 0ff1675

File tree

4 files changed

+118
-2
lines changed

4 files changed

+118
-2
lines changed

docs/api_reference.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ Distributions
3333
DiscreteMarkovChain
3434
R2D2M2CP
3535
histogram_approximation
36+
Chi
3637

3738

3839
Model Transformations

pymc_experimental/distributions/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
Experimental probability distributions for stochastic nodes in PyMC.
1818
"""
1919

20-
from pymc_experimental.distributions.continuous import GenExtreme
20+
from pymc_experimental.distributions.continuous import Chi, GenExtreme
2121
from pymc_experimental.distributions.discrete import GeneralizedPoisson
2222
from pymc_experimental.distributions.histogram_utils import histogram_approximation
2323
from pymc_experimental.distributions.multivariate import R2D2M2CP
@@ -29,4 +29,5 @@
2929
"GenExtreme",
3030
"R2D2M2CP",
3131
"histogram_approximation",
32+
"Chi",
3233
]

pymc_experimental/distributions/continuous.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323

2424
import numpy as np
2525
import pytensor.tensor as pt
26+
from pymc import ChiSquared, CustomDist
27+
from pymc.distributions import transforms
2628
from pymc.distributions.dist_math import check_parameters
2729
from pymc.distributions.distribution import Continuous
2830
from pymc.distributions.shape_utils import rv_size_is_none
@@ -216,3 +218,63 @@ def moment(rv, size, mu, sigma, xi):
216218
if not rv_size_is_none(size):
217219
mode = pt.full(size, mode)
218220
return mode
221+
222+
223+
def chi_dist(nu: TensorVariable, size: TensorVariable) -> TensorVariable:
224+
return pt.math.sqrt(ChiSquared.dist(nu=nu, size=size))
225+
226+
227+
class Chi:
228+
r"""
229+
:math:`\chi` log-likelihood.
230+
231+
The pdf of this distribution is
232+
233+
.. math::
234+
235+
f(x \mid \nu) = \frac{x^{\nu - 1}e^{-x^2/2}}{2^{\nu/2 - 1}\Gamma(\nu/2)}
236+
237+
.. plot::
238+
:context: close-figs
239+
240+
import matplotlib.pyplot as plt
241+
import numpy as np
242+
import scipy.stats as st
243+
import arviz as az
244+
plt.style.use('arviz-darkgrid')
245+
x = np.linspace(0, 10, 200)
246+
for df in [1, 2, 3, 6, 9]:
247+
pdf = st.chi.pdf(x, df)
248+
plt.plot(x, pdf, label=r'$\nu$ = {}'.format(df))
249+
plt.xlabel('x', fontsize=12)
250+
plt.ylabel('f(x)', fontsize=12)
251+
plt.legend(loc=1)
252+
plt.show()
253+
254+
======== =========================================================================
255+
Support :math:`x \in [0, \infty)`
256+
Mean :math:`\sqrt{2}\frac{\Gamma((\nu + 1)/2)}{\Gamma(\nu/2)}`
257+
Variance :math:`\nu - 2\left(\frac{\Gamma((\nu + 1)/2)}{\Gamma(\nu/2)}\right)^2`
258+
======== =========================================================================
259+
260+
Parameters
261+
----------
262+
nu : tensor_like of float
263+
Degrees of freedom (nu > 0).
264+
265+
Examples
266+
--------
267+
.. code-block:: python
268+
269+
with pm.Model():
270+
x = pm.Chi('x', nu=1)
271+
"""
272+
273+
def __new__(cls, name, nu, **kwargs):
274+
if "observed" not in kwargs:
275+
kwargs.setdefault("transform", transforms.log)
276+
return CustomDist(name, (nu,), dist=chi_dist, class_name="Chi", **kwargs)
277+
278+
@classmethod
279+
def dist(cls, nu, **kwargs):
280+
return CustomDist.dist(nu, dist=chi_dist, class_name="Chi", **kwargs)

pymc_experimental/tests/distributions/test_continuous.py

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
BaseTestDistributionRandom,
2727
Domain,
2828
R,
29+
Rplus,
2930
Rplusbig,
3031
assert_moment_is_expected,
3132
check_logcdf,
@@ -35,7 +36,7 @@
3536
)
3637

3738
# the distributions to be tested
38-
from pymc_experimental.distributions import GenExtreme
39+
from pymc_experimental.distributions import Chi, GenExtreme
3940

4041

4142
class TestGenExtremeClass:
@@ -149,3 +150,54 @@ class TestGenExtreme(BaseTestDistributionRandom):
149150
"check_pymc_draws_match_reference",
150151
"check_rv_size",
151152
]
153+
154+
155+
class TestChiClass:
156+
"""
157+
Wrapper class so that tests of experimental additions can be dropped into
158+
PyMC directly on adoption.
159+
"""
160+
161+
def test_logp(self):
162+
check_logp(
163+
Chi,
164+
Rplus,
165+
{"nu": Rplus},
166+
lambda value, nu: sp.chi.logpdf(value, df=nu),
167+
)
168+
169+
def test_logcdf(self):
170+
check_logcdf(
171+
Chi,
172+
Rplus,
173+
{"nu": Rplus},
174+
lambda value, nu: sp.chi.logcdf(value, df=nu),
175+
)
176+
177+
"""
178+
@pytest.mark.parametrize(
179+
"nu, size, expected",
180+
[
181+
(1, None, 1),
182+
(1, 5, np.full(5, 1)),
183+
(np.arange(1, 6), None, np.arange(1, 6)),
184+
],
185+
)
186+
def test_chi_moment(self, nu, size, expected):
187+
with pm.Model() as model:
188+
Chi("x", nu=nu, size=size)
189+
assert_moment_is_expected(model, expected)
190+
"""
191+
192+
193+
class TestChi(BaseTestDistributionRandom):
194+
pymc_dist = Chi
195+
pymc_dist_params = {"nu": 3.0}
196+
expected_rv_op_params = {"nu": 3.0}
197+
reference_dist_params = {"df": 3.0}
198+
reference_dist = seeded_scipy_distribution_builder("chi")
199+
tests_to_run = [
200+
"check_pymc_params_match_rv_op",
201+
"check_pymc_draws_match_reference",
202+
"check_rv_size",
203+
]

0 commit comments

Comments
 (0)