Skip to content

Commit a2f101a

Browse files
committed
Cleanup Fusion rewrites
* Move local_add_mul_fusion to `rewriting/elemwise` and remove unused/duplicated TestAddMulFusion tests * Use EquilibriumGraphRewriter for local_add_mul_fusion * Do not register optional rewrites if tensor__local_elemwise_fusion flag is disabled
1 parent daabeb3 commit a2f101a

File tree

5 files changed

+84
-854
lines changed

5 files changed

+84
-854
lines changed

pytensor/graph/rewriting/db.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -427,7 +427,7 @@ def query(
427427
position_cutoff = tags[0].position_cutoff
428428

429429
# The RewriteDatabaseQuery instance might contain extra rewrites which need
430-
# to be added the the sequence of rewrites (don't alter the
430+
# to be added to the sequence of rewrites (don't alter the
431431
# original dictionary)
432432
if len(tags[0].extra_rewrites) > 0:
433433
position_dict = position_dict.copy()

pytensor/tensor/rewriting/elemwise.py

Lines changed: 62 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from pytensor.graph.features import ReplaceValidate
1414
from pytensor.graph.op import compute_test_value, get_test_value
1515
from pytensor.graph.rewriting.basic import (
16+
EquilibriumGraphRewriter,
1617
GraphRewriter,
1718
copy_stack_trace,
1819
in2out,
@@ -529,6 +530,60 @@ def local_upcast_elemwise_constant_inputs(fgraph, node):
529530
return rval
530531

531532

533+
@node_rewriter([Elemwise])
534+
def local_add_mul_fusion(fgraph, node):
535+
"""Fuse consecutive add or mul in one such node with more inputs.
536+
537+
It is better to fuse add/mul that way then in a Composite node as
538+
this make the inner graph of the Composite smaller. This allows to
539+
put more computation in a Composite before hitting the max
540+
recursion limit when pickling Composite.
541+
542+
This rewrite is almost useless after the AlgebraicCanonizer is used,
543+
but it catches a few edge cases that are not canonicalized by it
544+
"""
545+
if not isinstance(node.op, Elemwise) or not isinstance(
546+
node.op.scalar_op, (aes.Add, aes.Mul)
547+
):
548+
return False
549+
550+
s_op = node.op.scalar_op.__class__
551+
new_inp = []
552+
fused = False
553+
nb_inputs = len(node.inputs)
554+
max_inputs = float("inf")
555+
if hasattr(node.op, "max_inputs"):
556+
max_inputs = node.op.max_inputs(node)
557+
for inp in node.inputs:
558+
if (
559+
inp.owner
560+
and isinstance(inp.owner.op, Elemwise)
561+
and isinstance(inp.owner.op.scalar_op, s_op)
562+
and
563+
# Do not duplicate the operation.
564+
len(fgraph.clients[inp]) == 1
565+
and (nb_inputs + len(inp.owner.inputs) - 1) <= max_inputs
566+
):
567+
new_inp.extend(inp.owner.inputs)
568+
fused = True
569+
else:
570+
new_inp.append(inp)
571+
572+
# We can not compare the number of inputs as Mul and Add could have
573+
# 0 or 1 inputs in some corner cases.
574+
if fused:
575+
output = node.op(*new_inp)
576+
copy_stack_trace(node.outputs[0], output)
577+
578+
# Do the recursion here to help lower the number of
579+
# FusionOptimizer iteration.
580+
if output.owner:
581+
output2 = local_add_mul_fusion.transform(fgraph, output.owner)
582+
if output2:
583+
return output2
584+
return [output]
585+
586+
532587
def local_elemwise_fusion_op(op_class, max_input_fct=lambda node: 32, maker=None):
533588
r"""Create a recursive function that fuses `Elemwise` `Op`\s.
534589
@@ -901,6 +956,13 @@ def print_profile(cls, stream, prof, level=0):
901956
if config.tensor__local_elemwise_fusion:
902957
# Must be after gpu(48.5) and before AddDestroyHandler(49.5)
903958
fuse_seqopt = SequenceDB()
959+
fuse_seqopt.register(
960+
"local_add_mul_fusion",
961+
EquilibriumGraphRewriter(rewriters=[local_add_mul_fusion], max_use_ratio=1000),
962+
"fast_run",
963+
"fusion",
964+
position=0,
965+
)
904966
fuse_seqopt.register(
905967
"composite_elemwise_fusion",
906968
FusionOptimizer(local_elemwise_fusion),
@@ -917,15 +979,6 @@ def print_profile(cls, stream, prof, level=0):
917979
"FusionOptimizer",
918980
position=49,
919981
)
920-
else:
921-
compile.optdb.register( # type: ignore
922-
"elemwise_fusion",
923-
FusionOptimizer(local_elemwise_fusion),
924-
"fusion",
925-
"local_elemwise_fusion",
926-
"FusionOptimizer",
927-
position=49,
928-
)
929982

930983

931984
@register_canonicalize

pytensor/tensor/rewriting/math.py

Lines changed: 0 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,6 @@
9292
register_uncanonicalize,
9393
register_useless,
9494
)
95-
from pytensor.tensor.rewriting.elemwise import FusionOptimizer, fuse_seqopt
9695
from pytensor.tensor.shape import Shape, Shape_i
9796
from pytensor.tensor.subtensor import Subtensor
9897
from pytensor.tensor.type import (
@@ -2966,66 +2965,6 @@ def check_input(inputs):
29662965
return [ret]
29672966

29682967

2969-
def local_add_mul_fusion(fgraph, node):
2970-
"""Fuse consecutive add or mul in one such node with more inputs.
2971-
2972-
It is better to fuse add/mul that way then in a Composite node as
2973-
this make the inner graph of the Composite smaller. This allow to
2974-
put more computation in a Composite before hitting the max
2975-
recursion limit when pickling Composite.
2976-
2977-
"""
2978-
if not isinstance(node.op, Elemwise) or not isinstance(
2979-
node.op.scalar_op, (aes.Add, aes.Mul)
2980-
):
2981-
return False
2982-
2983-
s_op = node.op.scalar_op.__class__
2984-
new_inp = []
2985-
fused = False
2986-
nb_inputs = len(node.inputs)
2987-
max_inputs = float("inf")
2988-
if hasattr(node.op, "max_inputs"):
2989-
max_inputs = node.op.max_inputs(node)
2990-
for inp in node.inputs:
2991-
if (
2992-
inp.owner
2993-
and isinstance(inp.owner.op, Elemwise)
2994-
and isinstance(inp.owner.op.scalar_op, s_op)
2995-
and
2996-
# Do not duplicate the operation.
2997-
len(fgraph.clients[inp]) == 1
2998-
and (nb_inputs + len(inp.owner.inputs) - 1) <= max_inputs
2999-
):
3000-
new_inp.extend(inp.owner.inputs)
3001-
fused = True
3002-
else:
3003-
new_inp.append(inp)
3004-
3005-
# We can not compare the number of inputs as Mul and Add could have
3006-
# 0 or 1 inputs in some corner cases.
3007-
if fused:
3008-
output = node.op(*new_inp)
3009-
copy_stack_trace(node.outputs[0], output)
3010-
3011-
# Do the recursion here to help lower the number of
3012-
# FusionOptimizer iteration.
3013-
if output.owner:
3014-
output2 = local_add_mul_fusion(fgraph, output.owner)
3015-
if output2:
3016-
return output2
3017-
return [output]
3018-
3019-
3020-
fuse_seqopt.register(
3021-
"local_add_mul_fusion",
3022-
FusionOptimizer(local_add_mul_fusion),
3023-
"fast_run",
3024-
"fusion",
3025-
position=0,
3026-
)
3027-
3028-
30292968
def _skip_mul_1(r):
30302969
if r.owner and r.owner.op == mul:
30312970
not_is_1 = [i for i in r.owner.inputs if not _is_1(i)]

tests/tensor/rewriting/test_elemwise.py

Lines changed: 7 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@
44
import pytest
55

66
import pytensor
7-
import pytensor.scalar as aes
8-
import pytensor.tensor as at
7+
from pytensor import scalar as aes
98
from pytensor import shared
9+
from pytensor import tensor as at
1010
from pytensor.compile.function import function
1111
from pytensor.compile.mode import Mode, get_default_mode
1212
from pytensor.configdefaults import config
@@ -263,9 +263,8 @@ def test_local_useless_dimshuffle_in_reshape():
263263
class TestFusion:
264264
rewrites = RewriteDatabaseQuery(
265265
include=[
266-
"local_elemwise_fusion",
267-
"composite_elemwise_fusion",
268266
"canonicalize",
267+
"fusion",
269268
"inplace",
270269
],
271270
exclude=["cxx_only", "BlasOpt"],
@@ -1007,22 +1006,10 @@ def test_big_fusion(self):
10071006
)
10081007

10091008
def test_add_mul_fusion_inplace(self):
1010-
1011-
rewrites = RewriteDatabaseQuery(
1012-
include=[
1013-
"local_elemwise_fusion",
1014-
"composite_elemwise_fusion",
1015-
"canonicalize",
1016-
"inplace",
1017-
],
1018-
exclude=["cxx_only", "BlasOpt"],
1019-
)
1020-
1021-
mode = Mode(self.mode.linker, rewrites)
1022-
10231009
x, y, z = dmatrices("xyz")
10241010
out = dot(x, y) + x + y + z
1025-
f = function([x, y, z], out, mode=mode)
1011+
1012+
f = function([x, y, z], out, mode=self.mode)
10261013
topo = [n for n in f.maker.fgraph.toposort()]
10271014
assert len(topo) == 2
10281015
assert topo[-1].op.inplace_pattern
@@ -1050,8 +1037,7 @@ def impl(self, x):
10501037

10511038
mode = Mode(linker="cvm")
10521039
mode._optimizer = mode._optimizer.including(
1053-
"local_elemwise_fusion",
1054-
"composite_elemwise_fusion",
1040+
"fusion",
10551041
"canonicalize",
10561042
"inplace",
10571043
)
@@ -1073,18 +1059,6 @@ def test_test_values(self, test_value):
10731059
are checked.
10741060
10751061
"""
1076-
1077-
rewrites = RewriteDatabaseQuery(
1078-
include=[
1079-
"local_elemwise_fusion",
1080-
"composite_elemwise_fusion",
1081-
"canonicalize",
1082-
],
1083-
exclude=["cxx_only", "BlasOpt"],
1084-
)
1085-
1086-
mode = Mode(self.mode.linker, rewrites)
1087-
10881062
x, y, z = dmatrices("xyz")
10891063

10901064
x.tag.test_value = test_value
@@ -1101,7 +1075,7 @@ def test_test_values(self, test_value):
11011075
):
11021076
out = x * y + z
11031077
with cm:
1104-
f = function([x, y, z], out, mode=mode)
1078+
f = function([x, y, z], out, mode=self.mode)
11051079

11061080
if test_value.size != 0:
11071081
# Confirm that the fusion happened

0 commit comments

Comments
 (0)