Skip to content

Commit 5b18641

Browse files
[LLVM][Coroutines] Transform "coro_elide_safe" calls to switch ABI coroutines to the noalloc variant
1 parent e2a6027 commit 5b18641

File tree

11 files changed

+281
-2
lines changed

11 files changed

+281
-2
lines changed
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
//===- CoroAnnotationElide.h - Elide attributed safe coroutine calls ------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// \file
10+
// This pass transforms all Call or Invoke instructions that are annotated
11+
// "coro_elide_safe" to call the `.noalloc` variant of coroutine instead.
12+
// The frame of the callee coroutine is allocated inside the caller. A pointer
13+
// to the allocated frame will be passed into the `.noalloc` ramp function.
14+
//
15+
//===----------------------------------------------------------------------===//
16+
17+
#ifndef LLVM_TRANSFORMS_COROUTINES_COROANNOTATIONELIDE_H
18+
#define LLVM_TRANSFORMS_COROUTINES_COROANNOTATIONELIDE_H
19+
20+
#include "llvm/Analysis/CGSCCPassManager.h"
21+
#include "llvm/Analysis/LazyCallGraph.h"
22+
#include "llvm/IR/PassManager.h"
23+
24+
namespace llvm {
25+
26+
struct CoroAnnotationElidePass : PassInfoMixin<CoroAnnotationElidePass> {
27+
CoroAnnotationElidePass() {}
28+
29+
PreservedAnalyses run(LazyCallGraph::SCC &C, CGSCCAnalysisManager &AM,
30+
LazyCallGraph &CG, CGSCCUpdateResult &UR);
31+
32+
static bool isRequired() { return false; }
33+
};
34+
} // end namespace llvm
35+
36+
#endif // LLVM_TRANSFORMS_COROUTINES_COROANNOTATIONELIDE_H

llvm/lib/Passes/PassBuilder.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,7 @@
138138
#include "llvm/Target/TargetMachine.h"
139139
#include "llvm/Transforms/AggressiveInstCombine/AggressiveInstCombine.h"
140140
#include "llvm/Transforms/CFGuard.h"
141+
#include "llvm/Transforms/Coroutines/CoroAnnotationElide.h"
141142
#include "llvm/Transforms/Coroutines/CoroCleanup.h"
142143
#include "llvm/Transforms/Coroutines/CoroConditionalWrapper.h"
143144
#include "llvm/Transforms/Coroutines/CoroEarly.h"

