|
2 | 2 |
|
3 | 3 | import numpy as np
|
4 | 4 | import pytest
|
| 5 | +import scipy.special |
5 | 6 |
|
| 7 | +import pytensor |
6 | 8 | import pytensor.tensor as at
|
7 | 9 | import pytensor.tensor.inplace as ati
|
8 | 10 | import pytensor.tensor.math as aem
|
@@ -532,3 +534,24 @@ def test_MaxAndArgmax(x, axes, exc):
|
532 | 534 | if not isinstance(i, (SharedVariable, Constant))
|
533 | 535 | ],
|
534 | 536 | )
|
| 537 | + |
| 538 | + |
| 539 | +@pytest.mark.parametrize("size", [(10, 10), (1000, 1000), (10000, 10000)]) |
| 540 | +@pytest.mark.parametrize("axis", [0, 1]) |
| 541 | +def test_logsumexp_benchmark(size, axis, benchmark): |
| 542 | + |
| 543 | + X = at.matrix("X") |
| 544 | + X_max = at.max(X, axis=axis, keepdims=True) |
| 545 | + X_max = at.switch(at.isinf(X_max), 0, X_max) |
| 546 | + X_lse = at.log(at.sum(at.exp(X - X_max), axis=axis, keepdims=True)) + X_max |
| 547 | + |
| 548 | + rng = np.random.default_rng(23920) |
| 549 | + X_val = rng.normal(size=size) |
| 550 | + |
| 551 | + X_lse_fn = pytensor.function([X], X_lse, mode="JAX") |
| 552 | + |
| 553 | + # JIT compile first |
| 554 | + _ = X_lse_fn(X_val) |
| 555 | + res = benchmark(X_lse_fn, X_val) |
| 556 | + exp_res = scipy.special.logsumexp(X_val, axis=axis, keepdims=True) |
| 557 | + np.testing.assert_array_almost_equal(res, exp_res) |
0 commit comments