Skip to content

[NFC] Utilities for lifetime dependent coroutines #77675

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Nov 18, 2024
24 changes: 24 additions & 0 deletions SwiftCompilerSources/Sources/SIL/Builder.swift
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,18 @@ public struct Builder {
private let notificationHandler: BridgedChangeNotificationHandler
private let notifyNewInstruction: (Instruction) -> ()

/// Return 'nil' when inserting at the start of a function or in a global initializer.
public var insertionBlock: BasicBlock? {
switch insertAt {
case let .before(inst):
return inst.parentBlock
case let .atEndOf(block):
return block
case .atStartOf, .staticInitializer:
return nil
}
}

public var bridged: BridgedBuilder {
switch insertAt {
case .before(let inst):
Expand Down Expand Up @@ -482,6 +494,18 @@ public struct Builder {
return notifyNew(endAccess.getAs(EndAccessInst.self))
}

@discardableResult
public func createEndApply(beginApply: BeginApplyInst) -> EndApplyInst {
let endApply = bridged.createEndApply(beginApply.token.bridged)
return notifyNew(endApply.getAs(EndApplyInst.self))
}

@discardableResult
public func createAbortApply(beginApply: BeginApplyInst) -> AbortApplyInst {
let endApply = bridged.createAbortApply(beginApply.token.bridged)
return notifyNew(endApply.getAs(AbortApplyInst.self))
}

public func createConvertFunction(originalFunction: Value, resultType: Type, withoutActuallyEscaping: Bool) -> ConvertFunctionInst {
let convertFunction = bridged.createConvertFunction(originalFunction.bridged, resultType.bridged, withoutActuallyEscaping)
return notifyNew(convertFunction.getAs(ConvertFunctionInst.self))
Expand Down
6 changes: 4 additions & 2 deletions SwiftCompilerSources/Sources/SIL/Instruction.swift
Original file line number Diff line number Diff line change
Expand Up @@ -1219,6 +1219,8 @@ final public class AllocExistentialBoxInst : SingleValueInstruction, Allocation
/// `end_borrow`).
public protocol ScopedInstruction {
var endOperands: LazyFilterSequence<UseList> { get }

var endInstructions: EndInstructions { get }
}

extension Instruction {
Expand Down Expand Up @@ -1330,7 +1332,7 @@ final public class BeginApplyInst : MultipleValueInstruction, FullApplySite {
}
}

final public class EndApplyInst : Instruction, UnaryInstruction {
final public class EndApplyInst : SingleValueInstruction, UnaryInstruction {
public var token: MultipleValueInstructionResult { operand.value as! MultipleValueInstructionResult }
public var beginApply: BeginApplyInst { token.parentInstruction as! BeginApplyInst }
}
Expand All @@ -1342,7 +1344,7 @@ final public class AbortApplyInst : Instruction, UnaryInstruction {

extension BeginApplyInst : ScopedInstruction {
public var endOperands: LazyFilterSequence<UseList> {
return token.uses.lazy.filter { _ in true }
return token.uses.lazy.filter { $0.endsLifetime }
}
}

Expand Down
85 changes: 66 additions & 19 deletions SwiftCompilerSources/Sources/SIL/Utilities/AccessUtils.swift
Original file line number Diff line number Diff line change
Expand Up @@ -103,11 +103,33 @@ public enum AccessBase : CustomStringConvertible, Hashable {
}
case let sb as StoreBorrowInst:
self = .storeBorrow(sb)
case let p2a as PointerToAddressInst:
if let global = p2a.resultOfGlobalAddressorCall {
self = .global(global)
} else {
self = .pointer(p2a)
}
default:
self = .unidentified
}
}

/// Return 'nil' for global varabiables and unidentified addresses.
public var address: Value? {
switch self {
case .global, .unidentified: return nil
case .box(let pbi): return pbi
case .stack(let asi): return asi
case .class(let rea): return rea
case .tail(let rta): return rta
case .argument(let arg): return arg
case .yield(let result): return result
case .storeBorrow(let sb): return sb
case .pointer(let p): return p
case .index(let ia): return ia
}
}

public var description: String {
switch self {
case .unidentified: return "?"
Expand Down Expand Up @@ -480,9 +502,40 @@ public enum EnclosingScope {
case base(AccessBase)
}

private struct EnclosingAccessWalker : AddressUseDefWalker {
var enclosingScope: EnclosingScope?

mutating func walk(startAt address: Value, initialPath: UnusedWalkingPath = UnusedWalkingPath()) {
if walkUp(address: address, path: UnusedWalkingPath()) == .abortWalk {
assert(enclosingScope == nil, "shouldn't have set an enclosing scope in an aborted walk")
}
}

mutating func rootDef(address: Value, path: UnusedWalkingPath) -> WalkResult {
assert(enclosingScope == nil, "rootDef should only called once")
// Try identifying the address a pointer originates from
if let p2ai = address as? PointerToAddressInst, let originatingAddr = p2ai.originatingAddress {
return walkUp(address: originatingAddr, path: path)
}
enclosingScope = .base(AccessBase(baseAddress: address))
return .continueWalk
}

mutating func walkUp(address: Value, path: UnusedWalkingPath) -> WalkResult {
if let ba = address as? BeginAccessInst {
enclosingScope = .scope(ba)
return .continueWalk
}
return walkUpDefault(address: address, path: path)
}
}

private struct AccessPathWalker : AddressUseDefWalker {
var result = AccessPath.unidentified()
var foundBeginAccess: BeginAccessInst?

// List of nested BeginAccessInst: inside-out order.
var foundBeginAccesses = SingleInlineArray<BeginAccessInst>()

let enforceConstantProjectionPath: Bool

init(enforceConstantProjectionPath: Bool = false) {
Expand Down Expand Up @@ -528,18 +581,9 @@ private struct AccessPathWalker : AddressUseDefWalker {
mutating func rootDef(address: Value, path: Path) -> WalkResult {
assert(result.base == .unidentified, "rootDef should only called once")
// Try identifying the address a pointer originates from
if let p2ai = address as? PointerToAddressInst {
if let originatingAddr = p2ai.originatingAddress {
return walkUp(address: originatingAddr, path: path)
} else if let global = p2ai.resultOfGlobalAddressorCall {
self.result = AccessPath(base: .global(global), projectionPath: path.projectionPath)
return .continueWalk
} else {
self.result = AccessPath(base: .pointer(p2ai), projectionPath: path.projectionPath)
return .continueWalk
}
if let p2ai = address as? PointerToAddressInst, let originatingAddr = p2ai.originatingAddress {
return walkUp(address: originatingAddr, path: path)
}

let base = AccessBase(baseAddress: address)
self.result = AccessPath(base: base, projectionPath: path.projectionPath)
return .continueWalk
Expand All @@ -557,8 +601,8 @@ private struct AccessPathWalker : AddressUseDefWalker {
// An `index_addr` instruction cannot be derived from an address
// projection. Bail out
return .abortWalk
} else if let ba = address as? BeginAccessInst, foundBeginAccess == nil {
foundBeginAccess = ba
} else if let ba = address as? BeginAccessInst {
foundBeginAccesses.push(ba)
}
return walkUpDefault(address: address, path: path.with(indexAddr: false))
}
Expand Down Expand Up @@ -611,17 +655,20 @@ extension Value {
public var accessPathWithScope: (AccessPath, scope: BeginAccessInst?) {
var walker = AccessPathWalker()
walker.walk(startAt: self)
return (walker.result, walker.foundBeginAccess)
return (walker.result, walker.foundBeginAccesses.first)
}

/// Computes the enclosing access scope of this address value.
public var enclosingAccessScope: EnclosingScope {
var walker = EnclosingAccessWalker()
walker.walk(startAt: self)
return walker.enclosingScope ?? .base(.unidentified)
}

public var accessBaseWithScopes: (AccessBase, SingleInlineArray<BeginAccessInst>) {
var walker = AccessPathWalker()
walker.walk(startAt: self)
if let ba = walker.foundBeginAccess {
return .scope(ba)
}
return .base(walker.result.base)
return (walker.result.base, walker.foundBeginAccesses)
}

/// The root definition of a reference, obtained by skipping ownership forwarding and ownership transition.
Expand Down
4 changes: 4 additions & 0 deletions include/swift/SIL/SILBridging.h
Original file line number Diff line number Diff line change
Expand Up @@ -1195,6 +1195,10 @@ struct BridgedBuilder{

SWIFT_IMPORT_UNSAFE BRIDGED_INLINE BridgedInstruction createEndAccess(BridgedValue value) const;

SWIFT_IMPORT_UNSAFE BRIDGED_INLINE BridgedInstruction createEndApply(BridgedValue value) const;

SWIFT_IMPORT_UNSAFE BRIDGED_INLINE BridgedInstruction createAbortApply(BridgedValue value) const;

SWIFT_IMPORT_UNSAFE BRIDGED_INLINE BridgedInstruction createConvertFunction(BridgedValue originalFunction, BridgedType resultType, bool withoutActuallyEscaping) const;
SWIFT_IMPORT_UNSAFE BRIDGED_INLINE BridgedInstruction createConvertEscapeToNoEscape(BridgedValue originalFunction, BridgedType resultType, bool isLifetimeGuaranteed) const;

Expand Down
10 changes: 10 additions & 0 deletions include/swift/SIL/SILBridgingImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -2296,6 +2296,16 @@ BridgedInstruction BridgedBuilder::createEndAccess(BridgedValue value) const {
return {unbridged().createEndAccess(regularLoc(), value.getSILValue(), false)};
}

BridgedInstruction BridgedBuilder::createEndApply(BridgedValue value) const {
swift::ASTContext &ctxt = unbridged().getASTContext();
return {unbridged().createEndApply(regularLoc(), value.getSILValue(),
swift::SILType::getEmptyTupleType(ctxt))};
}

BridgedInstruction BridgedBuilder::createAbortApply(BridgedValue value) const {
return {unbridged().createAbortApply(regularLoc(), value.getSILValue())};
}

BridgedInstruction BridgedBuilder::createConvertFunction(BridgedValue originalFunction, BridgedType resultType, bool withoutActuallyEscaping) const {
return {unbridged().createConvertFunction(regularLoc(), originalFunction.getSILValue(), resultType.unbridged(), withoutActuallyEscaping)};
}
Expand Down
33 changes: 33 additions & 0 deletions include/swift/SIL/SILInstruction.h
Original file line number Diff line number Diff line change
Expand Up @@ -3217,6 +3217,13 @@ class EndApplyInst;
class AbortApplyInst;
class EndBorrowInst;

struct EndApplyFilter {
std::optional<Operand*> operator()(Operand *use) const;
};

using EndApplyRange = OptionalTransformRange<ValueBase::use_range,
EndApplyFilter>;

/// BeginApplyInst - Represents the beginning of the full application of
/// a yield_once coroutine (up until the coroutine yields a value back).
class BeginApplyInst final
Expand Down Expand Up @@ -3267,6 +3274,8 @@ class BeginApplyInst final
&getAllResultsBuffer().drop_back(isCalleeAllocated() ? 1 : 0).back());
}

EndApplyRange getEndApplyUses() const;

MultipleValueInstructionResult *getCalleeAllocationResult() const {
if (!isCalleeAllocated()) {
return nullptr;
Expand Down Expand Up @@ -3336,6 +3345,25 @@ class EndApplyInst
}
};

inline std::optional<Operand*>
EndApplyFilter::operator()(Operand *use) const {
// An end_borrow ends the coroutine scope at a dead-end block without
// terminating the coroutine.
switch (use->getUser()->getKind()) {
case SILInstructionKind::EndApplyInst:
case SILInstructionKind::AbortApplyInst:
case SILInstructionKind::EndBorrowInst:
return use;
default:
return std::nullopt;
}
}

inline EndApplyRange BeginApplyInst::getEndApplyUses() const {
return makeOptionalTransformRange(
getTokenResult()->getUses(), EndApplyFilter());
}

//===----------------------------------------------------------------------===//
// Literal instructions.
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -8802,6 +8830,11 @@ class MarkDependenceInst
uint8_t(MarkDependenceKind::NonEscaping);
}

void settleToEscaping() {
sharedUInt8().MarkDependenceInst.dependenceKind =
uint8_t(MarkDependenceKind::Escaping);
}

/// Visit the instructions that end the lifetime of an OSSA on-stack closure.
bool visitNonEscapingLifetimeEnds(
llvm::function_ref<bool (Operand*)> visitScopeEnd,
Expand Down
14 changes: 7 additions & 7 deletions lib/SIL/IR/SILInstructions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -712,8 +712,8 @@ void BeginApplyInst::getCoroutineEndPoints(
SmallVectorImpl<EndApplyInst *> &endApplyInsts,
SmallVectorImpl<AbortApplyInst *> &abortApplyInsts,
SmallVectorImpl<EndBorrowInst *> *endBorrowInsts) const {
for (auto *tokenUse : getTokenResult()->getUses()) {
auto *user = tokenUse->getUser();
for (auto *use : getEndApplyUses()) {
auto *user = use->getUser();
if (auto *end = dyn_cast<EndApplyInst>(user)) {
endApplyInsts.push_back(end);
continue;
Expand All @@ -733,19 +733,19 @@ void BeginApplyInst::getCoroutineEndPoints(
SmallVectorImpl<Operand *> &endApplyInsts,
SmallVectorImpl<Operand *> &abortApplyInsts,
SmallVectorImpl<Operand *> *endBorrowInsts) const {
for (auto *tokenUse : getTokenResult()->getUses()) {
auto *user = tokenUse->getUser();
for (auto *use : getEndApplyUses()) {
auto *user = use->getUser();
if (isa<EndApplyInst>(user)) {
endApplyInsts.push_back(tokenUse);
endApplyInsts.push_back(use);
continue;
}
if (isa<AbortApplyInst>(user)) {
abortApplyInsts.push_back(tokenUse);
abortApplyInsts.push_back(use);
continue;
}

assert(isa<EndBorrowInst>(user));
abortApplyInsts.push_back(tokenUse);
abortApplyInsts.push_back(use);
}
}

Expand Down
3 changes: 1 addition & 2 deletions lib/SIL/Utils/OwnershipUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -694,8 +694,7 @@ bool BorrowingOperand::visitScopeEndingUses(
}
case BorrowingOperandKind::BeginApply: {
bool deadApply = true;
auto *user = cast<BeginApplyInst>(op->getUser());
for (auto *use : user->getTokenResult()->getUses()) {
for (auto *use : cast<BeginApplyInst>(op->getUser())->getEndApplyUses()) {
deadApply = false;
if (!visitScopeEnd(use))
return false;
Expand Down
3 changes: 1 addition & 2 deletions lib/SIL/Verifier/SILOwnershipVerifier.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -583,8 +583,7 @@ bool SILValueOwnershipChecker::checkYieldWithoutLifetimeEndingUses(
// If we have a guaranteed value, make sure that all uses are before our
// end_yield.
SmallVector<Operand *, 4> coroutineEndUses;
for (auto *use : yield->getParent<BeginApplyInst>()->
getTokenResult()->getUses()) {
for (auto *use : yield->getParent<BeginApplyInst>()->getEndApplyUses()) {
coroutineEndUses.push_back(use);
}

Expand Down
4 changes: 2 additions & 2 deletions lib/SILOptimizer/Mandatory/AddressLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2682,8 +2682,8 @@ void ApplyRewriter::convertBeginApplyWithOpaqueYield() {
SILValue load =
resultBuilder.emitLoadBorrowOperation(callLoc, &newResult);
oldResult.replaceAllUsesWith(load);
for (auto *user : origCall->getTokenResult()->getUsers()) {
pass.getBuilder(user->getIterator())
for (auto *use : origCall->getEndApplyUses()) {
pass.getBuilder(use->getUser()->getIterator())
.createEndBorrow(pass.genLoc(), load);
}
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -453,7 +453,8 @@ static bool visitScopeEndsRequiringInit(
// Check for yields from a modify coroutine.
if (auto bai =
dyn_cast_or_null<BeginApplyInst>(operand->getDefiningInstruction())) {
for (auto *inst : bai->getTokenResult()->getUsers()) {
for (auto *use : bai->getEndApplyUses()) {
auto *inst = use->getUser();
assert(isa<EndApplyInst>(inst) || isa<AbortApplyInst>(inst) ||
isa<EndBorrowInst>(inst));
visit(inst, ScopeRequiringFinalInit::Coroutine);
Expand Down
4 changes: 2 additions & 2 deletions lib/SILOptimizer/Transforms/TempRValueElimination.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -210,8 +210,8 @@ collectLoads(Operand *addressUse, CopyAddrInst *originalCopy,
// Register 'end_apply'/'abort_apply' as loads as well
// 'checkNoSourceModification' should check instructions until
// 'end_apply'/'abort_apply'.
for (auto tokenUse : beginApply->getTokenResult()->getUses()) {
SILInstruction *tokenUser = tokenUse->getUser();
for (auto *tokenUse : beginApply->getEndApplyUses()) {
auto *tokenUser = tokenUse->getUser();
if (tokenUser->getParent() != block)
return false;
loadInsts.insert(tokenUser);
Expand Down
11 changes: 4 additions & 7 deletions lib/SILOptimizer/Utils/Devirtualize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -643,14 +643,11 @@ replaceBeginApplyInst(SILBuilder &builder, SILPassManager *pm, SILLocation loc,
if (newArgBorrows.empty())
return {newBAI, changedCFG};

SILValue token = newBAI->getTokenResult();

// The token will only be used by end_apply and abort_apply. Use that to
// insert the end_borrows we need /after/ those uses.
for (auto *use : token->getUses()) {
// Insert the end_borrows after end_apply and abort_apply users.
for (auto *use : newBAI->getEndApplyUses()) {
SILBuilderWithScope borrowBuilder(
&*std::next(use->getUser()->getIterator()),
builder.getBuilderContext());
&*std::next(use->getUser()->getIterator()),
builder.getBuilderContext());
for (SILValue borrow : newArgBorrows) {
borrowBuilder.createEndBorrow(loc, borrow);
}
Expand Down
Loading