Skip to content

Commit e735560

Browse files
committed
Exclude unnecessary inputs from in useless_composite rewrite
1 parent 800a118 commit e735560

File tree

2 files changed

+39
-17
lines changed

2 files changed

+39
-17
lines changed

pytensor/tensor/rewriting/elemwise.py

+19-10
Original file line numberDiff line numberDiff line change
@@ -1094,21 +1094,30 @@ def print_profile(stream, prof, level=0):
10941094
@register_canonicalize
10951095
@node_rewriter([Elemwise])
10961096
def local_useless_composite(fgraph, node):
1097-
"""For elemwise Composite that have multiple outputs, remove the
1098-
outputs that are not used.
1099-
1100-
"""
1097+
"""Remove inputs and outputs of Composite Ops that are not used anywhere."""
11011098
if not isinstance(node.op, Elemwise) or not isinstance(
11021099
node.op.scalar_op, aes.Composite
11031100
):
11041101
return
11051102
comp = node.op.scalar_op
1106-
idx = [i for i, o_extern in enumerate(node.outputs) if fgraph.clients[o_extern]]
1107-
if len(idx) < len(node.outputs):
1108-
new_outputs = [comp.outputs[i] for i in idx]
1109-
c = aes.Composite(inputs=comp.inputs, outputs=new_outputs)
1110-
e = Elemwise(scalar_op=c)(*node.inputs, return_list=True)
1111-
return dict(zip([node.outputs[i] for i in idx], e))
1103+
used_outputs_idxs = [i for i, o_extern in enumerate(node.outputs) if fgraph.clients[o_extern]]
1104+
used_inner_outputs = [comp.outputs[i] for i in used_outputs_idxs]
1105+
comp_fgraph = FunctionGraph(
1106+
inputs=comp.inputs, outputs=used_inner_outputs, clone=False
1107+
)
1108+
used_inputs_idxs = [
1109+
i
1110+
for i, i_intern in enumerate(comp_fgraph.inputs)
1111+
if comp_fgraph.clients[i_intern]
1112+
]
1113+
used_inner_inputs = [comp.inputs[i] for i in used_inputs_idxs]
1114+
if len(used_inner_inputs) < len(node.inputs) or len(used_inner_outputs) < len(
1115+
node.outputs
1116+
):
1117+
used_inputs = [node.inputs[i] for i in used_inputs_idxs]
1118+
c = aes.Composite(inputs=used_inner_inputs, outputs=used_inner_outputs)
1119+
e = Elemwise(scalar_op=c)(*used_inputs, return_list=True)
1120+
return dict(zip([node.outputs[i] for i in used_outputs_idxs], e))
11121121

11131122

11141123
@node_rewriter([CAReduce])

tests/tensor/rewriting/test_elemwise.py

+20-7
Original file line numberDiff line numberDiff line change
@@ -1441,22 +1441,35 @@ def test_nested_composite(self):
14411441

14421442
def test_local_useless_composite(self):
14431443
x = aes.float32()
1444-
c = aes.Composite([x], [x + 1, x - 1])
1445-
X = matrix()
1446-
o = Elemwise(scalar_op=c)(X)
1444+
y = aes.float32()
1445+
c = aes.Composite([x, y], [x + 1, y - 1])
1446+
X = matrix("X")
1447+
Y = matrix("Y")
1448+
o1, o2 = Elemwise(scalar_op=c)(X, Y)
14471449
mode = get_default_mode().including("local_useless_composite")
14481450

1449-
f = function([X], o[0], mode=mode)
1451+
f = function([X, Y], [o1, o2], mode=mode)
14501452
topo = f.maker.fgraph.toposort()
14511453
assert len(topo) == 1
1454+
assert len(topo[0].inputs) == 2
1455+
assert len(topo[0].outputs) == 2
1456+
res1, res2 = f([[1.0]], [[1.0]])
1457+
utt.assert_allclose(res1, [[2.0]])
1458+
utt.assert_allclose(res2, [[0.0]])
1459+
1460+
f = function([X, Y], o1, mode=mode)
1461+
topo = f.maker.fgraph.toposort()
1462+
assert len(topo) == 1
1463+
assert len(topo[0].inputs) == 1
14521464
assert len(topo[0].outputs) == 1
1453-
utt.assert_allclose(f([[1.0]]), [[2.0]])
1465+
utt.assert_allclose(f([[1.0]], [[np.nan]]), [[2.0]])
14541466

1455-
f = function([X], o[1], mode=mode)
1467+
f = function([X, Y], o2, mode=mode)
14561468
topo = f.maker.fgraph.toposort()
14571469
assert len(topo) == 1
1470+
assert len(topo[0].inputs) == 1
14581471
assert len(topo[0].outputs) == 1
1459-
utt.assert_allclose(f([[1.0]]), [[0.0]])
1472+
utt.assert_allclose(f([[np.nan]], [[1.0]]), [[0.0]])
14601473

14611474

14621475
def test_local_useless_dimshuffle_makevector():

0 commit comments

Comments
 (0)