1
1
import itertools
2
+ import operator
2
3
import sys
3
4
from collections import Counter , defaultdict , deque
4
5
from collections .abc import Generator
5
- from functools import cache
6
+ from functools import cache , reduce
6
7
from typing import TypeVar
7
8
from warnings import warn
8
9
16
17
from pytensor .graph .features import ReplaceValidate
17
18
from pytensor .graph .fg import Output
18
19
from pytensor .graph .rewriting .basic import (
19
- EquilibriumGraphRewriter ,
20
20
GraphRewriter ,
21
21
copy_stack_trace ,
22
22
in2out ,
23
23
node_rewriter ,
24
+ out2in ,
24
25
)
25
26
from pytensor .graph .rewriting .db import SequenceDB
26
27
from pytensor .graph .utils import InconsistencyError , MethodNotDefined
29
30
MakeVector ,
30
31
alloc ,
31
32
cast ,
33
+ constant ,
32
34
get_underlying_scalar_constant_value ,
33
35
)
34
36
from pytensor .tensor .elemwise import CAReduce , DimShuffle , Elemwise
35
37
from pytensor .tensor .exceptions import NotScalarConstantError
36
- from pytensor .tensor .math import exp
38
+ from pytensor .tensor .math import add , exp , mul
37
39
from pytensor .tensor .rewriting .basic import (
38
40
alloc_like ,
41
+ broadcasted_by ,
39
42
register_canonicalize ,
40
43
register_specialize ,
41
44
)
@@ -542,8 +545,8 @@ def local_upcast_elemwise_constant_inputs(fgraph, node):
542
545
return rval
543
546
544
547
545
- @node_rewriter ([Elemwise ])
546
- def local_add_mul_fusion (fgraph , node ):
548
+ @node_rewriter ([add , mul ])
549
+ def flatten_nested_add_mul (fgraph , node ):
547
550
"""Fuse consecutive add or mul in one such node with more inputs.
548
551
549
552
It is better to fuse add/mul that way then in a Composite node as
@@ -554,27 +557,16 @@ def local_add_mul_fusion(fgraph, node):
554
557
This rewrite is almost useless after the AlgebraicCanonizer is used,
555
558
but it catches a few edge cases that are not canonicalized by it
556
559
"""
557
- if not (
558
- isinstance (node .op , Elemwise ) and isinstance (node .op .scalar_op , ps .Add | ps .Mul )
559
- ):
560
- return False
561
-
562
- s_op = node .op .scalar_op .__class__
560
+ s_op = node .op .scalar_op
563
561
new_inp = []
564
562
fused = False
565
- nb_inputs = len (node .inputs )
566
- max_inputs = float ("inf" )
567
- if hasattr (node .op , "max_inputs" ):
568
- max_inputs = node .op .max_inputs (node )
569
563
for inp in node .inputs :
570
564
if (
571
565
inp .owner
572
566
and isinstance (inp .owner .op , Elemwise )
573
- and isinstance (inp .owner .op .scalar_op , s_op )
574
- and
567
+ and inp .owner .op .scalar_op == s_op
575
568
# Do not duplicate the operation.
576
- len (fgraph .clients [inp ]) == 1
577
- and (nb_inputs + len (inp .owner .inputs ) - 1 ) <= max_inputs
569
+ and len (fgraph .clients [inp ]) == 1
578
570
):
579
571
new_inp .extend (inp .owner .inputs )
580
572
fused = True
@@ -590,7 +582,7 @@ def local_add_mul_fusion(fgraph, node):
590
582
# Do the recursion here to help lower the number of
591
583
# FusionOptimizer iteration.
592
584
if output .owner :
593
- output2 = local_add_mul_fusion .transform (fgraph , output .owner )
585
+ output2 = flatten_nested_add_mul .transform (fgraph , output .owner )
594
586
if output2 :
595
587
return output2
596
588
return [output ]
@@ -1237,6 +1229,76 @@ def local_inline_composite_constants(fgraph, node):
1237
1229
return new_outputs
1238
1230
1239
1231
1232
+ @node_rewriter (tracks = [add , mul ])
1233
+ def constant_fold_branches_of_add_mul (fgraph , node ):
1234
+ old_constants = [inp for inp in node .inputs if isinstance (inp , TensorConstant )]
1235
+
1236
+ if len (old_constants ) <= 1 :
1237
+ return None
1238
+
1239
+ new_constants = old_constants .copy ()
1240
+
1241
+ # Multiply constants if it doesn't result in higher intermediate memory
1242
+ while True :
1243
+ n_constants = len (new_constants )
1244
+ if n_constants <= 1 :
1245
+ break
1246
+
1247
+ for i in range (n_constants ):
1248
+ reference_inp = new_constants [i ]
1249
+ other_inps = []
1250
+ for j in range (n_constants ):
1251
+ if i == j :
1252
+ continue
1253
+ other_inp = new_constants [j ]
1254
+ if not broadcasted_by (reference_inp , other_inp ):
1255
+ other_inps .append (other_inp )
1256
+ if other_inps :
1257
+ python_op = operator .mul if node .op == mul else operator .add
1258
+ folded_inputs = [reference_inp , * other_inps ]
1259
+ new_inp = constant (
1260
+ reduce (python_op , (const .data for const in folded_inputs ))
1261
+ )
1262
+ new_constants = [
1263
+ new_inp ,
1264
+ * (inp for inp in new_constants if inp not in folded_inputs ),
1265
+ ]
1266
+ break
1267
+ else : # no-break
1268
+ break
1269
+
1270
+ if len (new_constants ) == len (old_constants ):
1271
+ return None
1272
+
1273
+ non_constants = [inp for inp in node .inputs if not isinstance (inp , TensorConstant )]
1274
+ new_out = node .op (
1275
+ * new_constants ,
1276
+ * non_constants ,
1277
+ )
1278
+ copy_stack_trace (node .outputs [0 ], new_out )
1279
+ return [new_out ]
1280
+
1281
+
1282
+ add_mul_fusion_seqopt = SequenceDB ()
1283
+ compile .optdb .register (
1284
+ "add_mul_fusion" ,
1285
+ add_mul_fusion_seqopt ,
1286
+ "fast_run" ,
1287
+ position = 48 , # Before Elemwise fusion
1288
+ )
1289
+ add_mul_fusion_seqopt .register (
1290
+ flatten_nested_add_mul .__name__ ,
1291
+ out2in (flatten_nested_add_mul , ignore_newtrees = False ),
1292
+ "fast_run" ,
1293
+ position = 0 ,
1294
+ )
1295
+ add_mul_fusion_seqopt .register (
1296
+ constant_fold_branches_of_add_mul .__name__ ,
1297
+ in2out (constant_fold_branches_of_add_mul , ignore_newtrees = True ),
1298
+ "fast_run" ,
1299
+ position = 1 ,
1300
+ )
1301
+
1240
1302
# Register fusion database just before AddDestroyHandler(49.5) (inplace rewrites)
1241
1303
fuse_seqopt = SequenceDB ()
1242
1304
compile .optdb .register (
@@ -1248,14 +1310,6 @@ def local_inline_composite_constants(fgraph, node):
1248
1310
"FusionOptimizer" ,
1249
1311
position = 49 ,
1250
1312
)
1251
-
1252
- fuse_seqopt .register (
1253
- "local_add_mul_fusion" ,
1254
- EquilibriumGraphRewriter (rewriters = [local_add_mul_fusion ], max_use_ratio = 1000 ),
1255
- "fast_run" ,
1256
- "fusion" ,
1257
- position = 0 ,
1258
- )
1259
1313
fuse_seqopt .register (
1260
1314
"composite_elemwise_fusion" ,
1261
1315
FusionOptimizer (),
@@ -1279,7 +1333,7 @@ def local_inline_composite_constants(fgraph, node):
1279
1333
)
1280
1334
fuse_seqopt .register (
1281
1335
"local_inline_composite_constants" ,
1282
- in2out (local_inline_composite_constants ),
1336
+ in2out (local_inline_composite_constants , ignore_newtrees = True ),
1283
1337
"fast_run" ,
1284
1338
"fusion" ,
1285
1339
position = 20 ,
0 commit comments