Skip to content

Commit f0c3d59

Browse files
committed
Implement rewrite to inline Composite constants
1 parent caa7348 commit f0c3d59

File tree

2 files changed

+78
-1
lines changed

2 files changed

+78
-1
lines changed

pytensor/tensor/rewriting/elemwise.py

+51
Original file line numberDiff line numberDiff line change
@@ -1203,6 +1203,50 @@ def local_careduce_fusion(fgraph, node):
12031203
return [new_car_op(*elm_inputs)]
12041204

12051205

1206+
@register_specialize
1207+
@node_rewriter([Elemwise])
1208+
def local_inline_composite_constants(fgraph, node):
1209+
"""Inline scalar constants in Composite graphs."""
1210+
composite_op = node.op.scalar_op
1211+
1212+
if not isinstance(composite_op, aes.Composite):
1213+
return None
1214+
1215+
new_outer_inputs = []
1216+
new_inner_inputs = []
1217+
inner_replacements = {}
1218+
for outer_inp, inner_inp in zip(node.inputs, composite_op.fgraph.inputs):
1219+
# Complex variables don't have a `c_literal` that can be inlined
1220+
if "complex" not in outer_inp.type.dtype:
1221+
unique_value = get_unique_constant_value(outer_inp)
1222+
if unique_value is not None:
1223+
inner_replacements[inner_inp] = aes.constant(
1224+
unique_value, dtype=inner_inp.dtype
1225+
)
1226+
continue
1227+
new_outer_inputs.append(outer_inp)
1228+
new_inner_inputs.append(inner_inp)
1229+
1230+
if not inner_replacements:
1231+
return None
1232+
1233+
new_inner_outs = clone_replace(
1234+
composite_op.fgraph.outputs, replace=inner_replacements
1235+
)
1236+
new_composite_op = aes.Composite(new_inner_inputs, new_inner_outs)
1237+
new_outputs = Elemwise(new_composite_op).make_node(*new_outer_inputs).outputs
1238+
1239+
# Some of the inlined constants were broadcasting the output shape
1240+
if node.outputs[0].type.broadcastable != new_outputs[0].type.broadcastable:
1241+
new_outputs = [
1242+
broadcast_like(new_out, template=node.outputs[0], fgraph=fgraph)
1243+
for new_out in new_outputs
1244+
]
1245+
1246+
copy_stack_trace(node.outputs, new_outputs)
1247+
return new_outputs
1248+
1249+
12061250
# Register fusion database just before AddDestroyHandler(49.5) (inplace rewrites)
12071251
fuse_seqopt = SequenceDB()
12081252
compile.optdb.register(
@@ -1243,6 +1287,13 @@ def local_careduce_fusion(fgraph, node):
12431287
"fusion",
12441288
position=10,
12451289
)
1290+
fuse_seqopt.register(
1291+
"local_inline_composite_constants",
1292+
in2out(local_inline_composite_constants),
1293+
"fast_run",
1294+
"fusion",
1295+
position=20,
1296+
)
12461297

12471298

12481299
def _rebuild_partial_2f1grad_loop(node, wrt):

tests/tensor/rewriting/test_elemwise.py

+27-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from pytensor.compile.mode import Mode, get_default_mode
1313
from pytensor.configdefaults import config
1414
from pytensor.gradient import grad
15-
from pytensor.graph.basic import Constant, equal_computations
15+
from pytensor.graph.basic import Constant, ancestors, equal_computations
1616
from pytensor.graph.fg import FunctionGraph
1717
from pytensor.graph.rewriting.basic import check_stack_trace, out2in
1818
from pytensor.graph.rewriting.db import RewriteDatabaseQuery
@@ -1461,6 +1461,32 @@ def test_local_useless_composite_outputs():
14611461
utt.assert_allclose(f([[np.nan]], [[1.0]], [[np.nan]]), [[0.0]])
14621462

14631463

1464+
@pytest.mark.parametrize("const_shape", [(), (1,), (5,), (1, 5), (2, 5)])
1465+
@pytest.mark.parametrize("op, np_op", [(at.pow, np.power), (at.add, np.add)])
1466+
def test_local_inline_composite_constants(op, np_op, const_shape):
1467+
const = np.full(shape=const_shape, fill_value=2.5).astype(config.floatX)
1468+
x = vector("x")
1469+
y = vector("y")
1470+
out = at.exp(op(x, const)) + y
1471+
1472+
fn = pytensor.function(
1473+
[x, y], out, mode=get_default_mode().including("specialize", "fusion")
1474+
)
1475+
# There should be a single Composite after optimization
1476+
[node] = [
1477+
node for node in fn.maker.fgraph.apply_nodes if isinstance(node.op, Elemwise)
1478+
]
1479+
assert isinstance(node.op.scalar_op, Composite)
1480+
assert len(node.inputs) == 2 # x and y, but not const
1481+
1482+
x_test_value = np.arange(5).astype(config.floatX)
1483+
y_test_value = np.ones(5).astype(config.floatX)
1484+
np.testing.assert_allclose(
1485+
fn(x_test_value, y_test_value),
1486+
np.exp(np_op(x_test_value, const)) + y_test_value,
1487+
)
1488+
1489+
14641490
def test_local_useless_dimshuffle_makevector():
14651491
a = scalar()
14661492
x = MakeVector(config.floatX)(a)

0 commit comments

Comments
 (0)