Skip to content

Commit 842bc52

Browse files
committed
Add benchmark tests for fused Elemwises
1 parent 2f94d1a commit 842bc52

File tree

2 files changed

+29
-0
lines changed

2 files changed

+29
-0
lines changed

tests/link/numba/test_elemwise.py

+16
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from pytensor import config, function
1212
from pytensor.compile.ops import deep_copy_op
1313
from pytensor.compile.sharedvalue import SharedVariable
14+
from pytensor.gradient import grad
1415
from pytensor.graph.basic import Constant
1516
from pytensor.graph.fg import FunctionGraph
1617
from pytensor.tensor import elemwise as at_elemwise
@@ -555,3 +556,18 @@ def test_logsumexp_benchmark(size, axis, benchmark):
555556
res = benchmark(X_lse_fn, X_val)
556557
exp_res = scipy.special.logsumexp(X_val, axis=axis, keepdims=True)
557558
np.testing.assert_array_almost_equal(res, exp_res)
559+
560+
561+
def test_fused_elemwise_benchmark(benchmark):
562+
rng = np.random.default_rng(123)
563+
size = 100_000
564+
x = pytensor.shared(rng.normal(size=size), name="x")
565+
mu = pytensor.shared(rng.normal(size=size), name="mu")
566+
567+
logp = -((x - mu) ** 2) / 2
568+
grad_logp = grad(logp.sum(), x)
569+
570+
func = pytensor.function([], [logp, grad_logp], mode="NUMBA")
571+
# JIT compile first
572+
func()
573+
benchmark(func)

tests/tensor/rewriting/test_elemwise.py

+13
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from pytensor.compile.function import function
1010
from pytensor.compile.mode import Mode, get_default_mode
1111
from pytensor.configdefaults import config
12+
from pytensor.gradient import grad
1213
from pytensor.graph.basic import Constant
1314
from pytensor.graph.fg import FunctionGraph
1415
from pytensor.graph.rewriting.basic import check_stack_trace, out2in
@@ -1349,6 +1350,18 @@ def test_multiple_outputs_fused_root_elemwise(self):
13491350
assert len(nodes) == 1
13501351
assert isinstance(nodes[0].op.scalar_op, Composite)
13511352

1353+
def test_eval_benchmark(self, benchmark):
1354+
rng = np.random.default_rng(123)
1355+
size = 100_000
1356+
x = pytensor.shared(rng.normal(size=size), name="x")
1357+
mu = pytensor.shared(rng.normal(size=size), name="mu")
1358+
1359+
logp = -((x - mu) ** 2) / 2
1360+
grad_logp = grad(logp.sum(), x)
1361+
1362+
func = pytensor.function([], [logp, grad_logp], mode="FAST_RUN")
1363+
benchmark(func)
1364+
13521365

13531366
class TimesN(aes.basic.UnaryScalarOp):
13541367
"""

0 commit comments

Comments
 (0)