Skip to content

Commit dd45099

Browse files
committed
Exclude unnecessary inputs in useless_composite rewrite
1 parent 1f49094 commit dd45099

File tree

2 files changed

+42
-17
lines changed

2 files changed

+42
-17
lines changed

pytensor/tensor/rewriting/elemwise.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -990,23 +990,33 @@ def print_profile(cls, stream, prof, level=0):
990990

991991

992992
@register_canonicalize
993+
@register_specialize
993994
@node_rewriter([Elemwise])
994995
def local_useless_composite(fgraph, node):
995-
"""For elemwise Composite that have multiple outputs, remove the
996-
outputs that are not used.
997-
998-
"""
996+
"""Remove inputs and outputs of Composite Ops that are not used anywhere."""
999997
if not isinstance(node.op, Elemwise) or not isinstance(
1000998
node.op.scalar_op, aes.Composite
1001999
):
10021000
return
10031001
comp = node.op.scalar_op
1004-
idx = [i for i, o_extern in enumerate(node.outputs) if fgraph.clients[o_extern]]
1005-
if len(idx) < len(node.outputs):
1006-
new_outputs = [comp.outputs[i] for i in idx]
1007-
c = aes.Composite(inputs=comp.inputs, outputs=new_outputs)
1008-
e = Elemwise(scalar_op=c)(*node.inputs, return_list=True)
1009-
return dict(zip([node.outputs[i] for i in idx], e))
1002+
used_outputs_idxs = [i for i, o_extern in enumerate(node.outputs) if fgraph.clients[o_extern]]
1003+
used_inner_outputs = [comp.outputs[i] for i in used_outputs_idxs]
1004+
comp_fgraph = FunctionGraph(
1005+
inputs=comp.inputs, outputs=used_inner_outputs, clone=False
1006+
)
1007+
used_inputs_idxs = [
1008+
i
1009+
for i, i_intern in enumerate(comp_fgraph.inputs)
1010+
if comp_fgraph.clients[i_intern]
1011+
]
1012+
used_inner_inputs = [comp.inputs[i] for i in used_inputs_idxs]
1013+
if len(used_inner_inputs) < len(node.inputs) or len(used_inner_outputs) < len(
1014+
node.outputs
1015+
):
1016+
used_inputs = [node.inputs[i] for i in used_inputs_idxs]
1017+
c = aes.Composite(inputs=used_inner_inputs, outputs=used_inner_outputs)
1018+
e = Elemwise(scalar_op=c)(*used_inputs, return_list=True)
1019+
return dict(zip([node.outputs[i] for i in used_outputs_idxs], e))
10101020

10111021

10121022
@node_rewriter([CAReduce])

tests/tensor/rewriting/test_elemwise.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1292,22 +1292,37 @@ def test_nested_composite(self):
12921292

12931293
def test_local_useless_composite(self):
12941294
x = aes.float32()
1295-
c = aes.Composite([x], [x + 1, x - 1])
1296-
X = matrix()
1297-
o = Elemwise(scalar_op=c)(X)
1295+
y = aes.float32()
1296+
z = aes.float32()
1297+
c = aes.Composite([x, y, z], [x + 1, y - 1])
1298+
X = matrix("X")
1299+
Y = matrix("Y")
1300+
Z = matrix("Z")
1301+
o1, o2 = Elemwise(scalar_op=c)(X, Y, Z)
12981302
mode = get_default_mode().including("local_useless_composite")
12991303

1300-
f = function([X], o[0], mode=mode)
1304+
f = function([X, Y, Z], [o1, o2], mode=mode)
13011305
topo = f.maker.fgraph.toposort()
13021306
assert len(topo) == 1
1307+
assert len(topo[0].inputs) == 2
1308+
assert len(topo[0].outputs) == 2
1309+
res1, res2 = f([[1.0]], [[1.0]], [[np.nan]])
1310+
utt.assert_allclose(res1, [[2.0]])
1311+
utt.assert_allclose(res2, [[0.0]])
1312+
1313+
f = function([X, Y, Z], o1, mode=mode)
1314+
topo = f.maker.fgraph.toposort()
1315+
assert len(topo) == 1
1316+
assert len(topo[0].inputs) == 1
13031317
assert len(topo[0].outputs) == 1
1304-
utt.assert_allclose(f([[1.0]]), [[2.0]])
1318+
utt.assert_allclose(f([[1.0]], [[np.nan]], [[np.nan]]), [[2.0]])
13051319

1306-
f = function([X], o[1], mode=mode)
1320+
f = function([X, Y, Z], o2, mode=mode)
13071321
topo = f.maker.fgraph.toposort()
13081322
assert len(topo) == 1
1323+
assert len(topo[0].inputs) == 1
13091324
assert len(topo[0].outputs) == 1
1310-
utt.assert_allclose(f([[1.0]]), [[0.0]])
1325+
utt.assert_allclose(f([[np.nan]], [[1.0]], [[np.nan]]), [[0.0]])
13111326

13121327

13131328
def test_local_useless_dimshuffle_makevector():

0 commit comments

Comments
 (0)