Skip to content

Commit 7218431

Browse files
authored
Add rewrite for Sum(MakeVector) (#346)
1 parent 981be2a commit 7218431

File tree

2 files changed

+76
-1
lines changed

2 files changed

+76
-1
lines changed

pytensor/tensor/rewriting/basic.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
from pytensor.tensor.elemwise import DimShuffle, Elemwise
4444
from pytensor.tensor.exceptions import NotScalarConstantError
4545
from pytensor.tensor.extra_ops import broadcast_shape, broadcast_to
46+
from pytensor.tensor.math import Sum, add
4647
from pytensor.tensor.math import all as at_all
4748
from pytensor.tensor.math import eq
4849
from pytensor.tensor.shape import Shape_i
@@ -956,6 +957,41 @@ def local_join_make_vector(fgraph, node):
956957
return [ret]
957958

958959

960+
@register_specialize
961+
@register_canonicalize
962+
@register_useless
963+
@node_rewriter([Sum])
964+
def local_sum_make_vector(fgraph, node):
965+
"""A sum of a MakeVector node is just the sum of the elements."""
966+
(array,) = node.inputs
967+
968+
if array.owner is None:
969+
return
970+
971+
if not isinstance(array.owner.op, MakeVector):
972+
return
973+
974+
if node.op.axis == ():
975+
return [array]
976+
977+
# If this is not the case the sum is invalid
978+
assert node.op.axis is None or node.op.axis == (0,) or node.op.axis == (-1,)
979+
980+
elements = array.owner.inputs
981+
acc_dtype = node.op.acc_dtype
982+
out_dtype = node.op.dtype
983+
if len(elements) == 0:
984+
element_sum = zeros(dtype=out_dtype, shape=())
985+
elif len(elements) == 1:
986+
element_sum = cast(elements[0], out_dtype)
987+
else:
988+
element_sum = cast(
989+
add(*[cast(value, acc_dtype) for value in elements]), out_dtype
990+
)
991+
992+
return [element_sum]
993+
994+
959995
@register_useless("local_remove_switch_const_cond")
960996
@register_canonicalize("fast_compile", "local_remove_switch_const_cond")
961997
@register_specialize

tests/tensor/rewriting/test_basic.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from pytensor.compile.mode import get_default_mode, get_mode
1313
from pytensor.compile.ops import DeepCopyOp, deep_copy_op
1414
from pytensor.configdefaults import config
15-
from pytensor.graph.basic import equal_computations
15+
from pytensor.graph.basic import equal_computations, vars_between
1616
from pytensor.graph.fg import FunctionGraph
1717
from pytensor.graph.rewriting.basic import check_stack_trace, out2in
1818
from pytensor.graph.rewriting.db import RewriteDatabaseQuery
@@ -31,6 +31,7 @@
3131
)
3232
from pytensor.tensor.elemwise import DimShuffle, Elemwise
3333
from pytensor.tensor.math import (
34+
Sum,
3435
add,
3536
bitwise_and,
3637
bitwise_or,
@@ -1300,6 +1301,44 @@ def test_local_join_make_vector():
13001301
assert check_stack_trace(f, ops_to_check="all")
13011302

13021303

1304+
def test_local_sum_make_vector():
1305+
a, b, c = scalars("abc")
1306+
mv = MakeVector(config.floatX)
1307+
output = mv(a, b, c).sum()
1308+
1309+
output = rewrite_graph(output)
1310+
between = vars_between([a, b, c], [output])
1311+
for var in between:
1312+
assert (var.owner is None) or (not isinstance(var.owner.op, MakeVector))
1313+
1314+
# Check for empty sum
1315+
a, b, c = scalars("abc")
1316+
mv = MakeVector(config.floatX)
1317+
output = mv(a, b, c).sum(axis=[])
1318+
1319+
output = rewrite_graph(output)
1320+
between = vars_between([a, b, c], [output])
1321+
for var in between:
1322+
assert (var.owner is None) or (not isinstance(var.owner.op, Sum))
1323+
1324+
# Check empty MakeVector
1325+
mv = MakeVector(config.floatX)
1326+
output = mv().sum()
1327+
1328+
output = rewrite_graph(output)
1329+
between = vars_between([a, b, c], [output])
1330+
for var in between:
1331+
assert (var.owner is None) or (not isinstance(var.owner.op, Sum))
1332+
1333+
mv = MakeVector(config.floatX)
1334+
output = mv(a).sum()
1335+
1336+
output = rewrite_graph(output)
1337+
between = vars_between([a, b, c], [output])
1338+
for var in between:
1339+
assert (var.owner is None) or (not isinstance(var.owner.op, Sum))
1340+
1341+
13031342
@pytest.mark.parametrize(
13041343
"dtype",
13051344
[

0 commit comments

Comments
 (0)