Skip to content

Commit b341661

Browse files
committed
Fix bug in Composite when multiple outputs are identical
1 parent 8532ac6 commit b341661

File tree

2 files changed

+26
-0
lines changed

2 files changed

+26
-0
lines changed

pytensor/scalar/basic.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4146,6 +4146,21 @@ def fgraph(self):
41464146
"The fgraph to Composite must be exclusively"
41474147
" composed of ScalarOp instances."
41484148
)
4149+
4150+
# Clone identical outputs that have been merged
4151+
if len(set(fgraph.outputs)) != len(self.outputs):
4152+
old_outputs = fgraph.outputs
4153+
new_outputs = []
4154+
for output in old_outputs:
4155+
if output not in new_outputs:
4156+
new_outputs.append(output)
4157+
else:
4158+
node = output.owner
4159+
output_idx = node.outputs.index(output)
4160+
new_output = node.clone().outputs[output_idx]
4161+
new_outputs.append(new_output)
4162+
fgraph = FunctionGraph(fgraph.inputs, new_outputs, clone=False)
4163+
41494164
self._fgraph = fgraph
41504165
return self._fgraph
41514166

tests/scalar/test_basic.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,17 @@ def test_many_outputs(self):
156156
fn = make_function(DualLinker().accept(g))
157157
assert fn(1.0, 2.0, 3.0) == [6.0, 7.0, 0.5]
158158

159+
def test_identical_outputs(self):
160+
x, y, z = floats("xyz")
161+
e0 = x + y + z
162+
e1 = x + y + z
163+
e2 = x / y
164+
C = Composite([x, y, z], [e0, e1, e2])
165+
c = C.make_node(x, y, z)
166+
g = FunctionGraph([x, y, z], c.outputs)
167+
fn = make_function(DualLinker().accept(g))
168+
assert fn(1.0, 2.0, 3.0) == [6.0, 6.0, 0.5]
169+
159170
def test_composite_printing(self):
160171
x, y, z = floats("xyz")
161172
e0 = x + y + z

0 commit comments

Comments
 (0)