@@ -58,6 +58,14 @@ def print_profile(cls, stream, prof, level=0):
58
58
for n in sorted (ndim .keys ()):
59
59
print (blanc , n , ndim [n ], file = stream )
60
60
61
+ def candidate_input_idxs (self , node ):
62
+ if isinstance (node .op .scalar_op , aes .Composite ) and len (node .outputs ) > 1 :
63
+ # TODO: Implement specialized InplaceCompositeOptimizer with logic
64
+ # needed to correctly assign inplace for multi-output Composites
65
+ return []
66
+ else :
67
+ return range (len (node .outputs ))
68
+
61
69
def apply (self , fgraph ):
62
70
r"""
63
71
@@ -148,7 +156,7 @@ def apply(self, fgraph):
148
156
149
157
baseline = op .inplace_pattern
150
158
candidate_outputs = [
151
- i for i in range ( len ( node . outputs ) ) if i not in baseline
159
+ i for i in self . candidate_input_idxs ( node ) if i not in baseline
152
160
]
153
161
# node inputs that are Constant, already destroyed,
154
162
# or fgraph protected inputs and fgraph outputs can't be used as
@@ -166,7 +174,7 @@ def apply(self, fgraph):
166
174
]
167
175
else :
168
176
baseline = []
169
- candidate_outputs = list ( range ( len ( node . outputs )) )
177
+ candidate_outputs = self . candidate_input_idxs ( node )
170
178
# node inputs that are Constant, already destroyed,
171
179
# fgraph protected inputs and fgraph outputs can't be used as inplace
172
180
# target.
0 commit comments