Skip to content

Commit 1f49094

Browse files
committed
Disable invalid inplace logic for multiple-output Composites
1 parent da87b0c commit 1f49094

File tree

2 files changed

+39
-2
lines changed

2 files changed

+39
-2
lines changed

pytensor/tensor/rewriting/elemwise.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,14 @@ def print_profile(cls, stream, prof, level=0):
5959
for n in sorted(ndim.keys()):
6060
print(blanc, n, ndim[n], file=stream)
6161

62+
def candidate_input_idxs(self, node):
63+
if isinstance(node.op.scalar_op, aes.Composite) and len(node.outputs) > 1:
64+
# TODO: Implement specialized InplaceCompositeOptimizer with logic
65+
# needed to correctly assign inplace for multi-output Composites
66+
return []
67+
else:
68+
return range(len(node.outputs))
69+
6270
def apply(self, fgraph):
6371
r"""
6472
@@ -149,7 +157,7 @@ def apply(self, fgraph):
149157

150158
baseline = op.inplace_pattern
151159
candidate_outputs = [
152-
i for i in range(len(node.outputs)) if i not in baseline
160+
i for i in self.candidate_input_idxs(node) if i not in baseline
153161
]
154162
# node inputs that are Constant, already destroyed,
155163
# or fgraph protected inputs and fgraph outputs can't be used as
@@ -167,7 +175,7 @@ def apply(self, fgraph):
167175
]
168176
else:
169177
baseline = []
170-
candidate_outputs = list(range(len(node.outputs)))
178+
candidate_outputs = self.candidate_input_idxs(node)
171179
# node inputs that are Constant, already destroyed,
172180
# fgraph protected inputs and fgraph outputs can't be used as inplace
173181
# target.

tests/tensor/rewriting/test_elemwise.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import pytest
55

66
import pytensor
7+
from pytensor import In
78
from pytensor import scalar as aes
89
from pytensor import shared
910
from pytensor import tensor as at
@@ -1024,6 +1025,34 @@ def test_add_mul_fusion_inplace(self):
10241025
np.random.random((5, 5)), np.random.random((5, 5)), np.random.random((5, 5))
10251026
)
10261027

1028+
def test_fusion_multiout_inplace(self):
1029+
x = vector("x")
1030+
1031+
# Create Composite where inplacing the first non-constant output would corrupt the second output
1032+
xs = aes.float64("xs")
1033+
outs = (
1034+
Elemwise(Composite([xs], [xs + 1, aes.cos(xs + 1) + xs]))
1035+
.make_node(x)
1036+
.outputs
1037+
)
1038+
1039+
f = pytensor.function(
1040+
[In(x, mutable=True)],
1041+
outs,
1042+
mode=self.mode.including("inplace"),
1043+
)
1044+
(composite_node,) = f.maker.fgraph.apply_nodes
1045+
1046+
# Destroy map must be None or the last toposorted output
1047+
destroy_map = composite_node.op.destroy_map
1048+
assert (destroy_map == {}) or (
1049+
destroy_map == {1: [composite_node.inputs.index(x)]}
1050+
)
1051+
1052+
res = f([0, 1, 2])
1053+
assert np.allclose(res[0], [1, 2, 3])
1054+
assert np.allclose(res[1], np.cos([1, 2, 3]) + np.array([0, 1, 2]))
1055+
10271056
@pytest.mark.skipif(not config.cxx, reason="No cxx compiler")
10281057
def test_no_c_code(self):
10291058
r"""Make sure we avoid fusions for `Op`\s without C code implementations."""

0 commit comments

Comments
 (0)