Skip to content

Commit 0e9b9a4

Browse files
authored
Improve logsumexp to work with infinite values (#4360)
* Make logsumexp work with inifinite values, matching scipy behavior * Run pre-commit * Add note to release_notes
1 parent 34447a7 commit 0e9b9a4

File tree

3 files changed

+29
-1
lines changed

3 files changed

+29
-1
lines changed

RELEASE-NOTES.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ This is the first release to support Python3.9 and to drop Python3.6.
1313
- Removed `theanof.set_theano_config` because it illegally touched Theano's privates (see [#4329](https://github.com/pymc-devs/pymc3/pull/4329)).
1414
- In `sample_posterior_predictive` the `vars` kwarg was removed in favor of `var_names` (see [#4343](https://github.com/pymc-devs/pymc3/pull/4343)).
1515
- The notebook gallery has been moved to https://github.com/pymc-devs/pymc-examples (see [#4348](https://github.com/pymc-devs/pymc3/pull/4348)).
16-
16+
- `math.logsumexp` now matches `scipy.special.logsumexp` when arrays contain infinite values (see [#4360](https://github.com/pymc-devs/pymc3/pull/4360)).
1717

1818
## PyMC3 3.10.0 (7 December 2020)
1919

pymc3/math.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,7 @@ def tround(*args, **kwargs):
175175
def logsumexp(x, axis=None, keepdims=True):
176176
# Adapted from https://github.com/Theano/Theano/issues/1563
177177
x_max = tt.max(x, axis=axis, keepdims=True)
178+
x_max = tt.switch(tt.isinf(x_max), 0, x_max)
178179
res = tt.log(tt.sum(tt.exp(x - x_max), axis=axis, keepdims=True)) + x_max
179180
return res if keepdims else res.squeeze()
180181

pymc3/tests/test_math.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
import theano
1919
import theano.tensor as tt
2020

21+
from scipy.special import logsumexp as scipy_logsumexp
22+
2123
from pymc3.math import (
2224
LogDet,
2325
cartesian,
@@ -30,6 +32,7 @@
3032
log1mexp_numpy,
3133
log1pexp,
3234
logdet,
35+
logsumexp,
3336
probit,
3437
)
3538
from pymc3.tests.helpers import SeededTest, verify_grad
@@ -207,3 +210,27 @@ def test_expand_packed_triangular():
207210
assert np.all(expand_upper.eval({packed: upper_packed}) == upper)
208211
assert np.all(expand_diag_lower.eval({packed: lower_packed}) == floatX(np.diag(vals)))
209212
assert np.all(expand_diag_upper.eval({packed: upper_packed}) == floatX(np.diag(vals)))
213+
214+
215+
@pytest.mark.parametrize(
216+
"values, axis, keepdims",
217+
[
218+
(np.array([-4, -2]), None, True),
219+
(np.array([-np.inf, -2]), None, True),
220+
(np.array([-2, np.inf]), None, True),
221+
(np.array([-np.inf, -np.inf]), None, True),
222+
(np.array([np.inf, np.inf]), None, True),
223+
(np.array([-np.inf, np.inf]), None, True),
224+
(np.array([[-np.inf, -np.inf], [-np.inf, -np.inf]]), None, True),
225+
(np.array([[-np.inf, -np.inf], [-np.inf, -np.inf]]), 0, True),
226+
(np.array([[-np.inf, -np.inf], [-np.inf, -np.inf]]), 1, True),
227+
(np.array([[-np.inf, -np.inf], [-np.inf, -np.inf]]), 0, False),
228+
(np.array([[-np.inf, -np.inf], [-np.inf, -np.inf]]), 1, False),
229+
(np.array([[-2, np.inf], [-np.inf, -np.inf]]), 0, True),
230+
],
231+
)
232+
def test_logsumexp(values, axis, keepdims):
233+
npt.assert_almost_equal(
234+
logsumexp(values, axis=axis, keepdims=keepdims).eval(),
235+
scipy_logsumexp(values, axis=axis, keepdims=keepdims),
236+
)

0 commit comments

Comments
 (0)