Skip to content

Commit be115bb

Browse files
committed
Add rewrite to remove broadcasted constants from Elemwise graphs
1 parent f0c3d59 commit be115bb

File tree

2 files changed

+84
-3
lines changed

2 files changed

+84
-3
lines changed

pytensor/tensor/rewriting/elemwise.py

+68-2
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
from typing import DefaultDict, Generator, List, Set, Tuple, TypeVar
55
from warnings import warn
66

7+
import numpy as np
8+
79
import pytensor
810
import pytensor.scalar.basic as aes
911
from pytensor import clone_replace, compile
@@ -28,14 +30,19 @@
2830
MakeVector,
2931
alloc,
3032
cast,
33+
constant,
3134
get_underlying_scalar_constant_value,
3235
)
3336
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
3437
from pytensor.tensor.exceptions import NotScalarConstantError
3538
from pytensor.tensor.math import exp
36-
from pytensor.tensor.rewriting.basic import register_canonicalize, register_specialize
39+
from pytensor.tensor.rewriting.basic import (
40+
broadcast_like,
41+
register_canonicalize,
42+
register_specialize,
43+
)
3744
from pytensor.tensor.shape import shape_padleft
38-
from pytensor.tensor.var import TensorConstant
45+
from pytensor.tensor.var import TensorConstant, get_unique_constant_value
3946

4047

4148
class InplaceElemwiseOptimizer(GraphRewriter):
@@ -1296,6 +1303,65 @@ def local_inline_composite_constants(fgraph, node):
12961303
)
12971304

12981305

1306+
@node_rewriter([Elemwise])
1307+
def local_replace_broadcasted_constants(fgraph, node):
1308+
"""Remove broadcasted constants from Elemwise graphs
1309+
1310+
Elemwise(matrix, ones((3, 4))) -> Elemwise(vector, ones((1, 1)))
1311+
1312+
In cases where the constant influenced the final shape of the Elemwise operation
1313+
We broadcast (via alloc) the new Elemwise result:
1314+
1315+
Elemwise(row, ones((3, 4))) -> Alloc(Elemwise(row, ones((1, 1))), 3, 4)
1316+
1317+
This will avoid useless iterations over constant arrays.
1318+
"""
1319+
if len(node.inputs) == 1:
1320+
return None
1321+
1322+
new_elem_inps = []
1323+
ndims = node.outputs[0].type.ndim
1324+
found_const = False
1325+
for inp in node.inputs:
1326+
# If input has non-broadcastable dims
1327+
if not all(b for b in inp.type.broadcastable):
1328+
constant_value = get_unique_constant_value(inp)
1329+
if constant_value is not None:
1330+
constant_value = np.expand_dims(
1331+
constant_value, axis=tuple(range(ndims))
1332+
).astype(inp.type.dtype)
1333+
new_elem_inps.append(constant(constant_value))
1334+
found_const = True
1335+
continue
1336+
1337+
new_elem_inps.append(inp)
1338+
1339+
if not found_const:
1340+
return None
1341+
1342+
new_outs = node.op.make_node(*new_elem_inps).outputs
1343+
1344+
# The constants were needed to enforce the output shape
1345+
if node.outputs[0].type.broadcastable != new_outs[0].type.broadcastable:
1346+
new_outs = [
1347+
broadcast_like(new_out, template=node.outputs[0], fgraph=fgraph)
1348+
for new_out in new_outs
1349+
]
1350+
1351+
copy_stack_trace(node.outputs, new_outs)
1352+
return new_outs
1353+
1354+
1355+
# We register this immediately after the fusion database.
1356+
# We don't want Allocs to break up the fusion rewrites
1357+
compile.optdb.register(
1358+
"local_replace_broadcasted_constants",
1359+
in2out(local_replace_broadcasted_constants),
1360+
"fast_run",
1361+
position=49.01,
1362+
)
1363+
1364+
12991365
def _rebuild_partial_2f1grad_loop(node, wrt):
13001366
a, b, c, log_z, sign_z = node.inputs[-5:]
13011367
z = exp(log_z) * sign_z

tests/tensor/rewriting/test_elemwise.py

+16-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from pytensor.compile.mode import Mode, get_default_mode
1313
from pytensor.configdefaults import config
1414
from pytensor.gradient import grad
15-
from pytensor.graph.basic import Constant, ancestors, equal_computations
15+
from pytensor.graph.basic import Constant, equal_computations
1616
from pytensor.graph.fg import FunctionGraph
1717
from pytensor.graph.rewriting.basic import check_stack_trace, out2in
1818
from pytensor.graph.rewriting.db import RewriteDatabaseQuery
@@ -178,6 +178,21 @@ def test_dimshuffle_lift_multi_out_elemwise(self):
178178
assert not local_dimshuffle_lift.transform(g, g.outputs[0].owner)
179179

180180

181+
def test_local_replace_broadcasted_constants():
182+
const = np.full(shape=(2, 5), fill_value=2.6)
183+
x = scalar("x")
184+
out = at.power(x, const)
185+
new_out = rewrite_graph(
186+
out, include=["ShapeOpt", "local_replace_broadcasted_constants"]
187+
)
188+
ref_out = at.alloc(
189+
at.power(x, [[2.6]]),
190+
at.constant(2, dtype="int64"),
191+
at.constant(5, dtype="int64"),
192+
)
193+
assert equal_computations([new_out], [ref_out])
194+
195+
181196
def test_local_useless_dimshuffle_in_reshape():
182197
vec = TensorType(dtype="float64", shape=(None,))("vector")
183198
mat = TensorType(dtype="float64", shape=(None, None))("mat")

0 commit comments

Comments
 (0)