|
31 | 31 | alloc,
|
32 | 32 | cast,
|
33 | 33 | constant,
|
34 |
| - get_underlying_scalar_constant_value, |
| 34 | + get_underlying_scalar_constant_value, constant, |
35 | 35 | )
|
36 | 36 | from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
|
37 | 37 | from pytensor.tensor.exceptions import NotScalarConstantError
|
@@ -1254,6 +1254,49 @@ def local_careduce_fusion(fgraph, node):
|
1254 | 1254 | return [new_car_op(*elm_inputs)]
|
1255 | 1255 |
|
1256 | 1256 |
|
| 1257 | +@register_specialize |
| 1258 | +@node_rewriter([Elemwise]) |
| 1259 | +def local_inline_composite_constants(fgraph, node): |
| 1260 | + """Inline scalar constants in Composite graphs.""" |
| 1261 | + composite_op = node.op.scalar_op |
| 1262 | + |
| 1263 | + if not isinstance(composite_op, aes.Composite): |
| 1264 | + return None |
| 1265 | + |
| 1266 | + new_outer_inputs = [] |
| 1267 | + new_inner_inputs = [] |
| 1268 | + inner_replacements = {} |
| 1269 | + for outer_inp, inner_inp in zip(node.inputs, composite_op.fgraph.inputs): |
| 1270 | + # Complex variables don't have a `c_literal` that can be inlined |
| 1271 | + if "complex" not in outer_inp.type.dtype: |
| 1272 | + unique_value = get_unique_constant_value(outer_inp) |
| 1273 | + if unique_value is not None: |
| 1274 | + inner_replacements[inner_inp] = aes.constant( |
| 1275 | + unique_value, dtype=inner_inp.dtype |
| 1276 | + ) |
| 1277 | + continue |
| 1278 | + new_outer_inputs.append(outer_inp) |
| 1279 | + new_inner_inputs.append(inner_inp) |
| 1280 | + |
| 1281 | + if not inner_replacements: |
| 1282 | + return None |
| 1283 | + |
| 1284 | + new_inner_outs = clone_replace( |
| 1285 | + composite_op.fgraph.outputs, replace=inner_replacements |
| 1286 | + ) |
| 1287 | + new_composite_op = aes.Composite(new_inner_inputs, new_inner_outs) |
| 1288 | + new_outputs = Elemwise(new_composite_op).make_node(*new_outer_inputs).outputs |
| 1289 | + |
| 1290 | + # Some of the inlined constants were broadcasting the output shape |
| 1291 | + # This should be taken care of by `local_replace_broadcasted_constant` |
| 1292 | + for old_out, new_out in zip(node.outputs, new_outputs): |
| 1293 | + if old_out.type.broadcastable != new_out.type.broadcastable: |
| 1294 | + return None |
| 1295 | + |
| 1296 | + copy_stack_trace(node.outputs, new_outputs) |
| 1297 | + return new_outputs |
| 1298 | + |
| 1299 | + |
1257 | 1300 | # Register fusion database just before AddDestroyHandler(49.5) (inplace rewrites)
|
1258 | 1301 | fuse_seqopt = SequenceDB()
|
1259 | 1302 | compile.optdb.register(
|
@@ -1294,6 +1337,13 @@ def local_careduce_fusion(fgraph, node):
|
1294 | 1337 | "fusion",
|
1295 | 1338 | position=10,
|
1296 | 1339 | )
|
| 1340 | +fuse_seqopt.register( |
| 1341 | + "local_inline_composite_constants", |
| 1342 | + in2out(local_inline_composite_constants), |
| 1343 | + "fast_run", |
| 1344 | + "fusion", |
| 1345 | + position=20, |
| 1346 | +) |
1297 | 1347 |
|
1298 | 1348 |
|
1299 | 1349 | def _rebuild_partial_2f1grad_loop(node, wrt):
|
|
0 commit comments