Skip to content

Commit ca12b58

Browse files
committed
Constant fold branches of variadic add/mul
1 parent 261aaf3 commit ca12b58

File tree

2 files changed

+104
-29
lines changed

2 files changed

+104
-29
lines changed

pytensor/tensor/rewriting/elemwise.py

Lines changed: 83 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
import itertools
2+
import operator
23
import sys
34
from collections import Counter, defaultdict, deque
45
from collections.abc import Generator
5-
from functools import cache
6+
from functools import cache, reduce
67
from typing import TypeVar
78
from warnings import warn
89

@@ -16,11 +17,11 @@
1617
from pytensor.graph.features import ReplaceValidate
1718
from pytensor.graph.fg import Output
1819
from pytensor.graph.rewriting.basic import (
19-
EquilibriumGraphRewriter,
2020
GraphRewriter,
2121
copy_stack_trace,
2222
in2out,
2323
node_rewriter,
24+
out2in,
2425
)
2526
from pytensor.graph.rewriting.db import SequenceDB
2627
from pytensor.graph.utils import InconsistencyError, MethodNotDefined
@@ -29,13 +30,15 @@
2930
MakeVector,
3031
alloc,
3132
cast,
33+
constant,
3234
get_underlying_scalar_constant_value,
3335
)
3436
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
3537
from pytensor.tensor.exceptions import NotScalarConstantError
36-
from pytensor.tensor.math import exp
38+
from pytensor.tensor.math import add, exp, mul
3739
from pytensor.tensor.rewriting.basic import (
3840
alloc_like,
41+
broadcasted_by,
3942
register_canonicalize,
4043
register_specialize,
4144
)
@@ -542,8 +545,8 @@ def local_upcast_elemwise_constant_inputs(fgraph, node):
542545
return rval
543546

544547

545-
@node_rewriter([Elemwise])
546-
def local_add_mul_fusion(fgraph, node):
548+
@node_rewriter([add, mul])
549+
def flatten_nested_add_mul(fgraph, node):
547550
"""Fuse consecutive add or mul in one such node with more inputs.
548551
549552
It is better to fuse add/mul that way then in a Composite node as
@@ -554,27 +557,16 @@ def local_add_mul_fusion(fgraph, node):
554557
This rewrite is almost useless after the AlgebraicCanonizer is used,
555558
but it catches a few edge cases that are not canonicalized by it
556559
"""
557-
if not (
558-
isinstance(node.op, Elemwise) and isinstance(node.op.scalar_op, ps.Add | ps.Mul)
559-
):
560-
return False
561-
562-
s_op = node.op.scalar_op.__class__
560+
s_op = node.op.scalar_op
563561
new_inp = []
564562
fused = False
565-
nb_inputs = len(node.inputs)
566-
max_inputs = float("inf")
567-
if hasattr(node.op, "max_inputs"):
568-
max_inputs = node.op.max_inputs(node)
569563
for inp in node.inputs:
570564
if (
571565
inp.owner
572566
and isinstance(inp.owner.op, Elemwise)
573-
and isinstance(inp.owner.op.scalar_op, s_op)
574-
and
567+
and inp.owner.op.scalar_op == s_op
575568
# Do not duplicate the operation.
576-
len(fgraph.clients[inp]) == 1
577-
and (nb_inputs + len(inp.owner.inputs) - 1) <= max_inputs
569+
and len(fgraph.clients[inp]) == 1
578570
):
579571
new_inp.extend(inp.owner.inputs)
580572
fused = True
@@ -590,7 +582,7 @@ def local_add_mul_fusion(fgraph, node):
590582
# Do the recursion here to help lower the number of
591583
# FusionOptimizer iteration.
592584
if output.owner:
593-
output2 = local_add_mul_fusion.transform(fgraph, output.owner)
585+
output2 = flatten_nested_add_mul.transform(fgraph, output.owner)
594586
if output2:
595587
return output2
596588
return [output]
@@ -1237,6 +1229,76 @@ def local_inline_composite_constants(fgraph, node):
12371229
return new_outputs
12381230

12391231

