Skip to content

Commit aa9ca61

Browse files
committed
Temporarily disable inplace for multiple-output Composites
1 parent ce1eb4d commit aa9ca61

File tree

1 file changed

+10
-2
lines changed

1 file changed

+10
-2
lines changed

pytensor/tensor/rewriting/elemwise.py

+10-2
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,14 @@ def print_profile(cls, stream, prof, level=0):
5858
for n in sorted(ndim.keys()):
5959
print(blanc, n, ndim[n], file=stream)
6060

61+
def candidate_input_idxs(self, node):
62+
if isinstance(node.op.scalar_op, aes.Composite) and len(node.outputs) > 1:
63+
# TODO: Implement specialized InplaceCompositeOptimizer with logic
64+
# needed to correctly assign inplace for multi-output Composites
65+
return []
66+
else:
67+
return range(len(node.outputs))
68+
6169
def apply(self, fgraph):
6270
r"""
6371
@@ -148,7 +156,7 @@ def apply(self, fgraph):
148156

149157
baseline = op.inplace_pattern
150158
candidate_outputs = [
151-
i for i in range(len(node.outputs)) if i not in baseline
159+
i for i in self.candidate_input_idxs(node) if i not in baseline
152160
]
153161
# node inputs that are Constant, already destroyed,
154162
# or fgraph protected inputs and fgraph outputs can't be used as
@@ -166,7 +174,7 @@ def apply(self, fgraph):
166174
]
167175
else:
168176
baseline = []
169-
candidate_outputs = list(range(len(node.outputs)))
177+
candidate_outputs = self.candidate_input_idxs(node)
170178
# node inputs that are Constant, already destroyed,
171179
# fgraph protected inputs and fgraph outputs can't be used as inplace
172180
# target.

0 commit comments

Comments
 (0)