@@ -1254,6 +1254,49 @@ def local_careduce_fusion(fgraph, node):
1254
1254
return [new_car_op (* elm_inputs )]
1255
1255
1256
1256
1257
+ @register_specialize
1258
+ @node_rewriter ([Elemwise ])
1259
+ def local_inline_composite_constants (fgraph , node ):
1260
+ """Inline scalar constants in Composite graphs."""
1261
+ composite_op = node .op .scalar_op
1262
+
1263
+ if not isinstance (composite_op , aes .Composite ):
1264
+ return None
1265
+
1266
+ new_outer_inputs = []
1267
+ new_inner_inputs = []
1268
+ inner_replacements = {}
1269
+ for outer_inp , inner_inp in zip (node .inputs , composite_op .fgraph .inputs ):
1270
+ # Complex variables don't have a `c_literal` that can be inlined
1271
+ if "complex" not in outer_inp .type .dtype :
1272
+ unique_value = get_unique_constant_value (outer_inp )
1273
+ if unique_value is not None :
1274
+ inner_replacements [inner_inp ] = aes .constant (
1275
+ unique_value , dtype = inner_inp .dtype
1276
+ )
1277
+ continue
1278
+ new_outer_inputs .append (outer_inp )
1279
+ new_inner_inputs .append (inner_inp )
1280
+
1281
+ if not inner_replacements :
1282
+ return None
1283
+
1284
+ new_inner_outs = clone_replace (
1285
+ composite_op .fgraph .outputs , replace = inner_replacements
1286
+ )
1287
+ new_composite_op = aes .Composite (new_inner_inputs , new_inner_outs )
1288
+ new_outputs = Elemwise (new_composite_op ).make_node (* new_outer_inputs ).outputs
1289
+
1290
+ # Some of the inlined constants were broadcasting the output shape
1291
+ # This should be taken care of by `local_replace_broadcasted_constant`
1292
+ for old_out , new_out in zip (node .outputs , new_outputs ):
1293
+ if old_out .type .broadcastable != new_out .type .broadcastable :
1294
+ return None
1295
+
1296
+ copy_stack_trace (node .outputs , new_outputs )
1297
+ return new_outputs
1298
+
1299
+
1257
1300
# Register fusion database just before AddDestroyHandler(49.5) (inplace rewrites)
1258
1301
fuse_seqopt = SequenceDB ()
1259
1302
compile .optdb .register (
@@ -1294,6 +1337,13 @@ def local_careduce_fusion(fgraph, node):
1294
1337
"fusion" ,
1295
1338
position = 10 ,
1296
1339
)
1340
+ fuse_seqopt .register (
1341
+ "local_inline_composite_constants" ,
1342
+ in2out (local_inline_composite_constants ),
1343
+ "fast_run" ,
1344
+ "fusion" ,
1345
+ position = 20 ,
1346
+ )
1297
1347
1298
1348
1299
1349
def _rebuild_partial_2f1grad_loop (node , wrt ):
0 commit comments