Skip to content

Commit f3380ef

Browse files
authored
Merge pull request #81240 from clackary/cherrypick/ossa-partial-apply-80662
[6.2 🍒] Fix ownership issues with sequences of partial_apply's in AutoDiff closure specialization pass
2 parents 929e51d + 35678bf commit f3380ef

File tree

2 files changed

+267
-48
lines changed

2 files changed

+267
-48
lines changed

SwiftCompilerSources/Sources/Optimizer/FunctionPasses/ClosureSpecialization.swift

Lines changed: 39 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -105,9 +105,9 @@ import SILBridging
105105

106106
private let verbose = false
107107

108-
private func log(_ message: @autoclosure () -> String) {
108+
private func log(prefix: Bool = true, _ message: @autoclosure () -> String) {
109109
if verbose {
110-
print("### \(message())")
110+
debugLog(prefix: prefix, message())
111111
}
112112
}
113113

@@ -128,47 +128,48 @@ let autodiffClosureSpecialization = FunctionPass(name: "autodiff-closure-special
128128
}
129129

130130
var remainingSpecializationRounds = 5
131-
var callerModified = false
132131

133132
repeat {
133+
// TODO: Names here are pretty misleading. We are looking for a place where
134+
// the pullback closure is created (so for `partial_apply` instruction).
134135
var callSites = gatherCallSites(in: function, context)
136+
guard !callSites.isEmpty else {
137+
return
138+
}
135139

136-
if !callSites.isEmpty {
137-
for callSite in callSites {
138-
var (specializedFunction, alreadyExists) = getOrCreateSpecializedFunction(basedOn: callSite, context)
139-
140-
if !alreadyExists {
141-
context.notifyNewFunction(function: specializedFunction, derivedFrom: callSite.applyCallee)
142-
}
140+
for callSite in callSites {
141+
var (specializedFunction, alreadyExists) = getOrCreateSpecializedFunction(basedOn: callSite, context)
143142

144-
rewriteApplyInstruction(using: specializedFunction, callSite: callSite, context)
143+
if !alreadyExists {
144+
context.notifyNewFunction(function: specializedFunction, derivedFrom: callSite.applyCallee)
145145
}
146146

147-
var deadClosures: InstructionWorklist = callSites.reduce(into: InstructionWorklist(context)) { deadClosures, callSite in
148-
callSite.closureArgDescriptors
149-
.map { $0.closure }
150-
.forEach { deadClosures.pushIfNotVisited($0) }
151-
}
147+
rewriteApplyInstruction(using: specializedFunction, callSite: callSite, context)
148+
}
152149

153-
defer {
154-
deadClosures.deinitialize()
155-
}
150+
var deadClosures: InstructionWorklist = callSites.reduce(into: InstructionWorklist(context)) { deadClosures, callSite in
151+
callSite.closureArgDescriptors
152+
.map { $0.closure }
153+
.forEach { deadClosures.pushIfNotVisited($0) }
154+
}
156155

157-
while let deadClosure = deadClosures.pop() {
158-
let isDeleted = context.tryDeleteDeadClosure(closure: deadClosure as! SingleValueInstruction)
159-
if isDeleted {
160-
context.notifyInvalidatedStackNesting()
161-
}
162-
}
156+
defer {
157+
deadClosures.deinitialize()
158+
}
163159

164-
if context.needFixStackNesting {
165-
function.fixStackNesting(context)
160+
while let deadClosure = deadClosures.pop() {
161+
let isDeleted = context.tryDeleteDeadClosure(closure: deadClosure as! SingleValueInstruction)
162+
if isDeleted {
163+
context.notifyInvalidatedStackNesting()
166164
}
167165
}
168166

169-
callerModified = callSites.count > 0
167+
if context.needFixStackNesting {
168+
function.fixStackNesting(context)
169+
}
170+
170171
remainingSpecializationRounds -= 1
171-
} while callerModified && remainingSpecializationRounds > 0
172+
} while remainingSpecializationRounds > 0
172173
}
173174

174175
// =========== Top-level functions ========== //
@@ -503,12 +504,6 @@ private func handleApplies(for rootClosure: SingleValueInstruction, callSiteMap:
503504
continue
504505
}
505506

506-
// Workaround for a problem with OSSA: https://github.com/swiftlang/swift/issues/78847
507-
// TODO: remove this if-statement once the underlying problem is fixed.
508-
if callee.hasOwnership {
509-
continue
510-
}
511-
512507
if callee.isDefinedExternally {
513508
continue
514509
}
@@ -779,13 +774,13 @@ private extension SpecializationCloner {
779774

780775
let clonedRootClosure = builder.cloneRootClosure(representedBy: closureArgDesc, capturedArguments: clonedClosureArgs)
781776

782-
let (finalClonedReabstractedClosure, releasableClonedReabstractedClosures) =
777+
let finalClonedReabstractedClosure =
783778
builder.cloneRootClosureReabstractions(rootClosure: closureArgDesc.closure, clonedRootClosure: clonedRootClosure,
784779
reabstractedClosure: callSite.appliedArgForClosure(at: closureArgDesc.closureArgIndex)!,
785780
origToClonedValueMap: origToClonedValueMap,
786781
self.context)
787782

788-
let allClonedReleasableClosures = [clonedRootClosure] + releasableClonedReabstractedClosures
783+
let allClonedReleasableClosures = [ finalClonedReabstractedClosure ];
789784
return (finalClonedReabstractedClosure, allClonedReleasableClosures)
790785
}
791786

@@ -935,10 +930,9 @@ private extension Builder {
935930

936931
func cloneRootClosureReabstractions(rootClosure: Value, clonedRootClosure: Value, reabstractedClosure: Value,
937932
origToClonedValueMap: [HashableValue: Value], _ context: FunctionPassContext)
938-
-> (finalClonedReabstractedClosure: SingleValueInstruction, releasableClonedReabstractedClosures: [PartialApplyInst])
933+
-> SingleValueInstruction
939934
{
940935
func inner(_ rootClosure: Value, _ clonedRootClosure: Value, _ reabstractedClosure: Value,
941-
_ releasableClonedReabstractedClosures: inout [PartialApplyInst],
942936
_ origToClonedValueMap: inout [HashableValue: Value]) -> Value {
943937
switch reabstractedClosure {
944938
case let reabstractedClosure where reabstractedClosure == rootClosure:
@@ -947,23 +941,23 @@ private extension Builder {
947941

948942
case let cvt as ConvertFunctionInst:
949943
let toBeReabstracted = inner(rootClosure, clonedRootClosure, cvt.fromFunction,
950-
&releasableClonedReabstractedClosures, &origToClonedValueMap)
944+
&origToClonedValueMap)
951945
let reabstracted = self.createConvertFunction(originalFunction: toBeReabstracted, resultType: cvt.type,
952946
withoutActuallyEscaping: cvt.withoutActuallyEscaping)
953947
origToClonedValueMap[cvt] = reabstracted
954948
return reabstracted
955949

956950
case let cvt as ConvertEscapeToNoEscapeInst:
957951
let toBeReabstracted = inner(rootClosure, clonedRootClosure, cvt.fromFunction,
958-
&releasableClonedReabstractedClosures, &origToClonedValueMap)
952+
&origToClonedValueMap)
959953
let reabstracted = self.createConvertEscapeToNoEscape(originalFunction: toBeReabstracted, resultType: cvt.type,
960954
isLifetimeGuaranteed: true)
961955
origToClonedValueMap[cvt] = reabstracted
962956
return reabstracted
963957

964958
case let pai as PartialApplyInst:
965959
let toBeReabstracted = inner(rootClosure, clonedRootClosure, pai.arguments[0],
966-
&releasableClonedReabstractedClosures, &origToClonedValueMap)
960+
&origToClonedValueMap)
967961

968962
guard let function = pai.referencedFunction else {
969963
log("Parent function of callSite: \(rootClosure.parentFunction)")
@@ -978,13 +972,11 @@ private extension Builder {
978972
calleeConvention: pai.calleeConvention,
979973
hasUnknownResultIsolation: pai.hasUnknownResultIsolation,
980974
isOnStack: pai.isOnStack)
981-
releasableClonedReabstractedClosures.append(reabstracted)
982975
origToClonedValueMap[pai] = reabstracted
983976
return reabstracted
984977

985978
case let mdi as MarkDependenceInst:
986-
let toBeReabstracted = inner(rootClosure, clonedRootClosure, mdi.value, &releasableClonedReabstractedClosures,
987-
&origToClonedValueMap)
979+
let toBeReabstracted = inner(rootClosure, clonedRootClosure, mdi.value, &origToClonedValueMap)
988980
let base = origToClonedValueMap[mdi.base]!
989981
let reabstracted = self.createMarkDependence(value: toBeReabstracted, base: base, kind: .Escaping)
990982
origToClonedValueMap[mdi] = reabstracted
@@ -998,11 +990,10 @@ private extension Builder {
998990
}
999991
}
1000992

1001-
var releasableClonedReabstractedClosures: [PartialApplyInst] = []
1002993
var origToClonedValueMap = origToClonedValueMap
1003994
let finalClonedReabstractedClosure = inner(rootClosure, clonedRootClosure, reabstractedClosure,
1004-
&releasableClonedReabstractedClosures, &origToClonedValueMap)
1005-
return (finalClonedReabstractedClosure as! SingleValueInstruction, releasableClonedReabstractedClosures)
995+
&origToClonedValueMap)
996+
return (finalClonedReabstractedClosure as! SingleValueInstruction)
1006997
}
1007998

1008999
func destroyPartialApply(pai: PartialApplyInst, _ context: FunctionPassContext){

0 commit comments

Comments
 (0)