-
-
Notifications
You must be signed in to change notification settings - Fork 62
Genextreme #84
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Genextreme #84
Changes from 8 commits
Commits
Show all changes
15 commits
Select commit
Hold shift + click to select a range
f1a67b9
Initial commit
ccaprani 959677f
Updated structure; code operational
ccaprani 929bdfa
Change scipy.ststats.distribution import
ccaprani 8cfc35d
Added moment test
ccaprani 85c3782
Update per #84 review; tests, and dist code
ccaprani c43423f
Updated tests and class per comments
ccaprani 13f9ade
Tweaks post-review
ccaprani b181815
Test tweaks and all pass
ccaprani 14c8252
Resolve unnecessary check
ccaprani 8471fca
Reverse constraints msg
ccaprani e90be67
Clearing precommit requirements
ccaprani 306ce90
Skip tests on float32 due to underflow of pymc vs scipy
ccaprani 2301015
Improve test grid for parameters
ccaprani 0022a2f
Revert mu test domain to R
ccaprani 544cf42
Separate logp and logcdf tests
ricardoV94 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,26 @@ | ||
from pymc_experimental.distributions import histogram_utils | ||
from pymc_experimental.distributions.histogram_utils import histogram_approximation | ||
# Copyright 2022 The PyMC Developers | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
# coding: utf-8 | ||
""" | ||
Experimental probability distributions for stochastic nodes in PyMC. | ||
""" | ||
|
||
from pymc_experimental.distributions.continuous import ( | ||
GenExtreme, | ||
) | ||
|
||
__all__ = [ | ||
"GenExtreme", | ||
] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,236 @@ | ||
# Copyright 2022 The PyMC Developers | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
# coding: utf-8 | ||
""" | ||
Experimental probability distributions for stochastic nodes in PyMC. | ||
|
||
The imports from pymc are not fully replicated here: add imports as necessary. | ||
""" | ||
|
||
from typing import List, Tuple, Union | ||
import aesara | ||
import aesara.tensor as at | ||
import numpy as np | ||
from scipy import stats | ||
|
||
from aesara.tensor.random.op import RandomVariable | ||
from pymc.distributions.distribution import Continuous | ||
from aesara.tensor.var import TensorVariable | ||
from pymc.aesaraf import floatX | ||
from pymc.distributions.dist_math import check_parameters | ||
from pymc.distributions.shape_utils import rv_size_is_none | ||
|
||
|
||
class GenExtremeRV(RandomVariable): | ||
name: str = "Generalized Extreme Value" | ||
ndim_supp: int = 0 | ||
ndims_params: List[int] = [0, 0, 0] | ||
dtype: str = "floatX" | ||
_print_name: Tuple[str, str] = ("Generalized Extreme Value", "\\operatorname{GEV}") | ||
|
||
def __call__( | ||
self, mu=0.0, sigma=1.0, xi=0.0, size=None, **kwargs | ||
) -> TensorVariable: | ||
return super().__call__(mu, sigma, xi, size=size, **kwargs) | ||
|
||
@classmethod | ||
def rng_fn( | ||
cls, | ||
rng: Union[np.random.RandomState, np.random.Generator], | ||
mu: np.ndarray, | ||
sigma: np.ndarray, | ||
xi: np.ndarray, | ||
size: Tuple[int, ...], | ||
) -> np.ndarray: | ||
# Notice negative here, since remainder of GenExtreme is based on Coles parametrization | ||
return stats.genextreme.rvs( | ||
c=-xi, loc=mu, scale=sigma, random_state=rng, size=size | ||
) | ||
|
||
|
||
gev = GenExtremeRV() | ||
|
||
|
||
class GenExtreme(Continuous): | ||
r""" | ||
Univariate Generalized Extreme Value log-likelihood | ||
|
||
The cdf of this distribution is | ||
|
||
.. math:: | ||
|
||
G(x \mid \mu, \sigma, \xi) = \exp\left[ -\left(1 + \xi z\right)^{-\frac{1}{\xi}} \right] | ||
|
||
where | ||
|
||
.. math:: | ||
|
||
z = \frac{x - \mu}{\sigma} | ||
|
||
and is defined on the set: | ||
|
||
.. math:: | ||
|
||
\left\{x: 1 + \xi\left(\frac{x-\mu}{\sigma}\right) > 0 \right\}. | ||
|
||
Note that this parametrization is per Coles (2001), and differs from that of | ||
Scipy in the sign of the shape parameter, :math:`\xi`. | ||
|
||
.. plot:: | ||
|
||
import matplotlib.pyplot as plt | ||
import numpy as np | ||
import scipy.stats as st | ||
import arviz as az | ||
plt.style.use('arviz-darkgrid') | ||
x = np.linspace(-10, 20, 200) | ||
mus = [0., 4., -1.] | ||
sigmas = [2., 2., 4.] | ||
xis = [-0.3, 0.0, 0.3] | ||
for mu, sigma, xi in zip(mus, sigmas, xis): | ||
pdf = st.genextreme.pdf(x, c=-xi, loc=mu, scale=sigma) | ||
plt.plot(x, pdf, label=rf'$\mu$ = {mu}, $\sigma$ = {sigma}, $\xi$={xi}') | ||
plt.xlabel('x', fontsize=12) | ||
plt.ylabel('f(x)', fontsize=12) | ||
plt.legend(loc=1) | ||
plt.show() | ||
|
||
|
||
======== ========================================================================= | ||
Support * :math:`x \in [\mu - \sigma/\xi, +\infty]`, when :math:`\xi > 0` | ||
* :math:`x \in \mathbb{R}` when :math:`\xi = 0` | ||
* :math:`x \in [-\infty, \mu - \sigma/\xi]`, when :math:`\xi < 0` | ||
Mean * :math:`\mu + \sigma(g_1 - 1)/\xi`, when :math:`\xi \neq 0, \xi < 1` | ||
* :math:`\mu + \sigma \gamma`, when :math:`\xi = 0` | ||
* :math:`\infty`, when :math:`\xi \geq 1` | ||
where :math:`\gamma` is the Euler-Mascheroni constant, and | ||
:math:`g_k = \Gamma (1-k\xi)` | ||
Variance * :math:`\sigma^2 (g_2 - g_1^2)/\xi^2`, when :math:`\xi \neq 0, \xi < 0.5` | ||
* :math:`\frac{\pi^2}{6} \sigma^2`, when :math:`\xi = 0` | ||
* :math:`\infty`, when :math:`\xi \geq 0.5` | ||
======== ========================================================================= | ||
|
||
Parameters | ||
---------- | ||
mu: float | ||
Location parameter. | ||
sigma: float | ||
Scale parameter (sigma > 0). | ||
xi: float | ||
Shape parameter | ||
scipy: bool | ||
Whether or not to use the Scipy interpretation of the shape parameter | ||
(defaults to `False`). | ||
|
||
References | ||
---------- | ||
.. [Coles2001] Coles, S.G. (2001). | ||
An Introduction to the Statistical Modeling of Extreme Values | ||
Springer-Verlag, London | ||
|
||
""" | ||
|
||
rv_op = gev | ||
|
||
@classmethod | ||
def dist(cls, mu=0, sigma=1, xi=0, scipy=False, **kwargs): | ||
# If SciPy, use its parametrization, otherwise convert to standard | ||
if scipy: | ||
xi = -xi | ||
mu = at.as_tensor_variable(floatX(mu)) | ||
sigma = at.as_tensor_variable(floatX(sigma)) | ||
xi = at.as_tensor_variable(floatX(xi)) | ||
|
||
return super().dist([mu, sigma, xi], **kwargs) | ||
|
||
def logp(value, mu, sigma, xi): | ||
""" | ||
Calculate log-probability of Generalized Extreme Value distribution | ||
at specified value. | ||
|
||
Parameters | ||
---------- | ||
value: numeric | ||
Value(s) for which log-probability is calculated. If the log probabilities for multiple | ||
values are desired the values must be provided in a numpy array or Aesara tensor | ||
|
||
Returns | ||
------- | ||
TensorVariable | ||
""" | ||
scaled = (value - mu) / sigma | ||
|
||
logp_expression = at.switch( | ||
at.isclose(xi, 0), | ||
-at.log(sigma) - scaled - at.exp(-scaled), | ||
-at.log(sigma) | ||
- ((xi + 1) / xi) * at.log1p(xi * scaled) | ||
- at.pow(1 + xi * scaled, -1 / xi), | ||
) | ||
|
||
logp = at.switch( | ||
at.gt(1 + xi * scaled, 0.0), | ||
logp_expression, | ||
-np.inf) | ||
|
||
return check_parameters( | ||
logp, | ||
sigma > 0, | ||
1 + xi * scaled > 0, | ||
ricardoV94 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
at.and_(xi > -1, xi < 1), | ||
msg="sigma <= 0 or 1+xi*(x-mu)/sigma <= 0") | ||
ricardoV94 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
def logcdf(value, mu, sigma, xi): | ||
""" | ||
Compute the log of the cumulative distribution function for Generalized Extreme Value | ||
distribution at the specified value. | ||
|
||
Parameters | ||
---------- | ||
value: numeric or np.ndarray or `TensorVariable` | ||
Value(s) for which log CDF is calculated. If the log CDF for | ||
multiple values are desired the values must be provided in a numpy | ||
array or `TensorVariable`. | ||
|
||
Returns | ||
------- | ||
TensorVariable | ||
""" | ||
scaled = (value - mu) / sigma | ||
logc_expression = at.switch( | ||
at.isclose(xi, 0), -at.exp(-scaled), -at.pow(1 + xi * scaled, -1 / xi) | ||
) | ||
|
||
logc = at.switch( | ||
1 + xi * (value - mu) / sigma > 0, | ||
logc_expression, | ||
-np.inf) | ||
|
||
return check_parameters(logc, | ||
sigma > 0, | ||
1 + xi * scaled > 0, | ||
at.and_(xi > -1, xi < 1), | ||
msg="sigma <= 0 or 1+xi*(x-mu)/sigma <= 0") | ||
|
||
def moment(rv, size, mu, sigma, xi): | ||
r""" | ||
Using the mode, as the mean can be infinite when :math:`\xi > 1` | ||
""" | ||
mode = at.switch( | ||
at.isclose(xi, 0), mu, mu + sigma * (at.pow(1 + xi, -xi) - 1) / xi | ||
) | ||
if not rv_size_is_none(size): | ||
mode = at.full(size, mode) | ||
return mode |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from pymc_experimental.distributions import histogram_utils | ||
from pymc_experimental.distributions.histogram_utils import histogram_approximation |
119 changes: 119 additions & 0 deletions
119
pymc_experimental/tests/distributions/test_continuous.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,119 @@ | ||
|
||
# Copyright 2020 The PyMC Developers | ||
ccaprani marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
# general imports | ||
import numpy as np | ||
import pytest | ||
import scipy.stats as st | ||
import scipy.stats.distributions as sp | ||
|
||
import pymc as pm | ||
|
||
# test support imports from pymc | ||
from pymc.tests.distributions.util import ( | ||
BaseTestDistributionRandom, | ||
R, | ||
Rplus, | ||
Domain, | ||
check_logp, | ||
check_logcdf, | ||
assert_moment_is_expected, | ||
seeded_scipy_distribution_builder, | ||
) | ||
|
||
# the distributions to be tested | ||
from pymc_experimental.distributions import ( | ||
GenExtreme, | ||
) | ||
|
||
class TestGenExtremeClass: | ||
""" | ||
Wrapper class so that tests of experimental additions can be dropped into | ||
PyMC directly on adoption. | ||
|
||
pm.logp(GenExtreme.dist(mu=0.,sigma=1.,xi=0.5),value=-0.01) | ||
""" | ||
|
||
def test_genextreme(self): | ||
check_logp( | ||
GenExtreme, | ||
R, | ||
{"mu": R, "sigma": Rplus, "xi": Domain([-1, -0.99, -0.5, 0, 0.5, 0.99, 1])}, | ||
lambda value, mu, sigma, xi: sp.genextreme.logpdf(value, c=-xi, loc=mu, scale=sigma) | ||
if 1 + xi*(value-mu)/sigma > 0 else -np.inf | ||
) | ||
check_logcdf( | ||
GenExtreme, | ||
R, | ||
{"mu": R, "sigma": Rplus, "xi": Domain([-1, -0.99, -0.5, 0, 0.5, 0.99, 1])}, | ||
lambda value, mu, sigma, xi: sp.genextreme.logcdf(value, c=-xi, loc=mu, scale=sigma) | ||
if 1 + xi*(value-mu)/sigma > 0 else -np.inf | ||
) | ||
|
||
@pytest.mark.parametrize( | ||
"mu, sigma, xi, size, expected", | ||
[ | ||
(0, 1, 0, None, 0), | ||
(1, np.arange(1, 4), 0.1, None, 1 + np.arange(1, 4) * (1.1 ** -0.1 - 1) / 0.1), | ||
(np.arange(5), 1, 0.1, None, np.arange(5) + (1.1 ** -0.1 - 1) / 0.1), | ||
( | ||
0, | ||
1, | ||
np.linspace(-0.2, 0.2, 6), | ||
None, | ||
((1 + np.linspace(-0.2, 0.2, 6)) ** -np.linspace(-0.2, 0.2, 6) - 1) | ||
/ np.linspace(-0.2, 0.2, 6), | ||
), | ||
(1, 2, 0.1, 5, np.full(5, 1 + 2 * (1.1 ** -0.1 - 1) / 0.1)), | ||
( | ||
np.arange(6), | ||
np.arange(1, 7), | ||
np.linspace(-0.2, 0.2, 6), | ||
(3, 6), | ||
np.full( | ||
(3, 6), | ||
np.arange(6) | ||
+ np.arange(1, 7) | ||
* ((1 + np.linspace(-0.2, 0.2, 6)) ** -np.linspace(-0.2, 0.2, 6) - 1) | ||
/ np.linspace(-0.2, 0.2, 6), | ||
), | ||
), | ||
], | ||
) | ||
def test_genextreme_moment(self, mu, sigma, xi, size, expected): | ||
with pm.Model() as model: | ||
GenExtreme("x", mu=mu, sigma=sigma, xi=xi, size=size) | ||
assert_moment_is_expected(model, expected) | ||
|
||
def test_gen_extreme_scipy_kwarg(self): | ||
dist = GenExtreme.dist(xi=1, scipy=False) | ||
assert dist.owner.inputs[-1].eval() == 1 | ||
|
||
dist = GenExtreme.dist(xi=1, scipy=True) | ||
assert dist.owner.inputs[-1].eval() == -1 | ||
|
||
|
||
class TestGenExtreme(BaseTestDistributionRandom): | ||
pymc_dist = GenExtreme | ||
pymc_dist_params = {"mu": 0, "sigma": 1, "xi": -0.1} | ||
expected_rv_op_params = {"mu": 0, "sigma": 1, "xi": -0.1} | ||
# Notice, using different parametrization of xi sign to scipy | ||
reference_dist_params = {"loc": 0, "scale": 1, "c": 0.1} | ||
reference_dist = seeded_scipy_distribution_builder("genextreme") | ||
tests_to_run = [ | ||
"check_pymc_params_match_rv_op", | ||
"check_pymc_draws_match_reference", | ||
"check_rv_size", | ||
] |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.