Skip to content

Commit 589ab7e

Browse files
committed
Add benchmark tests for fused Elemwises
1 parent 2b1956a commit 589ab7e

File tree

2 files changed

+29
-0
lines changed

2 files changed

+29
-0
lines changed

tests/link/numba/test_elemwise.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from pytensor import config, function
1212
from pytensor.compile.ops import deep_copy_op
1313
from pytensor.compile.sharedvalue import SharedVariable
14+
from pytensor.gradient import grad
1415
from pytensor.graph.basic import Constant
1516
from pytensor.graph.fg import FunctionGraph
1617
from pytensor.tensor import elemwise as at_elemwise
@@ -555,3 +556,18 @@ def test_logsumexp_benchmark(size, axis, benchmark):
555556
res = benchmark(X_lse_fn, X_val)
556557
exp_res = scipy.special.logsumexp(X_val, axis=axis, keepdims=True)
557558
np.testing.assert_array_almost_equal(res, exp_res)
559+
560+
561+
def test_fused_elemwise_benchmark(benchmark):
562+
rng = np.random.default_rng(123)
563+
size = 100_000
564+
x = pytensor.shared(rng.normal(size=size), name="x")
565+
mu = pytensor.shared(rng.normal(size=size), name="mu")
566+
567+
logp = -((x - mu) ** 2) / 2
568+
grad_logp = grad(logp.sum(), x)
569+
570+
func = pytensor.function([], [logp, grad_logp], mode="NUMBA")
571+
# JIT compile first
572+
func()
573+
benchmark(func)

tests/tensor/rewriting/test_elemwise.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from pytensor.compile.function import function
1010
from pytensor.compile.mode import Mode, get_default_mode
1111
from pytensor.configdefaults import config
12+
from pytensor.gradient import grad
1213
from pytensor.graph.basic import Constant
1314
from pytensor.graph.fg import FunctionGraph
1415
from pytensor.graph.rewriting.basic import check_stack_trace, out2in
@@ -1354,6 +1355,18 @@ def test_multiple_outputs_fused_root_elemwise(self):
13541355
assert len(nodes) == 1
13551356
assert isinstance(nodes[0].op.scalar_op, Composite)
13561357

1358+
def test_eval_benchmark(self, benchmark):
1359+
rng = np.random.default_rng(123)
1360+
size = 100_000
1361+
x = pytensor.shared(rng.normal(size=size), name="x")
1362+
mu = pytensor.shared(rng.normal(size=size), name="mu")
1363+
1364+
logp = -((x - mu) ** 2) / 2
1365+
grad_logp = grad(logp.sum(), x)
1366+
1367+
func = pytensor.function([], [logp, grad_logp], mode="FAST_RUN")
1368+
benchmark(func)
1369+
13571370

13581371
class TimesN(aes.basic.UnaryScalarOp):
13591372
"""

0 commit comments

Comments
 (0)