Skip to content

Commit 683479c

Browse files
committed
Add benchmark for numba elemwise
1 parent c0710c9 commit 683479c

File tree

1 file changed

+20
-1
lines changed

1 file changed

+20
-1
lines changed

tests/link/numba/test_elemwise.py

+20-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import pytensor.tensor as at
77
import pytensor.tensor.inplace as ati
88
import pytensor.tensor.math as aem
9-
from pytensor import config
9+
from pytensor import config, function
1010
from pytensor.compile.ops import deep_copy_op
1111
from pytensor.compile.sharedvalue import SharedVariable
1212
from pytensor.graph.basic import Constant
@@ -117,6 +117,25 @@ def test_Elemwise(inputs, input_vals, output_fn, exc):
117117
compare_numba_and_py(out_fg, input_vals)
118118

119119

120+
def test_elemwise_speed(benchmark):
121+
x = at.dmatrix("y")
122+
y = at.dvector("z")
123+
124+
out = np.exp(2 * x * y + y)
125+
126+
rng = np.random.default_rng(42)
127+
128+
x_val = rng.normal(size=(200, 500))
129+
y_val = rng.normal(size=500)
130+
131+
func = function([x, y], out, mode="NUMBA")
132+
func = func.vm.jit_fn
133+
(out,) = func(x_val, y_val)
134+
np.testing.assert_allclose(np.exp(2 * x_val * y_val + y_val), out)
135+
136+
benchmark(func, x_val, y_val)
137+
138+
120139
@pytest.mark.parametrize(
121140
"v, new_order",
122141
[

0 commit comments

Comments
 (0)