Skip to content

Commit 880e928

Browse files
committed
Add benchmark test for FusionRewriter
1 parent 589ab7e commit 880e928

File tree

1 file changed

+37
-28
lines changed

1 file changed

+37
-28
lines changed

tests/tensor/rewriting/test_elemwise.py

+37-28
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848
from pytensor.tensor.math import sin, sinh, sqr, sqrt
4949
from pytensor.tensor.math import sum as at_sum
5050
from pytensor.tensor.math import tan, tanh, true_div, xor
51-
from pytensor.tensor.rewriting.elemwise import local_dimshuffle_lift
51+
from pytensor.tensor.rewriting.elemwise import FusionOptimizer, local_dimshuffle_lift
5252
from pytensor.tensor.rewriting.shape import local_useless_dimshuffle_in_reshape
5353
from pytensor.tensor.shape import reshape
5454
from pytensor.tensor.type import (
@@ -302,6 +302,29 @@ def my_init(dtype="float64", num=0):
302302
fwx = fw + fx
303303
ftanx = tan(fx)
304304

305+
def large_fuseable_graph(self, n):
306+
factors = []
307+
sd = dscalar()
308+
means = dvector()
309+
310+
cst_05 = at.constant(0.5)
311+
cst_m05 = at.constant(-0.5)
312+
cst_2 = at.constant(2)
313+
cst_m2 = at.constant(-2)
314+
ones = at.constant(np.ones(10))
315+
316+
for i in range(n):
317+
f = cst_m05 * sd**cst_m2 * (ones - means[i]) ** cst_2 + cst_05 * log(
318+
cst_05 * (sd**cst_m2) / np.pi
319+
)
320+
factors.append(at_sum(f))
321+
322+
logp = add(*factors)
323+
324+
vars = [sd, means]
325+
dlogp = [pytensor.grad(logp, v) for v in vars]
326+
return vars, dlogp
327+
305328
@pytest.mark.parametrize(
306329
"case",
307330
[
@@ -1064,35 +1087,9 @@ def test_fusion_35_inputs(self):
10641087

10651088
@pytest.mark.skipif(not config.cxx, reason="No cxx compiler")
10661089
def test_big_fusion(self):
1067-
# In the past, pickle of Composite generated in that case
1068-
# crashed with max recursion limit. So we were not able to
1069-
# generate C code in that case.
1070-
factors = []
1071-
sd = dscalar()
1072-
means = dvector()
1073-
1074-
cst_05 = at.constant(0.5)
1075-
cst_m05 = at.constant(-0.5)
1076-
cst_2 = at.constant(2)
1077-
cst_m2 = at.constant(-2)
1078-
ones = at.constant(np.ones(10))
1079-
n = 85
1080-
if config.mode in ["DebugMode", "DEBUG_MODE"]:
1081-
n = 10
1082-
1083-
for i in range(n):
1084-
f = cst_m05 * sd**cst_m2 * (ones - means[i]) ** cst_2 + cst_05 * log(
1085-
cst_05 * (sd**cst_m2) / np.pi
1086-
)
1087-
factors.append(at_sum(f))
1088-
1089-
logp = add(*factors)
1090-
1091-
vars = [sd, means]
1092-
10931090
# Make sure that C compilation is used
10941091
mode = Mode("cvm", self.rewrites)
1095-
dlogp = function(vars, [pytensor.grad(logp, v) for v in vars], mode=mode)
1092+
dlogp = function(*self.large_fuseable_graph(n=85), mode=mode)
10961093

10971094
# Make sure something was fused
10981095
assert any(
@@ -1367,6 +1364,18 @@ def test_eval_benchmark(self, benchmark):
13671364
func = pytensor.function([], [logp, grad_logp], mode="FAST_RUN")
13681365
benchmark(func)
13691366

1367+
@pytest.mark.skipif(not config.cxx, reason="No cxx compiler")
1368+
def test_rewrite_benchmark(self, benchmark):
1369+
inps, outs = self.large_fuseable_graph(n=25)
1370+
fg = FunctionGraph(inps, outs)
1371+
opt = FusionOptimizer()
1372+
1373+
def rewrite_func():
1374+
nb_replacement = opt.apply(fg.clone())[2]
1375+
return nb_replacement
1376+
1377+
assert benchmark(rewrite_func) == 103
1378+
13701379

13711380
class TimesN(aes.basic.UnaryScalarOp):
13721381
"""

0 commit comments

Comments
 (0)