@@ -1150,11 +1150,20 @@ def local_careduce_fusion(fgraph, node):
1150
1150
"""Fuse a `CAReduce` applied to an `Elemwise`."""
1151
1151
1152
1152
(car_input ,) = node .inputs
1153
+ car_scalar_op = node .op .scalar_op
1154
+
1155
+ # FIXME: This check is needed because of the faulty logic in the FIXME below!
1156
+ # Right now, rewrite only works for `Sum`/`Prod`
1157
+ if not isinstance (car_scalar_op , (aes .Add , aes .Mul )):
1158
+ return None
1159
+
1153
1160
elm_node = car_input .owner
1154
1161
1155
1162
if elm_node is None or not isinstance (elm_node .op , Elemwise ):
1156
1163
return False
1157
1164
1165
+ elm_scalar_op = elm_node .op .scalar_op
1166
+
1158
1167
elm_inputs = elm_node .inputs
1159
1168
elm_outputs = elm_node .outputs
1160
1169
@@ -1166,21 +1175,15 @@ def local_careduce_fusion(fgraph, node):
1166
1175
return False
1167
1176
1168
1177
# Don't form the fusion when the target language is Python
1169
- elm_scalar_op = elm_node .op .scalar_op
1170
- car_scalar_op = node .op .scalar_op
1171
-
1172
1178
if get_target_language () == ("py" ,):
1173
1179
return False
1174
1180
1175
- try :
1176
- elm_scalar_op .c_code (
1177
- elm_node ,
1178
- "test_presence_of_c_code" ,
1179
- ["x" for x in elm_inputs ],
1180
- ["z" for z in elm_outputs ],
1181
- {"fail" : "%(fail)s" },
1182
- )
1181
+ if not elm_scalar_op .supports_c_code (elm_inputs , elm_outputs ):
1182
+ return None
1183
1183
1184
+ # FIXME: This fails with Ops like `Max` whose `c_code` always expects two inputs!
1185
+ # Should implement a `CAReduce.supports_c_code`?
1186
+ try :
1184
1187
car_scalar_op .c_code (
1185
1188
node ,
1186
1189
"test_presence_of_c_code" ,
@@ -1191,18 +1194,24 @@ def local_careduce_fusion(fgraph, node):
1191
1194
except (NotImplementedError , MethodNotDefined ):
1192
1195
return False
1193
1196
1194
- car_axis = node .op .axis
1197
+ car_op = node .op
1198
+ car_acc_dtype = node .op .acc_dtype
1195
1199
1196
1200
scalar_elm_inputs = [
1197
1201
aes .get_scalar_type (inp .type .dtype ).make_variable () for inp in elm_inputs
1198
1202
]
1203
+
1199
1204
elm_output = elm_scalar_op (* scalar_elm_inputs )
1205
+
1200
1206
# This input represents the previous value in the `CAReduce` binary reduction
1201
- carried_car_input = elm_output .type ()
1202
- scalar_fused_outputs = [car_scalar_op (carried_car_input , elm_output )]
1207
+ carried_car_input = aes .get_scalar_type (car_acc_dtype ).make_variable ()
1208
+
1209
+ scalar_fused_output = car_scalar_op (carried_car_input , elm_output )
1210
+ if scalar_fused_output .type .dtype != car_acc_dtype :
1211
+ scalar_fused_output = aes .cast (scalar_fused_output , car_acc_dtype )
1203
1212
1204
1213
fused_scalar_op = aes .Composite (
1205
- inputs = [carried_car_input ] + scalar_elm_inputs , outputs = scalar_fused_outputs
1214
+ inputs = [carried_car_input ] + scalar_elm_inputs , outputs = [ scalar_fused_output ]
1206
1215
)
1207
1216
1208
1217
# The fused `Op` needs to look and behave like a `BinaryScalarOp`
@@ -1211,7 +1220,13 @@ def local_careduce_fusion(fgraph, node):
1211
1220
fused_scalar_op .nin = 2
1212
1221
fused_scalar_op .nout = 1
1213
1222
1214
- new_car_op = CAReduce (fused_scalar_op , car_axis )
1223
+ new_car_op = CAReduce (
1224
+ scalar_op = fused_scalar_op ,
1225
+ axis = car_op .axis ,
1226
+ acc_dtype = car_acc_dtype ,
1227
+ dtype = car_op .dtype ,
1228
+ upcast_discrete_output = car_op .upcast_discrete_output ,
1229
+ )
1215
1230
1216
1231
return [new_car_op (* elm_inputs )]
1217
1232
0 commit comments