Skip to content

Commit b8be9e8

Browse files
committed
Implement rewrite to inline Composite constants
1 parent 8289de3 commit b8be9e8

File tree

2 files changed

+72
-1
lines changed

2 files changed

+72
-1
lines changed

pytensor/tensor/rewriting/elemwise.py

+50
Original file line numberDiff line numberDiff line change
@@ -1254,6 +1254,49 @@ def local_careduce_fusion(fgraph, node):
12541254
return [new_car_op(*elm_inputs)]
12551255

12561256

1257+
@register_specialize
1258+
@node_rewriter([Elemwise])
1259+
def local_inline_composite_constants(fgraph, node):
1260+
"""Inline scalar constants in Composite graphs."""
1261+
composite_op = node.op.scalar_op
1262+
1263+
if not isinstance(composite_op, aes.Composite):
1264+
return None
1265+
1266+
new_outer_inputs = []
1267+
new_inner_inputs = []
1268+
inner_replacements = {}
1269+
for outer_inp, inner_inp in zip(node.inputs, composite_op.fgraph.inputs):
1270+
# Complex variables don't have a `c_literal` that can be inlined
1271+
if "complex" not in outer_inp.type.dtype:
1272+
unique_value = get_unique_constant_value(outer_inp)
1273+
if unique_value is not None:
1274+
inner_replacements[inner_inp] = aes.constant(
1275+
unique_value, dtype=inner_inp.dtype
1276+
)
1277+
continue
1278+
new_outer_inputs.append(outer_inp)
1279+
new_inner_inputs.append(inner_inp)
1280+
1281+
if not inner_replacements:
1282+
return None
1283+
1284+
new_inner_outs = clone_replace(
1285+
composite_op.fgraph.outputs, replace=inner_replacements
1286+
)
1287+
new_composite_op = aes.Composite(new_inner_inputs, new_inner_outs)
1288+
new_outputs = Elemwise(new_composite_op).make_node(*new_outer_inputs).outputs
1289+
1290+
# Some of the inlined constants were broadcasting the output shape
1291+
# This should be taken care of by `local_replace_broadcasted_constant`
1292+
for old_out, new_out in zip(node.outputs, new_outputs):
1293+
if old_out.type.broadcastable != new_out.type.broadcastable:
1294+
return None
1295+
1296+
copy_stack_trace(node.outputs, new_outputs)
1297+
return new_outputs
1298+
1299+
12571300
# Register fusion database just before AddDestroyHandler(49.5) (inplace rewrites)
12581301
fuse_seqopt = SequenceDB()
12591302
compile.optdb.register(
@@ -1294,6 +1337,13 @@ def local_careduce_fusion(fgraph, node):
12941337
"fusion",
12951338
position=10,
12961339
)
1340+
fuse_seqopt.register(
1341+
"local_inline_composite_constants",
1342+
in2out(local_inline_composite_constants),
1343+
"fast_run",
1344+
"fusion",
1345+
position=20,
1346+
)
12971347

12981348

12991349
def _rebuild_partial_2f1grad_loop(node, wrt):

tests/tensor/rewriting/test_elemwise.py

+22-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from pytensor.compile.mode import Mode, get_default_mode
1313
from pytensor.configdefaults import config
1414
from pytensor.gradient import grad
15-
from pytensor.graph.basic import Constant, equal_computations
15+
from pytensor.graph.basic import Constant, ancestors, equal_computations
1616
from pytensor.graph.fg import FunctionGraph
1717
from pytensor.graph.rewriting.basic import check_stack_trace, out2in
1818
from pytensor.graph.rewriting.db import RewriteDatabaseQuery
@@ -1474,6 +1474,27 @@ def test_local_useless_composite_outputs():
14741474
utt.assert_allclose(f([[np.nan]], [[1.0]], [[np.nan]]), [[0.0]])
14751475

14761476

1477+
@pytest.mark.parametrize("const_shape", [(), (1,), (5,), (1, 5), (2, 5)])
1478+
@pytest.mark.parametrize("op, np_op", [(at.pow, np.power), (at.add, np.add)])
1479+
def test_local_inline_composite_constants(op, np_op, const_shape):
1480+
x = vector("x")
1481+
const = np.full(shape=const_shape, fill_value=2.5)
1482+
out = exp(op(x, const))
1483+
1484+
fn = pytensor.function(
1485+
[x], out, mode=get_default_mode().including("specialize", "fusion")
1486+
)
1487+
[node] = [
1488+
node for node in fn.maker.fgraph.apply_nodes if isinstance(node.op, Elemwise)
1489+
]
1490+
assert isinstance(node.op.scalar_op, Composite)
1491+
assert len(node.inputs) == 1
1492+
assert x in ancestors(node.inputs)
1493+
1494+
test_value = np.arange(5).astype(config.floatX)
1495+
np.testing.assert_allclose(fn(test_value), np.exp(np_op(test_value, const)))
1496+
1497+
14771498
def test_local_useless_dimshuffle_makevector():
14781499
a = scalar()
14791500
x = MakeVector(config.floatX)(a)

0 commit comments

Comments
 (0)