@@ -593,20 +593,13 @@ def local_add_mul_fusion(fgraph, node):
593
593
return [output ]
594
594
595
595
596
- def elemwise_max_operands_fct (node ) -> int :
597
- # `Elemwise.perform` uses NumPy ufuncs and they are limited to 32 operands (inputs and outputs)
598
- if not config .cxx :
599
- return 32
600
- return 1024
601
-
602
-
603
596
class FusionOptimizer (GraphRewriter ):
604
597
"""Graph optimizer that fuses consecutive Elemwise operations."""
605
598
606
- def __init__ (self , local_optimizer = None ):
607
- # TODO: Figure out what to do with this
599
+ def __init__ (self , backend ):
608
600
super ().__init__ ()
609
- self .optimizer = local_optimizer
601
+ assert backend in ("py" , "c" , "numba" )
602
+ self .backend = backend
610
603
611
604
def add_requirements (self , fgraph ):
612
605
fgraph .attach_feature (ReplaceValidate ())
@@ -654,29 +647,29 @@ def elemwise_to_scalar(inputs, outputs):
654
647
return scalar_inputs , scalar_outputs
655
648
656
649
def apply (self , fgraph ):
650
+ # Even though this rewrite it marked as `cxx_only`,
651
+ # it may sometimes be called when `cxx` is disabled -.-
652
+ if self .backend == "c" and not config .cxx :
653
+ return
654
+
657
655
nb_replacement = 0
658
656
659
657
if fgraph .profile :
660
658
validate_before = fgraph .profile .validate_time
661
659
callbacks_before = fgraph .execute_callbacks_times .copy ()
662
660
callback_before = fgraph .execute_callbacks_time
663
661
664
- max_operands = elemwise_max_operands_fct (None )
662
+ # `Elemwise.perform` uses NumPy ufuncs and they are limited to 32 operands (inputs and outputs)
663
+ max_operands = 32 if self .backend == "py" else 1024
665
664
666
- def find_next_fuseable_subgraph (
667
- fg : FunctionGraph ,
668
- ) -> Generator [Tuple [List [Variable ], List [Variable ]], None , None ]:
669
- """Find all subgraphs in a FunctionGraph that can be fused together
670
-
671
- Yields
672
- -------
673
- List of inputs and outputs that determine subgraphs which can be fused. This
674
- method assumes that such replacement is done across iterations of the
675
- generator.
676
- """
665
+ if self .backend in ("py" , "c" ):
666
+ # Python mode is not really a backend, and it may or may not call C code
667
+ # Rewrites don't have access to the linker to make this decision, So we assume
668
+ # we can only fuse Ops with C implementation
677
669
670
+ # Python rewrite may
678
671
@lru_cache (maxsize = None )
679
- def elemwise_scalar_op_has_c_code (node : Apply ) -> bool :
672
+ def elemwise_scalar_op_can_be_fused (node : Apply ) -> bool :
680
673
if node .op .scalar_op .supports_c_code (node .inputs , node .outputs ):
681
674
return True
682
675
else :
@@ -690,6 +683,24 @@ def elemwise_scalar_op_has_c_code(node: Apply) -> bool:
690
683
)
691
684
return False
692
685
686
+ elif self .backend == "numba" :
687
+
688
+ def elemwise_scalar_op_can_be_fused (node : Apply ) -> bool :
689
+ # Should we truncate at numba elemwise ops that need to run in object mode?
690
+ return True
691
+
692
+ def find_next_fuseable_subgraph (
693
+ fg : FunctionGraph ,
694
+ ) -> Generator [Tuple [List [Variable ], List [Variable ]], None , None ]:
695
+ """Find all subgraphs in a FunctionGraph that can be fused together
696
+
697
+ Yields
698
+ -------
699
+ List of inputs and outputs that determine subgraphs which can be fused. This
700
+ method assumes that such replacement is done across iterations of the
701
+ generator.
702
+ """
703
+
693
704
# We start by creating two maps, 1) from each node to each potentially
694
705
# fuseable client (both nodes must be single output Elemwise with same
695
706
# broadcast type) and 2) from each node to each certainly unfuseable
@@ -702,7 +713,7 @@ def elemwise_scalar_op_has_c_code(node: Apply) -> bool:
702
713
and isinstance (out .owner .op , Elemwise )
703
714
# and not isinstance(out.owner.op.scalar_op, aes.Composite)
704
715
and len (out .owner .outputs ) == 1
705
- and elemwise_scalar_op_has_c_code (out .owner )
716
+ and elemwise_scalar_op_can_be_fused (out .owner )
706
717
)
707
718
for client , _ in clients :
708
719
if (
@@ -713,7 +724,7 @@ def elemwise_scalar_op_has_c_code(node: Apply) -> bool:
713
724
and len (client .outputs ) == 1
714
725
and out .type .broadcastable
715
726
== client .outputs [0 ].type .broadcastable
716
- and elemwise_scalar_op_has_c_code (client )
727
+ and elemwise_scalar_op_can_be_fused (client )
717
728
):
718
729
if client not in fuseable_clients [out ]:
719
730
fuseable_clients [out ].append (client )
@@ -1001,7 +1012,7 @@ def elemwise_scalar_op_has_c_code(node: Apply) -> bool:
1001
1012
if (len (inputs ) + len (outputs )) > max_operands :
1002
1013
warn (
1003
1014
"Loop fusion failed because the resulting node would exceed "
1004
- "the kernel argument limit."
1015
+ "the backend limit for number of operands ."
1005
1016
)
1006
1017
break
1007
1018
@@ -1067,30 +1078,68 @@ def print_profile(stream, prof, level=0):
1067
1078
print (blanc , " time_toposort" , prof [7 ], file = stream )
1068
1079
1069
1080
1070
- if config .tensor__local_elemwise_fusion :
1071
- # Must be after gpu(48.5) and before AddDestroyHandler(49.5)
1072
- fuse_seqopt = SequenceDB ()
1073
- fuse_seqopt .register (
1081
+ fuse_opt_py = SequenceDB ()
1082
+ fuse_opt_c = SequenceDB ()
1083
+ fuse_opt_numba = SequenceDB ()
1084
+ for fuse_opt in (fuse_opt_py , fuse_opt_c , fuse_opt_numba ):
1085
+ fuse_opt .register (
1074
1086
"local_add_mul_fusion" ,
1075
1087
EquilibriumGraphRewriter (rewriters = [local_add_mul_fusion ], max_use_ratio = 1000 ),
1076
1088
"fast_run" ,
1077
1089
"fusion" ,
1078
1090
position = 0 ,
1079
1091
)
1080
- fuse_seqopt .register (
1081
- "composite_elemwise_fusion" ,
1082
- FusionOptimizer (),
1092
+ fuse_opt_py .register (
1093
+ "composite_elemwise_fusion_py" ,
1094
+ FusionOptimizer ("py" ),
1095
+ "fast_run" ,
1096
+ "fusion" ,
1097
+ position = 1 ,
1098
+ )
1099
+ fuse_opt_c .register (
1100
+ "composite_elemwise_fusion_c" ,
1101
+ FusionOptimizer ("c" ),
1102
+ "fast_run" ,
1103
+ "fusion" ,
1104
+ position = 1 ,
1105
+ )
1106
+ fuse_opt_numba .register (
1107
+ "composite_elemwise_fusion_numba" ,
1108
+ FusionOptimizer ("numba" ),
1109
+ "fast_run" ,
1110
+ "fusion" ,
1111
+ position = 1 ,
1112
+ )
1113
+
1114
+
1115
+ if config .tensor__local_elemwise_fusion :
1116
+ # Must be after gpu(48.5) and before AddDestroyHandler(49.5)
1117
+ compile .optdb .register ( # type: ignore
1118
+ "elemwise_fusion_c" ,
1119
+ fuse_opt_c ,
1083
1120
"fast_run" ,
1084
1121
"fusion" ,
1085
- position = 1 ,
1122
+ "local_elemwise_fusion" ,
1123
+ "FusionOptimizer" ,
1124
+ "cxx_only" ,
1125
+ position = 49 ,
1086
1126
)
1127
+ # We allow the Python version to run afterwards,
1128
+ # since there is no mode for Python only
1087
1129
compile .optdb .register ( # type: ignore
1088
- "elemwise_fusion " ,
1089
- fuse_seqopt ,
1130
+ "elemwise_fusion_py " ,
1131
+ fuse_opt_py ,
1090
1132
"fast_run" ,
1091
1133
"fusion" ,
1092
1134
"local_elemwise_fusion" ,
1093
1135
"FusionOptimizer" ,
1136
+ position = 49.01 ,
1137
+ )
1138
+ # TODO: Not sure about this... Could rewrites receive info about the linker that is being used?
1139
+ compile .optdb .register ( # type: ignore
1140
+ "elemwise_fusion_numba" ,
1141
+ fuse_opt_numba ,
1142
+ "numba" ,
1094
1143
position = 49 ,
1095
1144
)
1096
1145
0 commit comments