Skip to content

Commit ce2d4c3

Browse files
committed
Implement rewrite to inline Composite constants
1 parent 8933712 commit ce2d4c3

File tree

2 files changed

+60
-1
lines changed

2 files changed

+60
-1
lines changed

pytensor/tensor/rewriting/elemwise.py

+46-1
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
from pytensor.tensor.math import exp
3636
from pytensor.tensor.rewriting.basic import register_canonicalize, register_specialize
3737
from pytensor.tensor.shape import shape_padleft
38-
from pytensor.tensor.var import TensorConstant
38+
from pytensor.tensor.var import TensorConstant, get_unique_value
3939

4040

4141
class InplaceElemwiseOptimizer(GraphRewriter):
@@ -1203,6 +1203,44 @@ def local_careduce_fusion(fgraph, node):
12031203
return [new_car_op(*elm_inputs)]
12041204

12051205

1206+
@register_specialize
1207+
@node_rewriter([Elemwise])
1208+
def local_inline_composite_constants(fgraph, node):
1209+
"""Inline scalar constants in Composite graphs."""
1210+
composite_op = node.op.scalar_op
1211+
1212+
if not isinstance(composite_op, aes.Composite):
1213+
return None
1214+
1215+
new_outer_inputs = []
1216+
new_inner_inputs = []
1217+
inner_replacements = {}
1218+
for outer_inp, inner_inp in zip(node.inputs, composite_op.fgraph.inputs):
1219+
# Complex variables don't have a `c_literal` that can be inlined
1220+
if "complex" not in outer_inp.type.dtype:
1221+
unique_value = get_unique_value(outer_inp)
1222+
if unique_value is not None:
1223+
inner_replacements[inner_inp] = aes.constant(
1224+
unique_value, dtype=inner_inp.dtype
1225+
)
1226+
continue
1227+
new_outer_inputs.append(outer_inp)
1228+
new_inner_inputs.append(inner_inp)
1229+
1230+
if not inner_replacements:
1231+
return None
1232+
1233+
new_inner_outs = clone_replace(
1234+
composite_op.fgraph.outputs, replace=inner_replacements
1235+
)
1236+
new_composite_op = aes.Composite(new_inner_inputs, new_inner_outs)
1237+
new_outputs = Elemwise(new_composite_op).make_node(*new_outer_inputs).outputs
1238+
1239+
copy_stack_trace(node.outputs, new_outputs)
1240+
1241+
return new_outputs
1242+
1243+
12061244
# Register fusion database just before AddDestroyHandler(49.5) (inplace rewrites)
12071245
fuse_seqopt = SequenceDB()
12081246
compile.optdb.register(
@@ -1243,6 +1281,13 @@ def local_careduce_fusion(fgraph, node):
12431281
"fusion",
12441282
position=10,
12451283
)
1284+
fuse_seqopt.register(
1285+
"local_inline_composite_constants",
1286+
in2out(local_inline_composite_constants),
1287+
"fast_run",
1288+
"fusion",
1289+
position=20,
1290+
)
12461291

12471292

12481293
def _rebuild_partial_2f1grad_loop(node, wrt):

tests/tensor/rewriting/test_elemwise.py

+14
Original file line numberDiff line numberDiff line change
@@ -1461,6 +1461,20 @@ def test_local_useless_composite_outputs():
14611461
utt.assert_allclose(f([[np.nan]], [[1.0]], [[np.nan]]), [[0.0]])
14621462

14631463

1464+
def test_local_inline_composite_constants():
1465+
x = vector("x")
1466+
out = exp(x + 2)
1467+
1468+
fn = pytensor.function([x], out, mode=get_default_mode().including("fusion"))
1469+
1470+
[node] = fn.maker.fgraph.apply_nodes
1471+
assert isinstance(node.op, Elemwise) and isinstance(node.op.scalar_op, Composite)
1472+
assert node.inputs == [x]
1473+
1474+
test_value = np.arange(5).astype(config.floatX)
1475+
np.testing.assert_allclose(fn(test_value), np.exp(test_value + 2))
1476+
1477+
14641478
def test_local_useless_dimshuffle_makevector():
14651479
a = scalar()
14661480
x = MakeVector(config.floatX)(a)

0 commit comments

Comments
 (0)