|
33 | 33 | from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
|
34 | 34 | from pytensor.tensor.exceptions import NotScalarConstantError
|
35 | 35 | from pytensor.tensor.math import exp
|
36 |
| -from pytensor.tensor.rewriting.basic import register_canonicalize, register_specialize |
| 36 | +from pytensor.tensor.rewriting.basic import ( |
| 37 | + broadcast_like, |
| 38 | + register_canonicalize, |
| 39 | + register_specialize, |
| 40 | +) |
37 | 41 | from pytensor.tensor.shape import shape_padleft
|
38 |
| -from pytensor.tensor.var import TensorConstant |
| 42 | +from pytensor.tensor.var import TensorConstant, get_unique_constant_value |
39 | 43 |
|
40 | 44 |
|
41 | 45 | class InplaceElemwiseOptimizer(GraphRewriter):
|
@@ -1203,6 +1207,49 @@ def local_careduce_fusion(fgraph, node):
|
1203 | 1207 | return [new_car_op(*elm_inputs)]
|
1204 | 1208 |
|
1205 | 1209 |
|
| 1210 | +@node_rewriter([Elemwise]) |
| 1211 | +def local_inline_composite_constants(fgraph, node): |
| 1212 | + """Inline scalar constants in Composite graphs.""" |
| 1213 | + composite_op = node.op.scalar_op |
| 1214 | + |
| 1215 | + if not isinstance(composite_op, aes.Composite): |
| 1216 | + return None |
| 1217 | + |
| 1218 | + new_outer_inputs = [] |
| 1219 | + new_inner_inputs = [] |
| 1220 | + inner_replacements = {} |
| 1221 | + for outer_inp, inner_inp in zip(node.inputs, composite_op.fgraph.inputs): |
| 1222 | + # Complex variables don't have a `c_literal` that can be inlined |
| 1223 | + if "complex" not in outer_inp.type.dtype: |
| 1224 | + unique_value = get_unique_constant_value(outer_inp) |
| 1225 | + if unique_value is not None: |
| 1226 | + inner_replacements[inner_inp] = aes.constant( |
| 1227 | + unique_value, dtype=inner_inp.dtype |
| 1228 | + ) |
| 1229 | + continue |
| 1230 | + new_outer_inputs.append(outer_inp) |
| 1231 | + new_inner_inputs.append(inner_inp) |
| 1232 | + |
| 1233 | + if not inner_replacements: |
| 1234 | + return None |
| 1235 | + |
| 1236 | + new_inner_outs = clone_replace( |
| 1237 | + composite_op.fgraph.outputs, replace=inner_replacements |
| 1238 | + ) |
| 1239 | + new_composite_op = aes.Composite(new_inner_inputs, new_inner_outs) |
| 1240 | + new_outputs = Elemwise(new_composite_op).make_node(*new_outer_inputs).outputs |
| 1241 | + |
| 1242 | + # Some of the inlined constants were broadcasting the output shape |
| 1243 | + if node.outputs[0].type.broadcastable != new_outputs[0].type.broadcastable: |
| 1244 | + new_outputs = [ |
| 1245 | + broadcast_like(new_out, template=node.outputs[0], fgraph=fgraph) |
| 1246 | + for new_out in new_outputs |
| 1247 | + ] |
| 1248 | + |
| 1249 | + copy_stack_trace(node.outputs, new_outputs) |
| 1250 | + return new_outputs |
| 1251 | + |
| 1252 | + |
1206 | 1253 | # Register fusion database just before AddDestroyHandler(49.5) (inplace rewrites)
|
1207 | 1254 | fuse_seqopt = SequenceDB()
|
1208 | 1255 | compile.optdb.register(
|
@@ -1243,6 +1290,13 @@ def local_careduce_fusion(fgraph, node):
|
1243 | 1290 | "fusion",
|
1244 | 1291 | position=10,
|
1245 | 1292 | )
|
| 1293 | +fuse_seqopt.register( |
| 1294 | + "local_inline_composite_constants", |
| 1295 | + in2out(local_inline_composite_constants), |
| 1296 | + "fast_run", |
| 1297 | + "fusion", |
| 1298 | + position=20, |
| 1299 | +) |
1246 | 1300 |
|
1247 | 1301 |
|
1248 | 1302 | def _rebuild_partial_2f1grad_loop(node, wrt):
|
|
0 commit comments