Skip to content

Commit 71c018b

Browse files
committed
Add rewrite to remove broadcasted constants from Elemwise graphs
1 parent b618a0f commit 71c018b

File tree

3 files changed

+85
-4
lines changed

3 files changed

+85
-4
lines changed

pytensor/tensor/rewriting/elemwise.py

+68-2
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
from typing import DefaultDict, Generator, List, Set, Tuple, TypeVar
55
from warnings import warn
66

7+
import numpy as np
8+
79
import pytensor
810
import pytensor.scalar.basic as aes
911
from pytensor import clone_replace, compile
@@ -28,14 +30,19 @@
2830
MakeVector,
2931
alloc,
3032
cast,
33+
constant,
3134
get_underlying_scalar_constant_value,
3235
)
3336
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
3437
from pytensor.tensor.exceptions import NotScalarConstantError
3538
from pytensor.tensor.math import exp
36-
from pytensor.tensor.rewriting.basic import register_canonicalize, register_specialize
39+
from pytensor.tensor.rewriting.basic import (
40+
broadcast_like,
41+
register_canonicalize,
42+
register_specialize,
43+
)
3744
from pytensor.tensor.shape import shape_padleft
38-
from pytensor.tensor.var import TensorConstant
45+
from pytensor.tensor.var import TensorConstant, get_unique_constant_value
3946

4047

4148
class InplaceElemwiseOptimizer(GraphRewriter):
@@ -1295,6 +1302,65 @@ def local_inline_composite_constants(fgraph, node):
12951302
)
12961303

12971304

1305+
@node_rewriter([Elemwise])
1306+
def local_replace_broadcasted_constants(fgraph, node):
1307+
"""Remove broadcasted constants from Elemwise graphs
1308+
1309+
Elemwise(matrix, ones((3, 4))) -> Elemwise(vector, ones((1, 1)))
1310+
1311+
In cases where the constant influenced the final shape of the Elemwise operation
1312+
We broadcast (via alloc) the new Elemwise result:
1313+
1314+
Elemwise(row, ones((3, 4))) -> Alloc(Elemwise(row, ones((1, 1))), 3, 4)
1315+
1316+
This will avoid useless iterations over constant arrays.
1317+
"""
1318+
if len(node.inputs) == 1:
1319+
return None
1320+
1321+
new_elem_inps = []
1322+
ndims = node.outputs[0].type.ndim
1323+
found_const = False
1324+
for inp in node.inputs:
1325+
# If input has non-broadcastable dims
1326+
if not all(b for b in inp.type.broadcastable):
1327+
constant_value = get_unique_constant_value(inp)
1328+
if constant_value is not None:
1329+
constant_value = np.expand_dims(
1330+
constant_value, axis=tuple(range(ndims))
1331+
).astype(inp.type.dtype)
1332+
new_elem_inps.append(constant(constant_value))
1333+
found_const = True
1334+
continue
1335+
1336+
new_elem_inps.append(inp)
1337+
1338+
if not found_const:
1339+
return None
1340+
1341+
new_outs = node.op.make_node(*new_elem_inps).outputs
1342+
1343+
# The constants were needed to enforce the output shape
1344+
if node.outputs[0].type.broadcastable != new_outs[0].type.broadcastable:
1345+
new_outs = [
1346+
broadcast_like(new_out, template=node.outputs[0], fgraph=fgraph)
1347+
for new_out in new_outs
1348+
]
1349+
1350+
copy_stack_trace(node.outputs, new_outs)
1351+
return new_outs
1352+
1353+
1354+
# We register this immediately after the fusion database.
1355+
# We don't want Allocs to break up the fusion rewrites
1356+
compile.optdb.register(
1357+
"local_replace_broadcasted_constants",
1358+
in2out(local_replace_broadcasted_constants),
1359+
"fast_run",
1360+
position=49.01,
1361+
)
1362+
1363+
12981364
def _rebuild_partial_2f1grad_loop(node, wrt):
12991365
a, b, c, log_z, sign_z = node.inputs[-5:]
13001366
z = exp(log_z) * sign_z

tests/tensor/rewriting/test_elemwise.py

+16-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from pytensor.compile.mode import Mode, get_default_mode
1313
from pytensor.configdefaults import config
1414
from pytensor.gradient import grad
15-
from pytensor.graph.basic import Constant, ancestors, equal_computations
15+
from pytensor.graph.basic import Constant, equal_computations
1616
from pytensor.graph.fg import FunctionGraph
1717
from pytensor.graph.rewriting.basic import check_stack_trace, out2in
1818
from pytensor.graph.rewriting.db import RewriteDatabaseQuery
@@ -178,6 +178,21 @@ def test_dimshuffle_lift_multi_out_elemwise(self):
178178
assert not local_dimshuffle_lift.transform(g, g.outputs[0].owner)
179179

180180

181+
def test_local_replace_broadcasted_constants():
182+
const = np.full(shape=(2, 5), fill_value=2.6)
183+
x = scalar("x")
184+
out = at.power(x, const)
185+
new_out = rewrite_graph(
186+
out, include=["ShapeOpt", "local_replace_broadcasted_constants"]
187+
)
188+
ref_out = at.alloc(
189+
at.power(x, [[2.6]]),
190+
at.constant(2, dtype="int64"),
191+
at.constant(5, dtype="int64"),
192+
)
193+
assert equal_computations([new_out], [ref_out])
194+
195+
181196
def test_local_useless_dimshuffle_in_reshape():
182197
vec = TensorType(dtype="float64", shape=(None,))("vector")
183198
mat = TensorType(dtype="float64", shape=(None, None))("mat")

tests/tensor/rewriting/test_math.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -987,7 +987,7 @@ def test_specified_shape_by_constant(self):
987987
new_out = rewrite_graph(
988988
out, custom_rewrite=in2out(local_mul_canonizer, name="test")
989989
)
990-
expected_out = [2.0] * specify_shape(x, (5,))
990+
expected_out = np.array([2.0], dtype=config.floatX) * specify_shape(x, (5,))
991991
assert equal_computations([new_out], [expected_out])
992992

993993
def test_broadcasted_by_constant(self):

0 commit comments

Comments
 (0)