|
48 | 48 | from pytensor.tensor.math import sin, sinh, sqr, sqrt
|
49 | 49 | from pytensor.tensor.math import sum as at_sum
|
50 | 50 | from pytensor.tensor.math import tan, tanh, true_div, xor
|
51 |
| -from pytensor.tensor.rewriting.elemwise import local_dimshuffle_lift |
| 51 | +from pytensor.tensor.rewriting.elemwise import FusionOptimizer, local_dimshuffle_lift |
52 | 52 | from pytensor.tensor.rewriting.shape import local_useless_dimshuffle_in_reshape
|
53 | 53 | from pytensor.tensor.shape import reshape
|
54 | 54 | from pytensor.tensor.type import (
|
@@ -302,6 +302,29 @@ def my_init(dtype="float64", num=0):
|
302 | 302 | fwx = fw + fx
|
303 | 303 | ftanx = tan(fx)
|
304 | 304 |
|
| 305 | + def large_fuseable_graph(self, n): |
| 306 | + factors = [] |
| 307 | + sd = dscalar() |
| 308 | + means = dvector() |
| 309 | + |
| 310 | + cst_05 = at.constant(0.5) |
| 311 | + cst_m05 = at.constant(-0.5) |
| 312 | + cst_2 = at.constant(2) |
| 313 | + cst_m2 = at.constant(-2) |
| 314 | + ones = at.constant(np.ones(10)) |
| 315 | + |
| 316 | + for i in range(n): |
| 317 | + f = cst_m05 * sd**cst_m2 * (ones - means[i]) ** cst_2 + cst_05 * log( |
| 318 | + cst_05 * (sd**cst_m2) / np.pi |
| 319 | + ) |
| 320 | + factors.append(at_sum(f)) |
| 321 | + |
| 322 | + logp = add(*factors) |
| 323 | + |
| 324 | + vars = [sd, means] |
| 325 | + dlogp = [pytensor.grad(logp, v) for v in vars] |
| 326 | + return vars, dlogp |
| 327 | + |
305 | 328 | @pytest.mark.parametrize(
|
306 | 329 | "case",
|
307 | 330 | [
|
@@ -1059,35 +1082,9 @@ def test_fusion_35_inputs(self):
|
1059 | 1082 |
|
1060 | 1083 | @pytest.mark.skipif(not config.cxx, reason="No cxx compiler")
|
1061 | 1084 | def test_big_fusion(self):
|
1062 |
| - # In the past, pickle of Composite generated in that case |
1063 |
| - # crashed with max recursion limit. So we were not able to |
1064 |
| - # generate C code in that case. |
1065 |
| - factors = [] |
1066 |
| - sd = dscalar() |
1067 |
| - means = dvector() |
1068 |
| - |
1069 |
| - cst_05 = at.constant(0.5) |
1070 |
| - cst_m05 = at.constant(-0.5) |
1071 |
| - cst_2 = at.constant(2) |
1072 |
| - cst_m2 = at.constant(-2) |
1073 |
| - ones = at.constant(np.ones(10)) |
1074 |
| - n = 85 |
1075 |
| - if config.mode in ["DebugMode", "DEBUG_MODE"]: |
1076 |
| - n = 10 |
1077 |
| - |
1078 |
| - for i in range(n): |
1079 |
| - f = cst_m05 * sd**cst_m2 * (ones - means[i]) ** cst_2 + cst_05 * log( |
1080 |
| - cst_05 * (sd**cst_m2) / np.pi |
1081 |
| - ) |
1082 |
| - factors.append(at_sum(f)) |
1083 |
| - |
1084 |
| - logp = add(*factors) |
1085 |
| - |
1086 |
| - vars = [sd, means] |
1087 |
| - |
1088 | 1085 | # Make sure that C compilation is used
|
1089 | 1086 | mode = Mode("cvm", self.rewrites)
|
1090 |
| - dlogp = function(vars, [pytensor.grad(logp, v) for v in vars], mode=mode) |
| 1087 | + dlogp = function(*self.large_fuseable_graph(n=85), mode=mode) |
1091 | 1088 |
|
1092 | 1089 | # Make sure something was fused
|
1093 | 1090 | assert any(
|
@@ -1362,6 +1359,18 @@ def test_eval_benchmark(self, benchmark):
|
1362 | 1359 | func = pytensor.function([], [logp, grad_logp], mode="FAST_RUN")
|
1363 | 1360 | benchmark(func)
|
1364 | 1361 |
|
| 1362 | + @pytest.mark.skipif(not config.cxx, reason="No cxx compiler") |
| 1363 | + def test_rewrite_benchmark(self, benchmark): |
| 1364 | + inps, outs = self.large_fuseable_graph(n=25) |
| 1365 | + fg = FunctionGraph(inps, outs) |
| 1366 | + opt = FusionOptimizer() |
| 1367 | + |
| 1368 | + def rewrite_func(): |
| 1369 | + nb_replacement = opt.apply(fg.clone())[2] |
| 1370 | + return nb_replacement |
| 1371 | + |
| 1372 | + assert benchmark(rewrite_func) == 103 |
| 1373 | + |
1365 | 1374 |
|
1366 | 1375 | class TimesN(aes.basic.UnaryScalarOp):
|
1367 | 1376 | """
|
|
0 commit comments