1232+
@node_rewriter(tracks=[add, mul])
1233+
def constant_fold_branches_of_add_mul(fgraph, node):
1234+
old_constants = [inp for inp in node.inputs if isinstance(inp, TensorConstant)]
1235+
1236+
if len(old_constants) <= 1:
1237+
return None
1238+
1239+
new_constants = old_constants.copy()
1240+
1241+
# Multiply constants if it doesn't result in higher intermediate memory
1242+
while True:
1243+
n_constants = len(new_constants)
1244+
if n_constants <= 1:
1245+
break
1246+
1247+
for i in range(n_constants):
1248+
reference_inp = new_constants[i]
1249+
other_inps = []
1250+
for j in range(n_constants):
1251+
if i == j:
1252+
continue
1253+
other_inp = new_constants[j]
1254+
if not broadcasted_by(reference_inp, other_inp):
1255+
other_inps.append(other_inp)
1256+
if other_inps:
1257+
python_op = operator.mul if node.op == mul else operator.add
1258+
folded_inputs = [reference_inp, *other_inps]
1259+
new_inp = constant(
1260+
reduce(python_op, (const.data for const in folded_inputs))
1261+
)
1262+
new_constants = [
1263+
new_inp,
1264+
*(inp for inp in new_constants if inp not in folded_inputs),
1265+
]
1266+
break
1267+
else: # no-break
1268+
break
1269+
1270+
if len(new_constants) == len(old_constants):
1271+
return None
1272+
1273+
non_constants = [inp for inp in node.inputs if not isinstance(inp, TensorConstant)]
1274+
new_out = node.op(
1275+
*new_constants,
1276+
*non_constants,
1277+
)
1278+
copy_stack_trace(node.outputs[0], new_out)
1279+
return [new_out]
1280+
1281+
1282+
add_mul_flat_seqopt = SequenceDB()
1283+
compile.optdb.register(
1284+
"add_mul_flat",
1285+
add_mul_flat_seqopt,
1286+
"fast_run",
1287+
position=48, # Before Elemwise fusion
1288+
)
1289+
add_mul_flat_seqopt.register(
1290+
flatten_nested_add_mul.__name__,
1291+
out2in(flatten_nested_add_mul, ignore_newtrees=False),
1292+
"fast_run",
1293+
position=0,
1294+
)
1295+
add_mul_flat_seqopt.register(
1296+
constant_fold_branches_of_add_mul.__name__,
1297+
in2out(constant_fold_branches_of_add_mul, ignore_newtrees=True),
1298+
"fast_run",
1299+
position=1,
1300+
)
1301+
12401302
# Register fusion database just before AddDestroyHandler(49.5) (inplace rewrites)
12411303
fuse_seqopt = SequenceDB()
12421304
compile.optdb.register(
@@ -1248,14 +1310,6 @@ def local_inline_composite_constants(fgraph, node):
12481310
"FusionOptimizer",
12491311
position=49,
12501312
)
1251-
1252-
fuse_seqopt.register(
1253-
"local_add_mul_fusion",
1254-
EquilibriumGraphRewriter(rewriters=[local_add_mul_fusion], max_use_ratio=1000),
1255-
"fast_run",
1256-
"fusion",
1257-
position=0,
1258-
)
12591313
fuse_seqopt.register(
12601314
"composite_elemwise_fusion",
12611315
FusionOptimizer(),
@@ -1279,7 +1333,7 @@ def local_inline_composite_constants(fgraph, node):
12791333
)
12801334
fuse_seqopt.register(
12811335
"local_inline_composite_constants",
1282-
in2out(local_inline_composite_constants),
1336+
in2out(local_inline_composite_constants, ignore_newtrees=True),
12831337
"fast_run",
12841338
"fusion",
12851339
position=20,

tests/tensor/rewriting/test_elemwise.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1507,3 +1507,24 @@ def test_local_useless_dimshuffle_makevector():
15071507
)
15081508

15091509
assert y_rewritten_fg.outputs[0] == a
1510+
1511+
1512+
@pytest.mark.parametrize("op", (add, mul))
1513+
def test_constant_fold_branches_add_mul(op):
1514+
rng = np.random.default_rng()
1515+
py_op = np.add if op is add else np.multiply
1516+
1517+
x = pt.vector("x")
1518+
a = rng.normal(size=(1, 512, 5))
1519+
b = rng.normal(size=(1, 512, 1))
1520+
out = op(op(a, x), b)
1521+
new_out = rewrite_graph(out, include=("fast_run",), exclude=("inplace",))
1522+
assert len(new_out.owner.inputs) == 2
1523+
assert equal_computations([new_out], [op(py_op(a, b), x)])
1524+
1525+
# c shouldn't be folded as it would increase the memory usage
1526+
c = rng.normal(size=(1024, 1, 1))
1527+
out = op(op(op(a, x), c), b)
1528+
new_out = rewrite_graph(out, include=("fast_run",), exclude=("inplace",))
1529+
assert len(new_out.owner.inputs) == 3
1530+
assert equal_computations([new_out], [op(py_op(a, b), c, x)])

0 commit comments

Comments
 (0)