Skip to content

Commit 6d8f569

Browse files
authored
Implement Maxwell Distribution (#261)
1 parent b521e64 commit 6d8f569

File tree

4 files changed

+98
-2
lines changed

4 files changed

+98
-2
lines changed

docs/api_reference.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ Distributions
2929
:toctree: generated/
3030

3131
Chi
32+
Maxwell
3233
DiscreteMarkovChain
3334
GeneralizedPoisson
3435
BetaNegativeBinomial

pymc_experimental/distributions/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
Experimental probability distributions for stochastic nodes in PyMC.
1818
"""
1919

20-
from pymc_experimental.distributions.continuous import Chi, GenExtreme
20+
from pymc_experimental.distributions.continuous import Chi, GenExtreme, Maxwell
2121
from pymc_experimental.distributions.discrete import (
2222
BetaNegativeBinomial,
2323
GeneralizedPoisson,
@@ -36,4 +36,5 @@
3636
"Skellam",
3737
"histogram_approximation",
3838
"Chi",
39+
"Maxwell",
3940
]

pymc_experimental/distributions/continuous.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from pymc.distributions.dist_math import check_parameters
2929
from pymc.distributions.distribution import Continuous
3030
from pymc.distributions.shape_utils import rv_size_is_none
31+
from pymc.logprob.utils import CheckParameterValue
3132
from pymc.pytensorf import floatX
3233
from pytensor.tensor.random.op import RandomVariable
3334
from pytensor.tensor.variable import TensorVariable
@@ -280,3 +281,73 @@ def __new__(cls, name, nu, **kwargs):
280281
@classmethod
281282
def dist(cls, nu, **kwargs):
282283
return CustomDist.dist(nu, dist=cls.chi_dist, class_name="Chi", **kwargs)
284+
285+
286+
class Maxwell:
287+
R"""
288+
The Maxwell-Boltzmann distribution
289+
290+
The pdf of this distribution is
291+
292+
.. math::
293+
294+
f(x \mid a) = {\displaystyle {\sqrt {\frac {2}{\pi }}}\,{\frac {x^{2}}{a^{3}}}\,\exp \left({\frac {-x^{2}}{2a^{2}}}\right)}
295+
296+
Read more about it on `Wikipedia <https://en.wikipedia.org/wiki/Maxwell%E2%80%93Boltzmann_distribution>`_
297+
298+
.. plot::
299+
:context: close-figs
300+
301+
import matplotlib.pyplot as plt
302+
import numpy as np
303+
import scipy.stats as st
304+
import arviz as az
305+
plt.style.use('arviz-darkgrid')
306+
x = np.linspace(0, 20, 200)
307+
for a in [1, 2, 5]:
308+
pdf = st.maxwell.pdf(x, scale=a)
309+
plt.plot(x, pdf, label=r'$a$ = {}'.format(a))
310+
plt.xlabel('x', fontsize=12)
311+
plt.ylabel('f(x)', fontsize=12)
312+
plt.legend(loc=1)
313+
plt.show()
314+
315+
======== =========================================================================
316+
Support :math:`x \in (0, \infty)`
317+
Mean :math:`2a \sqrt{\frac{2}{\pi}}`
318+
Variance :math:`\frac{a^2(3 \pi - 8)}{\pi}`
319+
======== =========================================================================
320+
321+
Parameters
322+
----------
323+
a : tensor_like of float
324+
Scale parameter (a > 0).
325+
326+
"""
327+
328+
@staticmethod
329+
def maxwell_dist(a: TensorVariable, size: TensorVariable) -> TensorVariable:
330+
if rv_size_is_none(size):
331+
size = a.shape
332+
333+
a = CheckParameterValue("a > 0")(a, pt.all(pt.gt(a, 0)))
334+
335+
return Chi.dist(nu=3, size=size) * a
336+
337+
def __new__(cls, name, a, **kwargs):
338+
return CustomDist(
339+
name,
340+
a,
341+
dist=cls.maxwell_dist,
342+
class_name="Maxwell",
343+
**kwargs,
344+
)
345+
346+
@classmethod
347+
def dist(cls, a, **kwargs):
348+
return CustomDist.dist(
349+
a,
350+
dist=cls.maxwell_dist,
351+
class_name="Maxwell",
352+
**kwargs,
353+
)

pymc_experimental/tests/distributions/test_continuous.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
)
3434

3535
# the distributions to be tested
36-
from pymc_experimental.distributions import Chi, GenExtreme
36+
from pymc_experimental.distributions import Chi, GenExtreme, Maxwell
3737

3838

3939
class TestGenExtremeClass:
@@ -159,3 +159,26 @@ def test_logcdf(self):
159159
{"nu": Rplus},
160160
lambda value, nu: sp.chi.logcdf(value, df=nu),
161161
)
162+
163+
164+
class TestMaxwell:
165+
"""
166+
Wrapper class so that tests of experimental additions can be dropped into
167+
PyMC directly on adoption.
168+
"""
169+
170+
def test_logp(self):
171+
check_logp(
172+
Maxwell,
173+
Rplus,
174+
{"a": Rplus},
175+
lambda value, a: sp.maxwell.logpdf(value, scale=a),
176+
)
177+
178+
def test_logcdf(self):
179+
check_logcdf(
180+
Maxwell,
181+
Rplus,
182+
{"a": Rplus},
183+
lambda value, a: sp.maxwell.logcdf(value, scale=a),
184+
)

0 commit comments

Comments
 (0)