Skip to content

Commit 8289de3

Browse files
committed
Add rewrite to remove broadcasted constants from Elemwise graphs
1 parent 027f64a commit 8289de3

File tree

2 files changed

+66
-2
lines changed

2 files changed

+66
-2
lines changed

pytensor/tensor/rewriting/elemwise.py

Lines changed: 53 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
from typing import DefaultDict, Generator, List, Set, Tuple, TypeVar
55
from warnings import warn
66

7+
import numpy as np
8+
79
import pytensor
810
import pytensor.scalar.basic as aes
911
from pytensor import clone_replace, compile
@@ -28,14 +30,19 @@
2830
MakeVector,
2931
alloc,
3032
cast,
33+
constant,
3134
get_underlying_scalar_constant_value,
3235
)
3336
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
3437
from pytensor.tensor.exceptions import NotScalarConstantError
3538
from pytensor.tensor.math import exp
36-
from pytensor.tensor.rewriting.basic import register_canonicalize, register_specialize
39+
from pytensor.tensor.rewriting.basic import (
40+
broadcast_like,
41+
register_canonicalize,
42+
register_specialize,
43+
)
3744
from pytensor.tensor.shape import shape_padleft
38-
from pytensor.tensor.var import TensorConstant
45+
from pytensor.tensor.var import TensorConstant, get_unique_constant_value
3946

4047

4148
class InplaceElemwiseOptimizer(GraphRewriter):
@@ -552,6 +559,50 @@ def local_upcast_elemwise_constant_inputs(fgraph, node):
552559
return rval
553560

554561

562+
@register_specialize
563+
@node_rewriter([Elemwise])
564+
def local_replace_broadcasted_constant(fgraph, node):
565+
"""Remove broadcasted constants from Elemwise graphs
566+
567+
Elemwise(scalar, ones((3, 4))) -> Alloc(Elemwise(scalar, ones((1, 1))), 3, 4)
568+
569+
This will avoid a useless iterations over constant arrays
570+
"""
571+
if len(node.inputs) == 1:
572+
return None
573+
574+
new_elem_inps = []
575+
ndims = node.outputs[0].type.ndim
576+
found_const = False
577+
for inp in node.inputs:
578+
# If input has non-broadcastable dims
579+
if not all(b for b in inp.type.broadcastable):
580+
constant_value = get_unique_constant_value(inp)
581+
if constant_value is not None:
582+
constant_value = np.expand_dims(
583+
constant_value, axis=tuple(range(ndims))
584+
).astype(inp.type.dtype)
585+
new_elem_inps.append(constant(constant_value))
586+
found_const = True
587+
continue
588+
589+
new_elem_inps.append(inp)
590+
591+
if not found_const:
592+
return None
593+
594+
new_outs = node.op.make_node(*new_elem_inps).outputs
595+
# The constants were needed to enforce the output shape
596+
if node.outputs[0].type.broadcastable != new_outs[0].type.broadcastable:
597+
new_outs = [
598+
broadcast_like(new_out, template=node.outputs[0], fgraph=fgraph)
599+
for new_out in new_outs
600+
]
601+
602+
copy_stack_trace(node.outputs, new_outs)
603+
return new_outs
604+
605+
555606
@node_rewriter([Elemwise])
556607
def local_add_mul_fusion(fgraph, node):
557608
"""Fuse consecutive add or mul in one such node with more inputs.

tests/tensor/rewriting/test_elemwise.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,19 @@ def test_dimshuffle_lift_multi_out_elemwise(self):
178178
assert not local_dimshuffle_lift.transform(g, g.outputs[0].owner)
179179

180180

181+
def test_local_replace_broadcasted_constant():
182+
const = np.full(shape=(2, 5), fill_value=2.6)
183+
x = scalar("x")
184+
out = at.power(x, const)
185+
new_out = rewrite_graph(out, include=["ShapeOpt", "specialize"])
186+
ref_out = at.alloc(
187+
at.power(x, [[2.6]]),
188+
at.constant(2, dtype="int64"),
189+
at.constant(5, dtype="int64"),
190+
)
191+
assert equal_computations([new_out], [ref_out])
192+
193+
181194
def test_local_useless_dimshuffle_in_reshape():
182195
vec = TensorType(dtype="float64", shape=(None,))("vector")
183196
mat = TensorType(dtype="float64", shape=(None, None))("mat")

0 commit comments

Comments
 (0)