Skip to content

Commit b618a0f

Browse files
committed
Implement rewrite to inline Composite constants
1 parent caa7348 commit b618a0f

File tree

2 files changed

+77
-1
lines changed

2 files changed

+77
-1
lines changed

pytensor/tensor/rewriting/elemwise.py

+50
Original file line numberDiff line numberDiff line change
@@ -1203,6 +1203,49 @@ def local_careduce_fusion(fgraph, node):
12031203
return [new_car_op(*elm_inputs)]
12041204

12051205

1206+
@node_rewriter([Elemwise])
1207+
def local_inline_composite_constants(fgraph, node):
1208+
"""Inline scalar constants in Composite graphs."""
1209+
composite_op = node.op.scalar_op
1210+
1211+
if not isinstance(composite_op, aes.Composite):
1212+
return None
1213+
1214+
new_outer_inputs = []
1215+
new_inner_inputs = []
1216+
inner_replacements = {}
1217+
for outer_inp, inner_inp in zip(node.inputs, composite_op.fgraph.inputs):
1218+
# Complex variables don't have a `c_literal` that can be inlined
1219+
if "complex" not in outer_inp.type.dtype:
1220+
unique_value = get_unique_constant_value(outer_inp)
1221+
if unique_value is not None:
1222+
inner_replacements[inner_inp] = aes.constant(
1223+
unique_value, dtype=inner_inp.dtype
1224+
)
1225+
continue
1226+
new_outer_inputs.append(outer_inp)
1227+
new_inner_inputs.append(inner_inp)
1228+
1229+
if not inner_replacements:
1230+
return None
1231+
1232+
new_inner_outs = clone_replace(
1233+
composite_op.fgraph.outputs, replace=inner_replacements
1234+
)
1235+
new_composite_op = aes.Composite(new_inner_inputs, new_inner_outs)
1236+
new_outputs = Elemwise(new_composite_op).make_node(*new_outer_inputs).outputs
1237+
1238+
# Some of the inlined constants were broadcasting the output shape
1239+
if node.outputs[0].type.broadcastable != new_outputs[0].type.broadcastable:
1240+
new_outputs = [
1241+
broadcast_like(new_out, template=node.outputs[0], fgraph=fgraph)
1242+
for new_out in new_outputs
1243+
]
1244+
1245+
copy_stack_trace(node.outputs, new_outputs)
1246+
return new_outputs
1247+
1248+
12061249
# Register fusion database just before AddDestroyHandler(49.5) (inplace rewrites)
12071250
fuse_seqopt = SequenceDB()
12081251
compile.optdb.register(
@@ -1243,6 +1286,13 @@ def local_careduce_fusion(fgraph, node):
12431286
"fusion",
12441287
position=10,
12451288
)
1289+
fuse_seqopt.register(
1290+
"local_inline_composite_constants",
1291+
in2out(local_inline_composite_constants),
1292+
"fast_run",
1293+
"fusion",
1294+
position=20,
1295+
)
12461296

12471297

12481298
def _rebuild_partial_2f1grad_loop(node, wrt):

tests/tensor/rewriting/test_elemwise.py

+27-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from pytensor.compile.mode import Mode, get_default_mode
1313
from pytensor.configdefaults import config
1414
from pytensor.gradient import grad
15-
from pytensor.graph.basic import Constant, equal_computations
15+
from pytensor.graph.basic import Constant, ancestors, equal_computations
1616
from pytensor.graph.fg import FunctionGraph
1717
from pytensor.graph.rewriting.basic import check_stack_trace, out2in
1818
from pytensor.graph.rewriting.db import RewriteDatabaseQuery
@@ -1461,6 +1461,32 @@ def test_local_useless_composite_outputs():
14611461
utt.assert_allclose(f([[np.nan]], [[1.0]], [[np.nan]]), [[0.0]])
14621462

14631463

1464+
@pytest.mark.parametrize("const_shape", [(), (1,), (5,), (1, 5), (2, 5)])
1465+
@pytest.mark.parametrize("op, np_op", [(at.pow, np.power), (at.add, np.add)])
1466+
def test_local_inline_composite_constants(op, np_op, const_shape):
1467+
const = np.full(shape=const_shape, fill_value=2.5).astype(config.floatX)
1468+
x = vector("x")
1469+
y = vector("y")
1470+
out = at.exp(op(x, const)) + y
1471+
1472+
fn = pytensor.function(
1473+
[x, y], out, mode=get_default_mode().including("specialize", "fusion")
1474+
)
1475+
# There should be a single Composite after optimization
1476+
[node] = [
1477+
node for node in fn.maker.fgraph.apply_nodes if isinstance(node.op, Elemwise)
1478+
]
1479+
assert isinstance(node.op.scalar_op, Composite)
1480+
assert len(node.inputs) == 2 # x and y, but not const
1481+
1482+
x_test_value = np.arange(5).astype(config.floatX)
1483+
y_test_value = np.ones(5).astype(config.floatX)
1484+
np.testing.assert_allclose(
1485+
fn(x_test_value, y_test_value),
1486+
np.exp(np_op(x_test_value, const)) + y_test_value,
1487+
)
1488+
1489+
14641490
def test_local_useless_dimshuffle_makevector():
14651491
a = scalar()
14661492
x = MakeVector(config.floatX)(a)

0 commit comments

Comments
 (0)