Skip to content

Commit 5578fde

Browse files
[Clang] Introduce [[clang::coro_await_elidable]]
1 parent cece4ba commit 5578fde

25 files changed

+335
-110
lines changed

clang/docs/ReleaseNotes.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,9 @@ Attribute Changes in Clang
136136
- The ``hybrid_patchable`` attribute is now supported on ARM64EC targets. It can be used to specify
137137
that a function requires an additional x86-64 thunk, which may be patched at runtime.
138138

139+
- Introduced a new attribute ``[[clang::coro_await_elidable]]`` on coroutine return types
140+
to express elideability at call sites where the coroutine is co_awaited as a prvalue.
141+
139142
Improvements to Clang's diagnostics
140143
-----------------------------------
141144

clang/include/clang/AST/Expr.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2991,6 +2991,9 @@ class CallExpr : public Expr {
29912991

29922992
bool hasStoredFPFeatures() const { return CallExprBits.HasFPFeatures; }
29932993

2994+
bool isCoroMustElide() const { return CallExprBits.IsCoroMustElide; }
2995+
void setCoroMustElide(bool V = true) { CallExprBits.IsCoroMustElide = V; }
2996+
29942997
Decl *getCalleeDecl() { return getCallee()->getReferencedDeclOfCallee(); }
29952998
const Decl *getCalleeDecl() const {
29962999
return getCallee()->getReferencedDeclOfCallee();

clang/include/clang/AST/Stmt.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -561,8 +561,11 @@ class alignas(void *) Stmt {
561561
LLVM_PREFERRED_TYPE(bool)
562562
unsigned HasFPFeatures : 1;
563563

564+
/// True if the call expression is a must-elide call to a coroutine.
565+
unsigned IsCoroMustElide : 1;
566+
564567
/// Padding used to align OffsetToTrailingObjects to a byte multiple.
565-
unsigned : 24 - 3 - NumExprBits;
568+
unsigned : 24 - 4 - NumExprBits;
566569

567570
/// The offset in bytes from the this pointer to the start of the
568571
/// trailing objects belonging to CallExpr. Intentionally byte sized

clang/include/clang/Basic/Attr.td

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1220,6 +1220,14 @@ def CoroDisableLifetimeBound : InheritableAttr {
12201220
let SimpleHandler = 1;
12211221
}
12221222

1223+
def CoroAwaitElidable : InheritableAttr {
1224+
let Spellings = [Clang<"coro_await_elidable">];
1225+
let Subjects = SubjectList<[CXXRecord]>;
1226+
let LangOpts = [CPlusPlus];
1227+
let Documentation = [CoroAwaitElidableDoc];
1228+
let SimpleHandler = 1;
1229+
}
1230+
12231231
// OSObject-based attributes.
12241232
def OSConsumed : InheritableParamAttr {
12251233
let Spellings = [Clang<"os_consumed">];

clang/include/clang/Basic/AttrDocs.td

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8147,6 +8147,38 @@ but do not pass them to the underlying coroutine or pass them by value.
81478147
}];
81488148
}
81498149

8150+
def CoroAwaitElidableDoc : Documentation {
8151+
let Category = DocCatDecl;
8152+
let Content = [{
8153+
The ``[[clang::coro_await_elidable]]`` is a class attribute which can be applied
8154+
to a coroutine return type.
8155+
8156+
When a coroutine function that returns such a type calls another coroutine function,
8157+
the compiler performs heap allocation elision when the call to the coroutine function
8158+
is immediately co_awaited as a prvalue. In this case, the coroutine frame for the
8159+
callee will be a local variable within the enclosing braces in the caller's stack
8160+
frame. And the local variable, like other variables in coroutines, may be collected
8161+
into the coroutine frame, which may be allocated on the heap.
8162+
8163+
Example:
8164+
8165+
.. code-block:: c++
8166+
8167+
class [[clang::coro_await_elidable]] Task { ... };
8168+
8169+
Task foo();
8170+
Task bar() {
8171+
co_await foo(); // foo()'s coroutine frame on this line is elidable
8172+
auto t = foo(); // foo()'s coroutine frame on this line is NOT elidable
8173+
co_await t;
8174+
}
8175+
8176+
The behavior is undefined if the caller coroutine is destroyed earlier than the
8177+
callee coroutine.
8178+
8179+
}];
8180+
}
8181+
81508182
def CountedByDocs : Documentation {
81518183
let Category = DocCatField;
81528184
let Content = [{
@@ -8306,4 +8338,3 @@ Declares that a function potentially allocates heap memory, and prevents any pot
83068338
of ``nonallocating`` by the compiler.
83078339
}];
83088340
}
8309-

clang/lib/AST/Expr.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1474,6 +1474,7 @@ CallExpr::CallExpr(StmtClass SC, Expr *Fn, ArrayRef<Expr *> PreArgs,
14741474
this->computeDependence();
14751475

14761476
CallExprBits.HasFPFeatures = FPFeatures.requiresTrailingStorage();
1477+
CallExprBits.IsCoroMustElide = false;
14771478
if (hasStoredFPFeatures())
14781479
setStoredFPFeatures(FPFeatures);
14791480
}
@@ -1489,6 +1490,7 @@ CallExpr::CallExpr(StmtClass SC, unsigned NumPreArgs, unsigned NumArgs,
14891490
assert((CallExprBits.OffsetToTrailingObjects == OffsetToTrailingObjects) &&
14901491
"OffsetToTrailingObjects overflow!");
14911492
CallExprBits.HasFPFeatures = HasFPFeatures;
1493+
CallExprBits.IsCoroMustElide = false;
14921494
}
14931495

14941496
CallExpr *CallExpr::Create(const ASTContext &Ctx, Expr *Fn,

clang/lib/CodeGen/CGBlocks.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1163,7 +1163,8 @@ llvm::Type *CodeGenModule::getGenericBlockLiteralType() {
11631163
}
11641164

11651165
RValue CodeGenFunction::EmitBlockCallExpr(const CallExpr *E,
1166-
ReturnValueSlot ReturnValue) {
1166+
ReturnValueSlot ReturnValue,
1167+
llvm::CallBase **CallOrInvoke) {
11671168
const auto *BPT = E->getCallee()->getType()->castAs<BlockPointerType>();
11681169
llvm::Value *BlockPtr = EmitScalarExpr(E->getCallee());
11691170
llvm::Type *GenBlockTy = CGM.getGenericBlockLiteralType();
@@ -1220,7 +1221,7 @@ RValue CodeGenFunction::EmitBlockCallExpr(const CallExpr *E,
12201221
CGCallee Callee(CGCalleeInfo(), Func);
12211222

12221223
// And call the block.
1223-
return EmitCall(FnInfo, Callee, ReturnValue, Args);
1224+
return EmitCall(FnInfo, Callee, ReturnValue, Args, CallOrInvoke);
12241225
}
12251226

12261227
Address CodeGenFunction::GetAddrOfBlockDecl(const VarDecl *variable) {

clang/lib/CodeGen/CGCUDARuntime.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@ CGCUDARuntime::~CGCUDARuntime() {}
2525

2626
RValue CGCUDARuntime::EmitCUDAKernelCallExpr(CodeGenFunction &CGF,
2727
const CUDAKernelCallExpr *E,
28-
ReturnValueSlot ReturnValue) {
28+
ReturnValueSlot ReturnValue,
29+
llvm::CallBase **CallOrInvoke) {
2930
llvm::BasicBlock *ConfigOKBlock = CGF.createBasicBlock("kcall.configok");
3031
llvm::BasicBlock *ContBlock = CGF.createBasicBlock("kcall.end");
3132

@@ -35,7 +36,7 @@ RValue CGCUDARuntime::EmitCUDAKernelCallExpr(CodeGenFunction &CGF,
3536

3637
eval.begin(CGF);
3738
CGF.EmitBlock(ConfigOKBlock);
38-
CGF.EmitSimpleCallExpr(E, ReturnValue);
39+
CGF.EmitSimpleCallExpr(E, ReturnValue, CallOrInvoke);
3940
CGF.EmitBranch(ContBlock);
4041

4142
CGF.EmitBlock(ContBlock);

clang/lib/CodeGen/CGCUDARuntime.h

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include "llvm/IR/GlobalValue.h"
2222

2323
namespace llvm {
24+
class CallBase;
2425
class Function;
2526
class GlobalVariable;
2627
}
@@ -82,9 +83,10 @@ class CGCUDARuntime {
8283
CGCUDARuntime(CodeGenModule &CGM) : CGM(CGM) {}
8384
virtual ~CGCUDARuntime();
8485

85-
virtual RValue EmitCUDAKernelCallExpr(CodeGenFunction &CGF,
86-
const CUDAKernelCallExpr *E,
87-
ReturnValueSlot ReturnValue);
86+
virtual RValue
87+
EmitCUDAKernelCallExpr(CodeGenFunction &CGF, const CUDAKernelCallExpr *E,
88+
ReturnValueSlot ReturnValue,
89+
llvm::CallBase **CallOrInvoke = nullptr);
8890

8991
/// Emits a kernel launch stub.
9092
virtual void emitDeviceStub(CodeGenFunction &CGF, FunctionArgList &Args) = 0;

clang/lib/CodeGen/CGCXXABI.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -485,11 +485,11 @@ class CGCXXABI {
485485
llvm::PointerUnion<const CXXDeleteExpr *, const CXXMemberCallExpr *>;
486486

487487
/// Emit the ABI-specific virtual destructor call.
488-
virtual llvm::Value *EmitVirtualDestructorCall(CodeGenFunction &CGF,
489-
const CXXDestructorDecl *Dtor,
490-
CXXDtorType DtorType,
491-
Address This,
492-
DeleteOrMemberCallExpr E) = 0;
488+
virtual llvm::Value *
489+
EmitVirtualDestructorCall(CodeGenFunction &CGF, const CXXDestructorDecl *Dtor,
490+
CXXDtorType DtorType, Address This,
491+
DeleteOrMemberCallExpr E,
492+
llvm::CallBase **CallOrInvoke) = 0;
493493

494494
virtual void adjustCallArgsForDestructorThunk(CodeGenFunction &CGF,
495495
GlobalDecl GD,

clang/lib/CodeGen/CGClass.cpp

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2192,15 +2192,11 @@ static bool canEmitDelegateCallArgs(CodeGenFunction &CGF,
21922192
return true;
21932193
}
21942194

2195-
void CodeGenFunction::EmitCXXConstructorCall(const CXXConstructorDecl *D,
2196-
CXXCtorType Type,
2197-
bool ForVirtualBase,
2198-
bool Delegating,
2199-
Address This,
2200-
CallArgList &Args,
2201-
AggValueSlot::Overlap_t Overlap,
2202-
SourceLocation Loc,
2203-
bool NewPointerIsChecked) {
2195+
void CodeGenFunction::EmitCXXConstructorCall(
2196+
const CXXConstructorDecl *D, CXXCtorType Type, bool ForVirtualBase,
2197+
bool Delegating, Address This, CallArgList &Args,
2198+
AggValueSlot::Overlap_t Overlap, SourceLocation Loc,
2199+
bool NewPointerIsChecked, llvm::CallBase **CallOrInvoke) {
22042200
const CXXRecordDecl *ClassDecl = D->getParent();
22052201

22062202
if (!NewPointerIsChecked)
@@ -2248,7 +2244,7 @@ void CodeGenFunction::EmitCXXConstructorCall(const CXXConstructorDecl *D,
22482244
const CGFunctionInfo &Info = CGM.getTypes().arrangeCXXConstructorCall(
22492245
Args, D, Type, ExtraArgs.Prefix, ExtraArgs.Suffix, PassPrototypeArgs);
22502246
CGCallee Callee = CGCallee::forDirect(CalleePtr, GlobalDecl(D, Type));
2251-
EmitCall(Info, Callee, ReturnValueSlot(), Args, nullptr, false, Loc);
2247+
EmitCall(Info, Callee, ReturnValueSlot(), Args, CallOrInvoke, false, Loc);
22522248

22532249
// Generate vtable assumptions if we're constructing a complete object
22542250
// with a vtable. We don't do this for base subobjects for two reasons:

clang/lib/CodeGen/CGExpr.cpp

Lines changed: 39 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
#include "clang/Basic/SourceManager.h"
3434
#include "llvm/ADT/Hashing.h"
3535
#include "llvm/ADT/STLExtras.h"
36+
#include "llvm/ADT/ScopeExit.h"
3637
#include "llvm/ADT/StringExtras.h"
3738
#include "llvm/IR/DataLayout.h"
3839
#include "llvm/IR/Intrinsics.h"
@@ -5444,24 +5445,38 @@ RValue CodeGenFunction::EmitRValueForField(LValue LV,
54445445
//===--------------------------------------------------------------------===//
54455446

54465447
RValue CodeGenFunction::EmitCallExpr(const CallExpr *E,
5447-
ReturnValueSlot ReturnValue) {
5448+
ReturnValueSlot ReturnValue,
5449+
llvm::CallBase **CallOrInvoke) {
5450+
llvm::CallBase *CallOrInvokeStorage;
5451+
if (!CallOrInvoke) {
5452+
CallOrInvoke = &CallOrInvokeStorage;
5453+
}
5454+
5455+
auto AddCoroMustElideOnExit = llvm::make_scope_exit([&] {
5456+
if (E->isCoroMustElide()) {
5457+
auto *I = *CallOrInvoke;
5458+
if (I)
5459+
I->addFnAttr(llvm::Attribute::CoroMustElide);
5460+
}
5461+
});
5462+
54485463
// Builtins never have block type.
54495464
if (E->getCallee()->getType()->isBlockPointerType())
5450-
return EmitBlockCallExpr(E, ReturnValue);
5465+
return EmitBlockCallExpr(E, ReturnValue, CallOrInvoke);
54515466

54525467
if (const auto *CE = dyn_cast<CXXMemberCallExpr>(E))
5453-
return EmitCXXMemberCallExpr(CE, ReturnValue);
5468+
return EmitCXXMemberCallExpr(CE, ReturnValue, CallOrInvoke);
54545469

54555470
if (const auto *CE = dyn_cast<CUDAKernelCallExpr>(E))
5456-
return EmitCUDAKernelCallExpr(CE, ReturnValue);
5471+
return EmitCUDAKernelCallExpr(CE, ReturnValue, CallOrInvoke);
54575472

54585473
// A CXXOperatorCallExpr is created even for explicit object methods, but
54595474
// these should be treated like static function call.
54605475
if (const auto *CE = dyn_cast<CXXOperatorCallExpr>(E))
54615476
if (const auto *MD =
54625477
dyn_cast_if_present<CXXMethodDecl>(CE->getCalleeDecl());
54635478
MD && MD->isImplicitObjectMemberFunction())
5464-
return EmitCXXOperatorMemberCallExpr(CE, MD, ReturnValue);
5479+
return EmitCXXOperatorMemberCallExpr(CE, MD, ReturnValue, CallOrInvoke);
54655480

54665481
CGCallee callee = EmitCallee(E->getCallee());
54675482

@@ -5474,14 +5489,17 @@ RValue CodeGenFunction::EmitCallExpr(const CallExpr *E,
54745489
return EmitCXXPseudoDestructorExpr(callee.getPseudoDestructorExpr());
54755490
}
54765491

5477-
return EmitCall(E->getCallee()->getType(), callee, E, ReturnValue);
5492+
return EmitCall(E->getCallee()->getType(), callee, E, ReturnValue,
5493+
/*Chain=*/nullptr, CallOrInvoke);
54785494
}
54795495

54805496
/// Emit a CallExpr without considering whether it might be a subclass.
54815497
RValue CodeGenFunction::EmitSimpleCallExpr(const CallExpr *E,
5482-
ReturnValueSlot ReturnValue) {
5498+
ReturnValueSlot ReturnValue,
5499+
llvm::CallBase **CallOrInvoke) {
54835500
CGCallee Callee = EmitCallee(E->getCallee());
5484-
return EmitCall(E->getCallee()->getType(), Callee, E, ReturnValue);
5501+
return EmitCall(E->getCallee()->getType(), Callee, E, ReturnValue,
5502+
/*Chain=*/nullptr, CallOrInvoke);
54855503
}
54865504

54875505
// Detect the unusual situation where an inline version is shadowed by a
@@ -5685,8 +5703,9 @@ LValue CodeGenFunction::EmitBinaryOperatorLValue(const BinaryOperator *E) {
56855703
llvm_unreachable("bad evaluation kind");
56865704
}
56875705

5688-
LValue CodeGenFunction::EmitCallExprLValue(const CallExpr *E) {
5689-
RValue RV = EmitCallExpr(E);
5706+
LValue CodeGenFunction::EmitCallExprLValue(const CallExpr *E,
5707+
llvm::CallBase **CallOrInvoke) {
5708+
RValue RV = EmitCallExpr(E, ReturnValueSlot(), CallOrInvoke);
56905709

56915710
if (!RV.isScalar())
56925711
return MakeAddrLValue(RV.getAggregateAddress(), E->getType(),
@@ -5809,9 +5828,11 @@ LValue CodeGenFunction::EmitStmtExprLValue(const StmtExpr *E) {
58095828
AlignmentSource::Decl);
58105829
}
58115830

5812-
RValue CodeGenFunction::EmitCall(QualType CalleeType, const CGCallee &OrigCallee,
5813-
const CallExpr *E, ReturnValueSlot ReturnValue,
5814-
llvm::Value *Chain) {
5831+
RValue CodeGenFunction::EmitCall(QualType CalleeType,
5832+
const CGCallee &OrigCallee, const CallExpr *E,
5833+
ReturnValueSlot ReturnValue,
5834+
llvm::Value *Chain,
5835+
llvm::CallBase **CallOrInvoke) {
58155836
// Get the actual function type. The callee type will always be a pointer to
58165837
// function type or a block pointer type.
58175838
assert(CalleeType->isFunctionPointerType() &&
@@ -6031,8 +6052,8 @@ RValue CodeGenFunction::EmitCall(QualType CalleeType, const CGCallee &OrigCallee
60316052
Address(Handle, Handle->getType(), CGM.getPointerAlign()));
60326053
Callee.setFunctionPointer(Stub);
60336054
}
6034-
llvm::CallBase *CallOrInvoke = nullptr;
6035-
RValue Call = EmitCall(FnInfo, Callee, ReturnValue, Args, &CallOrInvoke,
6055+
llvm::CallBase *LocalCallOrInvoke = nullptr;
6056+
RValue Call = EmitCall(FnInfo, Callee, ReturnValue, Args, &LocalCallOrInvoke,
60366057
E == MustTailCall, E->getExprLoc());
60376058

60386059
// Generate function declaration DISuprogram in order to be used
@@ -6041,11 +6062,13 @@ RValue CodeGenFunction::EmitCall(QualType CalleeType, const CGCallee &OrigCallee
60416062
if (auto *CalleeDecl = dyn_cast_or_null<FunctionDecl>(TargetDecl)) {
60426063
FunctionArgList Args;
60436064
QualType ResTy = BuildFunctionArgList(CalleeDecl, Args);
6044-
DI->EmitFuncDeclForCallSite(CallOrInvoke,
6065+
DI->EmitFuncDeclForCallSite(LocalCallOrInvoke,
60456066
DI->getFunctionType(CalleeDecl, ResTy, Args),
60466067
CalleeDecl);
60476068
}
60486069
}
6070+
if (CallOrInvoke)
6071+
*CallOrInvoke = LocalCallOrInvoke;
60496072

60506073
return Call;
60516074
}

0 commit comments

Comments
 (0)