Skip to content

Commit 79d98f1

Browse files
committed
Add benchmark test for FusionRewriter
1 parent 842bc52 commit 79d98f1

File tree

1 file changed

+37
-28
lines changed

1 file changed

+37
-28
lines changed

tests/tensor/rewriting/test_elemwise.py

+37-28
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848
from pytensor.tensor.math import sin, sinh, sqr, sqrt
4949
from pytensor.tensor.math import sum as at_sum
5050
from pytensor.tensor.math import tan, tanh, true_div, xor
51-
from pytensor.tensor.rewriting.elemwise import local_dimshuffle_lift
51+
from pytensor.tensor.rewriting.elemwise import FusionOptimizer, local_dimshuffle_lift
5252
from pytensor.tensor.rewriting.shape import local_useless_dimshuffle_in_reshape
5353
from pytensor.tensor.shape import reshape
5454
from pytensor.tensor.type import (
@@ -302,6 +302,29 @@ def my_init(dtype="float64", num=0):
302302
fwx = fw + fx
303303
ftanx = tan(fx)
304304

305+
def large_fuseable_graph(self, n):
306+
factors = []
307+
sd = dscalar()
308+
means = dvector()
309+
310+
cst_05 = at.constant(0.5)
311+
cst_m05 = at.constant(-0.5)
312+
cst_2 = at.constant(2)
313+
cst_m2 = at.constant(-2)
314+
ones = at.constant(np.ones(10))
315+
316+
for i in range(n):
317+
f = cst_m05 * sd**cst_m2 * (ones - means[i]) ** cst_2 + cst_05 * log(
318+
cst_05 * (sd**cst_m2) / np.pi
319+
)
320+
factors.append(at_sum(f))
321+
322+
logp = add(*factors)
323+
324+
vars = [sd, means]
325+
dlogp = [pytensor.grad(logp, v) for v in vars]
326+
return vars, dlogp
327+
305328
@pytest.mark.parametrize(
306329
"case",
307330
[
@@ -1059,35 +1082,9 @@ def test_fusion_35_inputs(self):
10591082

10601083
@pytest.mark.skipif(not config.cxx, reason="No cxx compiler")
10611084
def test_big_fusion(self):
1062-
# In the past, pickle of Composite generated in that case
1063-
# crashed with max recursion limit. So we were not able to
1064-
# generate C code in that case.
1065-
factors = []
1066-
sd = dscalar()
1067-
means = dvector()
1068-
1069-
cst_05 = at.constant(0.5)
1070-
cst_m05 = at.constant(-0.5)
1071-
cst_2 = at.constant(2)
1072-
cst_m2 = at.constant(-2)
1073-
ones = at.constant(np.ones(10))
1074-
n = 85
1075-
if config.mode in ["DebugMode", "DEBUG_MODE"]:
1076-
n = 10
1077-
1078-
for i in range(n):
1079-
f = cst_m05 * sd**cst_m2 * (ones - means[i]) ** cst_2 + cst_05 * log(
1080-
cst_05 * (sd**cst_m2) / np.pi
1081-
)
1082-
factors.append(at_sum(f))
1083-
1084-
logp = add(*factors)
1085-
1086-
vars = [sd, means]
1087-
10881085
# Make sure that C compilation is used
10891086
mode = Mode("cvm", self.rewrites)
1090-
dlogp = function(vars, [pytensor.grad(logp, v) for v in vars], mode=mode)
1087+
dlogp = function(*self.large_fuseable_graph(n=85), mode=mode)
10911088

10921089
# Make sure something was fused
10931090
assert any(
@@ -1362,6 +1359,18 @@ def test_eval_benchmark(self, benchmark):
13621359
func = pytensor.function([], [logp, grad_logp], mode="FAST_RUN")
13631360
benchmark(func)
13641361

1362+
@pytest.mark.skipif(not config.cxx, reason="No cxx compiler")
1363+
def test_rewrite_benchmark(self, benchmark):
1364+
inps, outs = self.large_fuseable_graph(n=25)
1365+
fg = FunctionGraph(inps, outs)
1366+
opt = FusionOptimizer()
1367+
1368+
def rewrite_func():
1369+
nb_replacement = opt.apply(fg.clone())[2]
1370+
return nb_replacement
1371+
1372+
assert benchmark(rewrite_func) == 103
1373+
13651374

13661375
class TimesN(aes.basic.UnaryScalarOp):
13671376
"""

0 commit comments

Comments
 (0)