@@ -105,9 +105,9 @@ import SILBridging
105
105
106
106
private let verbose = false
107
107
108
- private func log( _ message: @autoclosure ( ) -> String ) {
108
+ private func log( prefix : Bool = true , _ message: @autoclosure ( ) -> String ) {
109
109
if verbose {
110
- print ( " ### \( message ( ) ) " )
110
+ debugLog ( prefix : prefix , message ( ) )
111
111
}
112
112
}
113
113
@@ -128,47 +128,48 @@ let autodiffClosureSpecialization = FunctionPass(name: "autodiff-closure-special
128
128
}
129
129
130
130
var remainingSpecializationRounds = 5
131
- var callerModified = false
132
131
133
132
repeat {
133
+ // TODO: Names here are pretty misleading. We are looking for a place where
134
+ // the pullback closure is created (so for `partial_apply` instruction).
134
135
var callSites = gatherCallSites ( in: function, context)
136
+ guard !callSites. isEmpty else {
137
+ return
138
+ }
135
139
136
- if !callSites. isEmpty {
137
- for callSite in callSites {
138
- var ( specializedFunction, alreadyExists) = getOrCreateSpecializedFunction ( basedOn: callSite, context)
139
-
140
- if !alreadyExists {
141
- context. notifyNewFunction ( function: specializedFunction, derivedFrom: callSite. applyCallee)
142
- }
140
+ for callSite in callSites {
141
+ var ( specializedFunction, alreadyExists) = getOrCreateSpecializedFunction ( basedOn: callSite, context)
143
142
144
- rewriteApplyInstruction ( using: specializedFunction, callSite: callSite, context)
143
+ if !alreadyExists {
144
+ context. notifyNewFunction ( function: specializedFunction, derivedFrom: callSite. applyCallee)
145
145
}
146
146
147
- var deadClosures : InstructionWorklist = callSites. reduce ( into: InstructionWorklist ( context) ) { deadClosures, callSite in
148
- callSite. closureArgDescriptors
149
- . map { $0. closure }
150
- . forEach { deadClosures. pushIfNotVisited ( $0) }
151
- }
147
+ rewriteApplyInstruction ( using: specializedFunction, callSite: callSite, context)
148
+ }
152
149
153
- defer {
154
- deadClosures. deinitialize ( )
155
- }
150
+ var deadClosures : InstructionWorklist = callSites. reduce ( into: InstructionWorklist ( context) ) { deadClosures, callSite in
151
+ callSite. closureArgDescriptors
152
+ . map { $0. closure }
153
+ . forEach { deadClosures. pushIfNotVisited ( $0) }
154
+ }
156
155
157
- while let deadClosure = deadClosures. pop ( ) {
158
- let isDeleted = context. tryDeleteDeadClosure ( closure: deadClosure as! SingleValueInstruction )
159
- if isDeleted {
160
- context. notifyInvalidatedStackNesting ( )
161
- }
162
- }
156
+ defer {
157
+ deadClosures. deinitialize ( )
158
+ }
163
159
164
- if context. needFixStackNesting {
165
- function. fixStackNesting ( context)
160
+ while let deadClosure = deadClosures. pop ( ) {
161
+ let isDeleted = context. tryDeleteDeadClosure ( closure: deadClosure as! SingleValueInstruction )
162
+ if isDeleted {
163
+ context. notifyInvalidatedStackNesting ( )
166
164
}
167
165
}
168
166
169
- callerModified = callSites. count > 0
167
+ if context. needFixStackNesting {
168
+ function. fixStackNesting ( context)
169
+ }
170
+
170
171
remainingSpecializationRounds -= 1
171
- } while callerModified && remainingSpecializationRounds > 0
172
+ } while remainingSpecializationRounds > 0
172
173
}
173
174
174
175
// =========== Top-level functions ========== //
@@ -503,12 +504,6 @@ private func handleApplies(for rootClosure: SingleValueInstruction, callSiteMap:
503
504
continue
504
505
}
505
506
506
- // Workaround for a problem with OSSA: https://github.com/swiftlang/swift/issues/78847
507
- // TODO: remove this if-statement once the underlying problem is fixed.
508
- if callee. hasOwnership {
509
- continue
510
- }
511
-
512
507
if callee. isDefinedExternally {
513
508
continue
514
509
}
@@ -779,13 +774,13 @@ private extension SpecializationCloner {
779
774
780
775
let clonedRootClosure = builder. cloneRootClosure ( representedBy: closureArgDesc, capturedArguments: clonedClosureArgs)
781
776
782
- let ( finalClonedReabstractedClosure, releasableClonedReabstractedClosures ) =
777
+ let finalClonedReabstractedClosure =
783
778
builder. cloneRootClosureReabstractions ( rootClosure: closureArgDesc. closure, clonedRootClosure: clonedRootClosure,
784
779
reabstractedClosure: callSite. appliedArgForClosure ( at: closureArgDesc. closureArgIndex) !,
785
780
origToClonedValueMap: origToClonedValueMap,
786
781
self . context)
787
782
788
- let allClonedReleasableClosures = [ clonedRootClosure ] + releasableClonedReabstractedClosures
783
+ let allClonedReleasableClosures = [ finalClonedReabstractedClosure ] ;
789
784
return ( finalClonedReabstractedClosure, allClonedReleasableClosures)
790
785
}
791
786
@@ -935,10 +930,9 @@ private extension Builder {
935
930
936
931
func cloneRootClosureReabstractions( rootClosure: Value , clonedRootClosure: Value , reabstractedClosure: Value ,
937
932
origToClonedValueMap: [ HashableValue : Value ] , _ context: FunctionPassContext )
938
- -> ( finalClonedReabstractedClosure : SingleValueInstruction , releasableClonedReabstractedClosures : [ PartialApplyInst ] )
933
+ -> SingleValueInstruction
939
934
{
940
935
func inner( _ rootClosure: Value , _ clonedRootClosure: Value , _ reabstractedClosure: Value ,
941
- _ releasableClonedReabstractedClosures: inout [ PartialApplyInst ] ,
942
936
_ origToClonedValueMap: inout [ HashableValue : Value ] ) -> Value {
943
937
switch reabstractedClosure {
944
938
case let reabstractedClosure where reabstractedClosure == rootClosure:
@@ -947,23 +941,23 @@ private extension Builder {
947
941
948
942
case let cvt as ConvertFunctionInst :
949
943
let toBeReabstracted = inner ( rootClosure, clonedRootClosure, cvt. fromFunction,
950
- & releasableClonedReabstractedClosures , & origToClonedValueMap)
944
+ & origToClonedValueMap)
951
945
let reabstracted = self . createConvertFunction ( originalFunction: toBeReabstracted, resultType: cvt. type,
952
946
withoutActuallyEscaping: cvt. withoutActuallyEscaping)
953
947
origToClonedValueMap [ cvt] = reabstracted
954
948
return reabstracted
955
949
956
950
case let cvt as ConvertEscapeToNoEscapeInst :
957
951
let toBeReabstracted = inner ( rootClosure, clonedRootClosure, cvt. fromFunction,
958
- & releasableClonedReabstractedClosures , & origToClonedValueMap)
952
+ & origToClonedValueMap)
959
953
let reabstracted = self . createConvertEscapeToNoEscape ( originalFunction: toBeReabstracted, resultType: cvt. type,
960
954
isLifetimeGuaranteed: true )
961
955
origToClonedValueMap [ cvt] = reabstracted
962
956
return reabstracted
963
957
964
958
case let pai as PartialApplyInst :
965
959
let toBeReabstracted = inner ( rootClosure, clonedRootClosure, pai. arguments [ 0 ] ,
966
- & releasableClonedReabstractedClosures , & origToClonedValueMap)
960
+ & origToClonedValueMap)
967
961
968
962
guard let function = pai. referencedFunction else {
969
963
log ( " Parent function of callSite: \( rootClosure. parentFunction) " )
@@ -978,13 +972,11 @@ private extension Builder {
978
972
calleeConvention: pai. calleeConvention,
979
973
hasUnknownResultIsolation: pai. hasUnknownResultIsolation,
980
974
isOnStack: pai. isOnStack)
981
- releasableClonedReabstractedClosures. append ( reabstracted)
982
975
origToClonedValueMap [ pai] = reabstracted
983
976
return reabstracted
984
977
985
978
case let mdi as MarkDependenceInst :
986
- let toBeReabstracted = inner ( rootClosure, clonedRootClosure, mdi. value, & releasableClonedReabstractedClosures,
987
- & origToClonedValueMap)
979
+ let toBeReabstracted = inner ( rootClosure, clonedRootClosure, mdi. value, & origToClonedValueMap)
988
980
let base = origToClonedValueMap [ mdi. base] !
989
981
let reabstracted = self . createMarkDependence ( value: toBeReabstracted, base: base, kind: . Escaping)
990
982
origToClonedValueMap [ mdi] = reabstracted
@@ -998,11 +990,10 @@ private extension Builder {
998
990
}
999
991
}
1000
992
1001
- var releasableClonedReabstractedClosures : [ PartialApplyInst ] = [ ]
1002
993
var origToClonedValueMap = origToClonedValueMap
1003
994
let finalClonedReabstractedClosure = inner ( rootClosure, clonedRootClosure, reabstractedClosure,
1004
- & releasableClonedReabstractedClosures , & origToClonedValueMap)
1005
- return ( finalClonedReabstractedClosure as! SingleValueInstruction , releasableClonedReabstractedClosures )
995
+ & origToClonedValueMap)
996
+ return ( finalClonedReabstractedClosure as! SingleValueInstruction )
1006
997
}
1007
998
1008
999
func destroyPartialApply( pai: PartialApplyInst , _ context: FunctionPassContext ) {
0 commit comments