llvm/lib/Passes/PassBuilderPipelines.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
#include "llvm/Support/VirtualFileSystem.h"
3434
#include "llvm/Target/TargetMachine.h"
3535
#include "llvm/Transforms/AggressiveInstCombine/AggressiveInstCombine.h"
36+
#include "llvm/Transforms/Coroutines/CoroAnnotationElide.h"
3637
#include "llvm/Transforms/Coroutines/CoroCleanup.h"
3738
#include "llvm/Transforms/Coroutines/CoroConditionalWrapper.h"
3839
#include "llvm/Transforms/Coroutines/CoroEarly.h"
@@ -984,8 +985,10 @@ PassBuilder::buildInlinerPipeline(OptimizationLevel Level,
984985
MainCGPipeline.addPass(createCGSCCToFunctionPassAdaptor(
985986
RequireAnalysisPass<ShouldNotRunFunctionPassesAnalysis, Function>()));
986987

987-
if (Phase != ThinOrFullLTOPhase::ThinLTOPreLink)
988+
if (Phase != ThinOrFullLTOPhase::ThinLTOPreLink) {
988989
MainCGPipeline.addPass(CoroSplitPass(Level != OptimizationLevel::O0));
990+
MainCGPipeline.addPass(CoroAnnotationElidePass());
991+
}
989992

990993
// Make sure we don't affect potential future NoRerun CGSCC adaptors.
991994
MIWP.addLateModulePass(createModuleToFunctionPassAdaptor(
@@ -1027,9 +1030,12 @@ PassBuilder::buildModuleInlinerPipeline(OptimizationLevel Level,
10271030
buildFunctionSimplificationPipeline(Level, Phase),
10281031
PTO.EagerlyInvalidateAnalyses));
10291032

1030-
if (Phase != ThinOrFullLTOPhase::ThinLTOPreLink)
1033+
if (Phase != ThinOrFullLTOPhase::ThinLTOPreLink) {
10311034
MPM.addPass(createModuleToPostOrderCGSCCPassAdaptor(
10321035
CoroSplitPass(Level != OptimizationLevel::O0)));
1036+
MPM.addPass(
1037+
createModuleToPostOrderCGSCCPassAdaptor(CoroAnnotationElidePass()));
1038+
}
10331039

10341040
return MPM;
10351041
}

llvm/lib/Passes/PassRegistry.def

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,7 @@ CGSCC_PASS("attributor-light-cgscc", AttributorLightCGSCCPass())
243243
CGSCC_PASS("invalidate<all>", InvalidateAllAnalysesPass())
244244
CGSCC_PASS("no-op-cgscc", NoOpCGSCCPass())
245245
CGSCC_PASS("openmp-opt-cgscc", OpenMPOptCGSCCPass())
246+
CGSCC_PASS("coro-annotation-elide", CoroAnnotationElidePass())
246247
#undef CGSCC_PASS
247248

248249
#ifndef CGSCC_PASS_WITH_PARAMS

llvm/lib/Transforms/Coroutines/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
add_llvm_component_library(LLVMCoroutines
22
Coroutines.cpp
3+
CoroAnnotationElide.cpp
34
CoroCleanup.cpp
45
CoroConditionalWrapper.cpp
56
CoroEarly.cpp
Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
//===- CoroAnnotationElide.cpp - Elide attributed safe coroutine calls ----===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// \file
10+
// This pass transforms all Call or Invoke instructions that are annotated
11+
// "coro_elide_safe" to call the `.noalloc` variant of coroutine instead.
12+
// The frame of the callee coroutine is allocated inside the caller. A pointer
13+
// to the allocated frame will be passed into the `.noalloc` ramp function.
14+
//
15+
//===----------------------------------------------------------------------===//
16+
17+
#include "llvm/Transforms/Coroutines/CoroAnnotationElide.h"
18+
19+
#include "llvm/Analysis/CGSCCPassManager.h"
20+
#include "llvm/Analysis/LazyCallGraph.h"
21+
#include "llvm/Analysis/OptimizationRemarkEmitter.h"
22+
#include "llvm/IR/Analysis.h"
23+
#include "llvm/IR/IRBuilder.h"
24+
#include "llvm/IR/InstIterator.h"
25+
#include "llvm/IR/Instruction.h"
26+
#include "llvm/IR/Module.h"
27+
#include "llvm/IR/PassManager.h"
28+
#include "llvm/Transforms/Utils/CallGraphUpdater.h"
29+
30+
#include <cassert>
31+
32+
using namespace llvm;
33+
34+
#define DEBUG_TYPE "coro-annotation-elide"
35+
36+
static Instruction *getFirstNonAllocaInTheEntryBlock(Function *F) {
37+
for (Instruction &I : F->getEntryBlock())
38+
if (!isa<AllocaInst>(&I))
39+
return &I;
40+
llvm_unreachable("no terminator in the entry block");
41+
}
42+
43+
// Create an alloca in the caller, using FrameSize and FrameAlign as the callee
44+
// coroutine's activation frame.
45+
static Value *allocateFrameInCaller(Function *Caller, uint64_t FrameSize,
46+
Align FrameAlign) {
47+
LLVMContext &C = Caller->getContext();
48+
BasicBlock::iterator InsertPt =
49+
getFirstNonAllocaInTheEntryBlock(Caller)->getIterator();
50+
const DataLayout &DL = Caller->getDataLayout();
51+
auto FrameTy = ArrayType::get(Type::getInt8Ty(C), FrameSize);
52+
auto *Frame = new AllocaInst(FrameTy, DL.getAllocaAddrSpace(), "", InsertPt);
53+
Frame->setAlignment(FrameAlign);
54+
return Frame;
55+
}
56+
57+
// Given a call or invoke instruction to the elide safe coroutine, this function
58+
// does the following:
59+
// - Allocate a frame for the callee coroutine in the caller using alloca.
60+
// - Replace the old CB with a new Call or Invoke to `NewCallee`, with the
61+
// pointer to the frame as an additional argument to NewCallee.
62+
static void processCall(CallBase *CB, Function *Caller, Function *NewCallee,
63+
uint64_t FrameSize, Align FrameAlign) {
64+
// TODO: generate the lifetime intrinsics for the new frame. This will require
65+
// introduction of two pesudo lifetime intrinsics in the frontend around the
66+
// `co_await` expression and convert them to real lifetime intrinsics here.
67+
auto *FramePtr = allocateFrameInCaller(Caller, FrameSize, FrameAlign);
68+
auto NewCBInsertPt = CB->getIterator();
69+
llvm::CallBase *NewCB = nullptr;
70+
SmallVector<Value *, 4> NewArgs;
71+
NewArgs.append(CB->arg_begin(), CB->arg_end());
72+
NewArgs.push_back(FramePtr);
73+
74+
if (auto *CI = dyn_cast<CallInst>(CB)) {
75+
auto *NewCI = CallInst::Create(NewCallee->getFunctionType(), NewCallee,
76+
NewArgs, "", NewCBInsertPt);
77+
NewCI->setTailCallKind(CI->getTailCallKind());
78+
NewCB = NewCI;
79+
} else if (auto *II = dyn_cast<InvokeInst>(CB)) {
80+
NewCB = InvokeInst::Create(NewCallee->getFunctionType(), NewCallee,
81+
II->getNormalDest(), II->getUnwindDest(),
82+
NewArgs, std::nullopt, "", NewCBInsertPt);
83+
} else {
84+
llvm_unreachable("CallBase should either be Call or Invoke!");
85+
}
86+
87+
NewCB->setCalledFunction(NewCallee->getFunctionType(), NewCallee);
88+
NewCB->setCallingConv(CB->getCallingConv());
89+
NewCB->setAttributes(CB->getAttributes());
90+
NewCB->setDebugLoc(CB->getDebugLoc());
91+
std::copy(CB->bundle_op_info_begin(), CB->bundle_op_info_end(),
92+
NewCB->bundle_op_info_begin());
93+
94+
NewCB->removeFnAttr(llvm::Attribute::CoroElideSafe);
95+
CB->replaceAllUsesWith(NewCB);
96+
CB->eraseFromParent();
97+
}
98+
99+
PreservedAnalyses CoroAnnotationElidePass::run(LazyCallGraph::SCC &C,
100+
CGSCCAnalysisManager &AM,
101+
LazyCallGraph &CG,
102+
CGSCCUpdateResult &UR) {
103+
bool Changed = false;
104+
CallGraphUpdater CGUpdater;
105+
CGUpdater.initialize(CG, C, AM, UR);
106+
107+
auto &FAM =
108+
AM.getResult<FunctionAnalysisManagerCGSCCProxy>(C, CG).getManager();
109+
110+
for (LazyCallGraph::Node &N : C) {
111+
Function *Callee = &N.getFunction();
112+
Function *NewCallee = Callee->getParent()->getFunction(
113+
(Callee->getName() + ".noalloc").str());
114+
if (!NewCallee) {
115+
continue;
116+
}
117+
118+
auto FramePtrArgPosition = NewCallee->arg_size() - 1;
119+
auto FrameSize =
120+
NewCallee->getParamDereferenceableBytes(FramePtrArgPosition);
121+
auto FrameAlign =
122+
NewCallee->getParamAlign(FramePtrArgPosition).valueOrOne();
123+
124+
SmallVector<CallBase *, 4> Users;
125+
for (auto *U : Callee->users()) {
126+
if (auto *CB = dyn_cast<CallBase>(U)) {
127+
if (CB->getCalledFunction() == Callee)
128+
Users.push_back(CB);
129+
}
130+
}
131+
132+
auto &ORE = FAM.getResult<OptimizationRemarkEmitterAnalysis>(*Callee);
133+
134+
for (auto *CB : Users) {
135+
auto *Caller = CB->getFunction();
136+
if (Caller && Caller->isPresplitCoroutine() &&
137+
CB->hasFnAttr(llvm::Attribute::CoroElideSafe)) {
138+
139+
auto *CallerN = CG.lookup(*Caller);
140+
auto *CallerC = CG.lookupSCC(*CallerN);
141+
processCall(CB, Caller, NewCallee, FrameSize, FrameAlign);
142+
143+
ORE.emit([&]() {
144+
return OptimizationRemark(DEBUG_TYPE, "CoroAnnotationElide", Caller)
145+
<< "'" << ore::NV("callee", Callee->getName())
146+
<< "' elided in '" << ore::NV("caller", Caller->getName());
147+
});
148+
Changed = true;
149+
updateCGAndAnalysisManagerForCGSCCPass(CG, *CallerC, *CallerN, AM, UR,
150+
FAM);
151+
}
152+
}
153+
}
154+
return Changed ? PreservedAnalyses::none() : PreservedAnalyses::all();
155+
}

llvm/test/Other/new-pm-defaults.ll

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,7 @@
226226
; CHECK-O-NEXT: Running pass: RequireAnalysisPass<{{.*}}ShouldNotRunFunctionPassesAnalysis
227227
; CHECK-O-NEXT: Running analysis: ShouldNotRunFunctionPassesAnalysis
228228
; CHECK-O-NEXT: Running pass: CoroSplitPass
229+
; CHECK-O-NEXT: Running pass: CoroAnnotationElidePass
229230
; CHECK-O-NEXT: Running pass: InvalidateAnalysisPass<{{.*}}ShouldNotRunFunctionPassesAnalysis
230231
; CHECK-O-NEXT: Invalidating analysis: ShouldNotRunFunctionPassesAnalysis
231232
; CHECK-O-NEXT: Invalidating analysis: InlineAdvisorAnalysis

llvm/test/Other/new-pm-thinlto-postlink-defaults.ll

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,7 @@
153153
; CHECK-O-NEXT: Running pass: RequireAnalysisPass<{{.*}}ShouldNotRunFunctionPassesAnalysis
154154
; CHECK-O-NEXT: Running analysis: ShouldNotRunFunctionPassesAnalysis
155155
; CHECK-O-NEXT: Running pass: CoroSplitPass
156+
; CHECK-O-NEXT: Running pass: CoroAnnotationElidePass
156157
; CHECK-O-NEXT: Running pass: InvalidateAnalysisPass<{{.*}}ShouldNotRunFunctionPassesAnalysis
157158
; CHECK-O-NEXT: Invalidating analysis: ShouldNotRunFunctionPassesAnalysis
158159
; CHECK-O-NEXT: Invalidating analysis: InlineAdvisorAnalysis

llvm/test/Other/new-pm-thinlto-postlink-pgo-defaults.ll

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@
137137
; CHECK-O-NEXT: Running pass: RequireAnalysisPass<{{.*}}ShouldNotRunFunctionPassesAnalysis
138138
; CHECK-O-NEXT: Running analysis: ShouldNotRunFunctionPassesAnalysis
139139
; CHECK-O-NEXT: Running pass: CoroSplitPass
140+
; CHECK-O-NEXT: Running pass: CoroAnnotationElidePass
140141
; CHECK-O-NEXT: Running pass: InvalidateAnalysisPass<{{.*}}ShouldNotRunFunctionPassesAnalysis
141142
; CHECK-O-NEXT: Invalidating analysis: ShouldNotRunFunctionPassesAnalysis
142143
; CHECK-O-NEXT: Invalidating analysis: InlineAdvisorAnalysis

llvm/test/Other/new-pm-thinlto-postlink-samplepgo-defaults.ll

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,7 @@
145145
; CHECK-O-NEXT: Running pass: RequireAnalysisPass<{{.*}}ShouldNotRunFunctionPassesAnalysis
146146
; CHECK-O-NEXT: Running analysis: ShouldNotRunFunctionPassesAnalysis
147147
; CHECK-O-NEXT: Running pass: CoroSplitPass
148+
; CHECK-O-NEXT: Running pass: CoroAnnotationElidePass
148149
; CHECK-O-NEXT: Running pass: InvalidateAnalysisPass<{{.*}}ShouldNotRunFunctionPassesAnalysis
149150
; CHECK-O-NEXT: Invalidating analysis: ShouldNotRunFunctionPassesAnalysis
150151
; CHECK-O-NEXT: Invalidating analysis: InlineAdvisorAnalysis
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
; Testing elide performed its job for calls to coroutines marked safe.
2+
; RUN: opt < %s -S -passes='cgscc(coro-annotation-elide)' | FileCheck %s
3+
4+
%struct.Task = type { ptr }
5+
6+
declare void @print(i32) nounwind
7+
8+
; resume part of the coroutine
9+
define fastcc void @callee.resume(ptr dereferenceable(1)) {
10+
tail call void @print(i32 0)
11+
ret void
12+
}
13+
14+
; destroy part of the coroutine
15+
define fastcc void @callee.destroy(ptr) {
16+
tail call void @print(i32 1)
17+
ret void
18+
}
19+
20+
; cleanup part of the coroutine
21+
define fastcc void @callee.cleanup(ptr) {
22+
tail call void @print(i32 2)
23+
ret void
24+
}
25+
26+
@callee.resumers = internal constant [3 x ptr] [
27+
ptr @callee.resume, ptr @callee.destroy, ptr @callee.cleanup]
28+
29+
declare void @alloc(i1) nounwind
30+
31+
; CHECK-LABEL: define ptr @callee
32+
define ptr @callee(i8 %arg) {
33+
entry:
34+
%task = alloca %struct.Task, align 8
35+
%id = call token @llvm.coro.id(i32 0, ptr null,
36+
ptr @callee,
37+
ptr @callee.resumers)
38+
%alloc = call i1 @llvm.coro.alloc(token %id)
39+
%hdl = call ptr @llvm.coro.begin(token %id, ptr null)
40+
store ptr %hdl, ptr %task
41+
ret ptr %task
42+
}
43+
44+
; CHECK-LABEL: define ptr @callee.noalloc
45+
define ptr @callee.noalloc(i8 %arg, ptr dereferenceable(32) align(8) %frame) {
46+
entry:
47+
%task = alloca %struct.Task, align 8
48+
%id = call token @llvm.coro.id(i32 0, ptr null,
49+
ptr @callee,
50+
ptr @callee.resumers)
51+
%hdl = call ptr @llvm.coro.begin(token %id, ptr null)
52+
store ptr %hdl, ptr %task
53+
ret ptr %task
54+
}
55+
56+
; CHECK-LABEL: define ptr @caller()
57+
; Function Attrs: presplitcoroutine
58+
define ptr @caller() #0 {
59+
entry:
60+
%task = call ptr @callee(i8 0) #1
61+
ret ptr %task
62+
63+
; CHECK: %[[FRAME:.+]] = alloca [32 x i8], align 8
64+
; CHECK-NEXT: %[[TASK:.+]] = call ptr @callee.noalloc(i8 0, ptr %[[FRAME]])
65+
; CHECK-NEXT: ret ptr %[[TASK]]
66+
}
67+
68+
declare token @llvm.coro.id(i32, ptr, ptr, ptr)
69+
declare ptr @llvm.coro.begin(token, ptr)
70+
declare ptr @llvm.coro.frame()
71+
declare ptr @llvm.coro.subfn.addr(ptr, i8)
72+
declare i1 @llvm.coro.alloc(token)
73+
74+
attributes #0 = { presplitcoroutine }
75+
attributes #1 = { coro_elide_safe }

0 commit comments

Comments
 (0)