Skip to content

Commit 451b18c

Browse files
committed
Implement rewrite to inline Composite constants
1 parent eb2431d commit 451b18c

File tree

2 files changed

+82
-2
lines changed

2 files changed

+82
-2
lines changed

pytensor/tensor/rewriting/elemwise.py

+56-2
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,13 @@
3333
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
3434
from pytensor.tensor.exceptions import NotScalarConstantError
3535
from pytensor.tensor.math import exp
36-
from pytensor.tensor.rewriting.basic import register_canonicalize, register_specialize
36+
from pytensor.tensor.rewriting.basic import (
37+
broadcast_like,
38+
register_canonicalize,
39+
register_specialize,
40+
)
3741
from pytensor.tensor.shape import shape_padleft
38-
from pytensor.tensor.var import TensorConstant
42+
from pytensor.tensor.var import TensorConstant, get_unique_constant_value
3943

4044

4145
class InplaceElemwiseOptimizer(GraphRewriter):
@@ -1203,6 +1207,49 @@ def local_careduce_fusion(fgraph, node):
12031207
return [new_car_op(*elm_inputs)]
12041208

12051209

1210+
@node_rewriter([Elemwise])
1211+
def local_inline_composite_constants(fgraph, node):
1212+
"""Inline scalar constants in Composite graphs."""
1213+
composite_op = node.op.scalar_op
1214+
1215+
if not isinstance(composite_op, aes.Composite):
1216+
return None
1217+
1218+
new_outer_inputs = []
1219+
new_inner_inputs = []
1220+
inner_replacements = {}
1221+
for outer_inp, inner_inp in zip(node.inputs, composite_op.fgraph.inputs):
1222+
# Complex variables don't have a `c_literal` that can be inlined
1223+
if "complex" not in outer_inp.type.dtype:
1224+
unique_value = get_unique_constant_value(outer_inp)
1225+
if unique_value is not None:
1226+
inner_replacements[inner_inp] = aes.constant(
1227+
unique_value, dtype=inner_inp.dtype
1228+
)
1229+
continue
1230+
new_outer_inputs.append(outer_inp)
1231+
new_inner_inputs.append(inner_inp)
1232+
1233+
if not inner_replacements:
1234+
return None
1235+
1236+
new_inner_outs = clone_replace(
1237+
composite_op.fgraph.outputs, replace=inner_replacements
1238+
)
1239+
new_composite_op = aes.Composite(new_inner_inputs, new_inner_outs)
1240+
new_outputs = Elemwise(new_composite_op).make_node(*new_outer_inputs).outputs
1241+
1242+
# Some of the inlined constants were broadcasting the output shape
1243+
if node.outputs[0].type.broadcastable != new_outputs[0].type.broadcastable:
1244+
new_outputs = [
1245+
broadcast_like(new_out, template=node.outputs[0], fgraph=fgraph)
1246+
for new_out in new_outputs
1247+
]
1248+
1249+
copy_stack_trace(node.outputs, new_outputs)
1250+
return new_outputs
1251+
1252+
12061253
# Register fusion database just before AddDestroyHandler(49.5) (inplace rewrites)
12071254
fuse_seqopt = SequenceDB()
12081255
compile.optdb.register(
@@ -1243,6 +1290,13 @@ def local_careduce_fusion(fgraph, node):
12431290
"fusion",
12441291
position=10,
12451292
)
1293+
fuse_seqopt.register(
1294+
"local_inline_composite_constants",
1295+
in2out(local_inline_composite_constants),
1296+
"fast_run",
1297+
"fusion",
1298+
position=20,
1299+
)
12461300

12471301

12481302
def _rebuild_partial_2f1grad_loop(node, wrt):

tests/tensor/rewriting/test_elemwise.py

+26
Original file line numberDiff line numberDiff line change
@@ -1461,6 +1461,32 @@ def test_local_useless_composite_outputs():
14611461
utt.assert_allclose(f([[np.nan]], [[1.0]], [[np.nan]]), [[0.0]])
14621462

14631463

1464+
@pytest.mark.parametrize("const_shape", [(), (1,), (5,), (1, 5), (2, 5)])
1465+
@pytest.mark.parametrize("op, np_op", [(at.pow, np.power), (at.add, np.add)])
1466+
def test_local_inline_composite_constants(op, np_op, const_shape):
1467+
const = np.full(shape=const_shape, fill_value=2.5).astype(config.floatX)
1468+
x = vector("x")
1469+
y = vector("y")
1470+
out = at.exp(op(x, const)) + y
1471+
1472+
fn = pytensor.function(
1473+
[x, y], out, mode=get_default_mode().including("specialize", "fusion")
1474+
)
1475+
# There should be a single Composite after optimization
1476+
[node] = [
1477+
node for node in fn.maker.fgraph.apply_nodes if isinstance(node.op, Elemwise)
1478+
]
1479+
assert isinstance(node.op.scalar_op, Composite)
1480+
assert len(node.inputs) == 2 # x and y, but not const
1481+
1482+
x_test_value = np.arange(5).astype(config.floatX)
1483+
y_test_value = np.ones(5).astype(config.floatX)
1484+
np.testing.assert_allclose(
1485+
fn(x_test_value, y_test_value),
1486+
np.exp(np_op(x_test_value, const)) + y_test_value,
1487+
)
1488+
1489+
14641490
def test_local_useless_dimshuffle_makevector():
14651491
a = scalar()
14661492
x = MakeVector(config.floatX)(a)

0 commit comments

Comments
 (0)