|
48 | 48 | from pytensor.tensor.math import sin, sinh, sqr, sqrt
|
49 | 49 | from pytensor.tensor.math import sum as at_sum
|
50 | 50 | from pytensor.tensor.math import tan, tanh, true_div, xor
|
51 |
| -from pytensor.tensor.rewriting.elemwise import local_dimshuffle_lift |
| 51 | +from pytensor.tensor.rewriting.elemwise import FusionOptimizer, local_dimshuffle_lift |
52 | 52 | from pytensor.tensor.rewriting.shape import local_useless_dimshuffle_in_reshape
|
53 | 53 | from pytensor.tensor.shape import reshape
|
54 | 54 | from pytensor.tensor.type import (
|
@@ -302,6 +302,29 @@ def my_init(dtype="float64", num=0):
|
302 | 302 | fwx = fw + fx
|
303 | 303 | ftanx = tan(fx)
|
304 | 304 |
|
| 305 | + def large_fuseable_graph(self, n): |
| 306 | + factors = [] |
| 307 | + sd = dscalar() |
| 308 | + means = dvector() |
| 309 | + |
| 310 | + cst_05 = at.constant(0.5) |
| 311 | + cst_m05 = at.constant(-0.5) |
| 312 | + cst_2 = at.constant(2) |
| 313 | + cst_m2 = at.constant(-2) |
| 314 | + ones = at.constant(np.ones(10)) |
| 315 | + |
| 316 | + for i in range(n): |
| 317 | + f = cst_m05 * sd**cst_m2 * (ones - means[i]) ** cst_2 + cst_05 * log( |
| 318 | + cst_05 * (sd**cst_m2) / np.pi |
| 319 | + ) |
| 320 | + factors.append(at_sum(f)) |
| 321 | + |
| 322 | + logp = add(*factors) |
| 323 | + |
| 324 | + vars = [sd, means] |
| 325 | + dlogp = [pytensor.grad(logp, v) for v in vars] |
| 326 | + return vars, dlogp |
| 327 | + |
305 | 328 | @pytest.mark.parametrize(
|
306 | 329 | "case",
|
307 | 330 | [
|
@@ -1064,35 +1087,9 @@ def test_fusion_35_inputs(self):
|
1064 | 1087 |
|
1065 | 1088 | @pytest.mark.skipif(not config.cxx, reason="No cxx compiler")
|
1066 | 1089 | def test_big_fusion(self):
|
1067 |
| - # In the past, pickle of Composite generated in that case |
1068 |
| - # crashed with max recursion limit. So we were not able to |
1069 |
| - # generate C code in that case. |
1070 |
| - factors = [] |
1071 |
| - sd = dscalar() |
1072 |
| - means = dvector() |
1073 |
| - |
1074 |
| - cst_05 = at.constant(0.5) |
1075 |
| - cst_m05 = at.constant(-0.5) |
1076 |
| - cst_2 = at.constant(2) |
1077 |
| - cst_m2 = at.constant(-2) |
1078 |
| - ones = at.constant(np.ones(10)) |
1079 |
| - n = 85 |
1080 |
| - if config.mode in ["DebugMode", "DEBUG_MODE"]: |
1081 |
| - n = 10 |
1082 |
| - |
1083 |
| - for i in range(n): |
1084 |
| - f = cst_m05 * sd**cst_m2 * (ones - means[i]) ** cst_2 + cst_05 * log( |
1085 |
| - cst_05 * (sd**cst_m2) / np.pi |
1086 |
| - ) |
1087 |
| - factors.append(at_sum(f)) |
1088 |
| - |
1089 |
| - logp = add(*factors) |
1090 |
| - |
1091 |
| - vars = [sd, means] |
1092 |
| - |
1093 | 1090 | # Make sure that C compilation is used
|
1094 | 1091 | mode = Mode("cvm", self.rewrites)
|
1095 |
| - dlogp = function(vars, [pytensor.grad(logp, v) for v in vars], mode=mode) |
| 1092 | + dlogp = function(*self.large_fuseable_graph(n=85), mode=mode) |
1096 | 1093 |
|
1097 | 1094 | # Make sure something was fused
|
1098 | 1095 | assert any(
|
@@ -1367,6 +1364,18 @@ def test_eval_benchmark(self, benchmark):
|
1367 | 1364 | func = pytensor.function([], [logp, grad_logp], mode="FAST_RUN")
|
1368 | 1365 | benchmark(func)
|
1369 | 1366 |
|
| 1367 | + @pytest.mark.skipif(not config.cxx, reason="No cxx compiler") |
| 1368 | + def test_rewrite_benchmark(self, benchmark): |
| 1369 | + inps, outs = self.large_fuseable_graph(n=25) |
| 1370 | + fg = FunctionGraph(inps, outs) |
| 1371 | + opt = FusionOptimizer() |
| 1372 | + |
| 1373 | + def rewrite_func(): |
| 1374 | + nb_replacement = opt.apply(fg.clone())[2] |
| 1375 | + return nb_replacement |
| 1376 | + |
| 1377 | + assert benchmark(rewrite_func) == 103 |
| 1378 | + |
1370 | 1379 |
|
1371 | 1380 | class TimesN(aes.basic.UnaryScalarOp):
|
1372 | 1381 | """
|
|
0 commit comments