Skip to content

Commit f3a4f2b

Browse files
Seed logsumexp benchmark tests
Also adds missing numba benchmark test Co-authored-by: Brandon T. Willard <[email protected]>
1 parent b8831aa commit f3a4f2b

File tree

2 files changed

+25
-1
lines changed

2 files changed

+25
-1
lines changed

tests/link/jax/test_elemwise.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,8 @@ def test_logsumexp_benchmark(size, axis, benchmark):
111111
X_max = at.switch(at.isinf(X_max), 0, X_max)
112112
X_lse = at.log(at.sum(at.exp(X - X_max), axis=axis, keepdims=True)) + X_max
113113

114-
X_val = np.random.normal(size=size)
114+
rng = np.random.default_rng(23920)
115+
X_val = rng.normal(size=size)
115116

116117
X_lse_fn = pytensor.function([X], X_lse, mode="JAX")
117118

tests/link/numba/test_elemwise.py

+23
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@
22

33
import numpy as np
44
import pytest
5+
import scipy.special
56

7+
import pytensor
68
import pytensor.tensor as at
79
import pytensor.tensor.inplace as ati
810
import pytensor.tensor.math as aem
@@ -532,3 +534,24 @@ def test_MaxAndArgmax(x, axes, exc):
532534
if not isinstance(i, (SharedVariable, Constant))
533535
],
534536
)
537+
538+
539+
@pytest.mark.parametrize("size", [(10, 10), (1000, 1000), (10000, 10000)])
540+
@pytest.mark.parametrize("axis", [0, 1])
541+
def test_logsumexp_benchmark(size, axis, benchmark):
542+
543+
X = at.matrix("X")
544+
X_max = at.max(X, axis=axis, keepdims=True)
545+
X_max = at.switch(at.isinf(X_max), 0, X_max)
546+
X_lse = at.log(at.sum(at.exp(X - X_max), axis=axis, keepdims=True)) + X_max
547+
548+
rng = np.random.default_rng(23920)
549+
X_val = rng.normal(size=size)
550+
551+
X_lse_fn = pytensor.function([X], X_lse, mode="JAX")
552+
553+
# JIT compile first
554+
_ = X_lse_fn(X_val)
555+
res = benchmark(X_lse_fn, X_val)
556+
exp_res = scipy.special.logsumexp(X_val, axis=axis, keepdims=True)
557+
np.testing.assert_array_almost_equal(res, exp_res)

0 commit comments

Comments
 (0)