|
35 | 35 | from pytensor.tensor.math import exp
|
36 | 36 | from pytensor.tensor.rewriting.basic import register_canonicalize, register_specialize
|
37 | 37 | from pytensor.tensor.shape import shape_padleft
|
38 |
| -from pytensor.tensor.var import TensorConstant |
| 38 | +from pytensor.tensor.var import TensorConstant, get_unique_value |
39 | 39 |
|
40 | 40 |
|
41 | 41 | class InplaceElemwiseOptimizer(GraphRewriter):
|
@@ -1203,6 +1203,44 @@ def local_careduce_fusion(fgraph, node):
|
1203 | 1203 | return [new_car_op(*elm_inputs)]
|
1204 | 1204 |
|
1205 | 1205 |
|
| 1206 | +@register_specialize |
| 1207 | +@node_rewriter([Elemwise]) |
| 1208 | +def local_inline_composite_constants(fgraph, node): |
| 1209 | + """Inline scalar constants in Composite graphs.""" |
| 1210 | + composite_op = node.op.scalar_op |
| 1211 | + |
| 1212 | + if not isinstance(composite_op, aes.Composite): |
| 1213 | + return None |
| 1214 | + |
| 1215 | + new_outer_inputs = [] |
| 1216 | + new_inner_inputs = [] |
| 1217 | + inner_replacements = {} |
| 1218 | + for outer_inp, inner_inp in zip(node.inputs, composite_op.fgraph.inputs): |
| 1219 | + # Complex variables don't have a `c_literal` that can be inlined |
| 1220 | + if "complex" not in outer_inp.type.dtype: |
| 1221 | + unique_value = get_unique_value(outer_inp) |
| 1222 | + if unique_value is not None: |
| 1223 | + inner_replacements[inner_inp] = aes.constant( |
| 1224 | + unique_value, dtype=inner_inp.dtype |
| 1225 | + ) |
| 1226 | + continue |
| 1227 | + new_outer_inputs.append(outer_inp) |
| 1228 | + new_inner_inputs.append(inner_inp) |
| 1229 | + |
| 1230 | + if not inner_replacements: |
| 1231 | + return None |
| 1232 | + |
| 1233 | + new_inner_outs = clone_replace( |
| 1234 | + composite_op.fgraph.outputs, replace=inner_replacements |
| 1235 | + ) |
| 1236 | + new_composite_op = aes.Composite(new_inner_inputs, new_inner_outs) |
| 1237 | + new_outputs = Elemwise(new_composite_op).make_node(*new_outer_inputs).outputs |
| 1238 | + |
| 1239 | + copy_stack_trace(node.outputs, new_outputs) |
| 1240 | + |
| 1241 | + return new_outputs |
| 1242 | + |
| 1243 | + |
1206 | 1244 | # Register fusion database just before AddDestroyHandler(49.5) (inplace rewrites)
|
1207 | 1245 | fuse_seqopt = SequenceDB()
|
1208 | 1246 | compile.optdb.register(
|
@@ -1243,6 +1281,13 @@ def local_careduce_fusion(fgraph, node):
|
1243 | 1281 | "fusion",
|
1244 | 1282 | position=10,
|
1245 | 1283 | )
|
| 1284 | +fuse_seqopt.register( |
| 1285 | + "local_inline_composite_constants", |
| 1286 | + in2out(local_inline_composite_constants), |
| 1287 | + "fast_run", |
| 1288 | + "fusion", |
| 1289 | + position=20, |
| 1290 | +) |
1246 | 1291 |
|
1247 | 1292 |
|
1248 | 1293 | def _rebuild_partial_2f1grad_loop(node, wrt):
|
|
0 commit comments