Skip to content

Commit 2445632

Browse files
committed
[CodeGen][ARM64EC] Add support for hybrid_patchable attribute.
1 parent efab4a3 commit 2445632

File tree

10 files changed

+484
-11
lines changed

10 files changed

+484
-11
lines changed

llvm/include/llvm/Bitcode/LLVMBitCodes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -756,6 +756,7 @@ enum AttributeKindCodes {
756756
ATTR_KIND_RANGE = 92,
757757
ATTR_KIND_SANITIZE_NUMERICAL_STABILITY = 93,
758758
ATTR_KIND_INITIALIZES = 94,
759+
ATTR_KIND_HYBRID_PATCHABLE = 95,
759760
};
760761

761762
enum ComdatSelectionKindCodes {

llvm/include/llvm/CodeGen/AsmPrinter.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -905,14 +905,14 @@ class AsmPrinter : public MachineFunctionPass {
905905
virtual void emitModuleCommandLines(Module &M);
906906

907907
GCMetadataPrinter *getOrCreateGCPrinter(GCStrategy &S);
908-
virtual void emitGlobalAlias(const Module &M, const GlobalAlias &GA);
909908
void emitGlobalIFunc(Module &M, const GlobalIFunc &GI);
910909

911910
private:
912911
/// This method decides whether the specified basic block requires a label.
913912
bool shouldEmitLabelForBasicBlock(const MachineBasicBlock &MBB) const;
914913

915914
protected:
915+
virtual void emitGlobalAlias(const Module &M, const GlobalAlias &GA);
916916
virtual bool shouldEmitWeakSwiftAsyncExtendedFramePointerFlags() const {
917917
return false;
918918
}

llvm/include/llvm/IR/Attributes.td

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,9 @@ def ElementType : TypeAttr<"elementtype", [ParamAttr]>;
112112
/// symbol.
113113
def FnRetThunkExtern : EnumAttr<"fn_ret_thunk_extern", [FnAttr]>;
114114

115+
/// Function has a hybrid patchable thunk.
116+
def HybridPatchable : EnumAttr<"hybrid_patchable", [FnAttr]>;
117+
115118
/// Pass structure in an alloca.
116119
def InAlloca : TypeAttr<"inalloca", [ParamAttr]>;
117120

llvm/lib/Bitcode/Writer/BitcodeWriter.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -717,6 +717,8 @@ static uint64_t getAttrKindEncoding(Attribute::AttrKind Kind) {
717717
return bitc::ATTR_KIND_HOT;
718718
case Attribute::ElementType:
719719
return bitc::ATTR_KIND_ELEMENTTYPE;
720+
case Attribute::HybridPatchable:
721+
return bitc::ATTR_KIND_HYBRID_PATCHABLE;
720722
case Attribute::InlineHint:
721723
return bitc::ATTR_KIND_INLINE_HINT;
722724
case Attribute::InReg:

llvm/lib/CodeGen/AsmPrinter/AsmPrinter.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2902,8 +2902,8 @@ bool AsmPrinter::emitSpecialLLVMGlobal(const GlobalVariable *GV) {
29022902
auto *Arr = cast<ConstantArray>(GV->getInitializer());
29032903
for (auto &U : Arr->operands()) {
29042904
auto *C = cast<Constant>(U);
2905-
auto *Src = cast<Function>(C->getOperand(0)->stripPointerCasts());
2906-
auto *Dst = cast<Function>(C->getOperand(1)->stripPointerCasts());
2905+
auto *Src = cast<GlobalValue>(C->getOperand(0)->stripPointerCasts());
2906+
auto *Dst = cast<GlobalValue>(C->getOperand(1)->stripPointerCasts());
29072907
int Kind = cast<ConstantInt>(C->getOperand(2))->getZExtValue();
29082908

29092909
if (Src->hasDLLImportStorageClass()) {

llvm/lib/Target/AArch64/AArch64Arm64ECCallLowering.cpp

Lines changed: 131 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include "llvm/ADT/SmallVector.h"
2222
#include "llvm/ADT/Statistic.h"
2323
#include "llvm/IR/CallingConv.h"
24+
#include "llvm/IR/GlobalAlias.h"
2425
#include "llvm/IR/IRBuilder.h"
2526
#include "llvm/IR/Instruction.h"
2627
#include "llvm/IR/Mangler.h"
@@ -69,15 +70,21 @@ class AArch64Arm64ECCallLowering : public ModulePass {
6970
Function *buildEntryThunk(Function *F);
7071
void lowerCall(CallBase *CB);
7172
Function *buildGuestExitThunk(Function *F);
72-
bool processFunction(Function &F, SetVector<Function *> &DirectCalledFns);
73+
Function *buildPatchableThunk(GlobalAlias *UnmangledAlias,
74+
GlobalAlias *MangledAlias);
75+
bool processFunction(Function &F, SetVector<GlobalValue *> &DirectCalledFns,
76+
DenseMap<GlobalAlias *, GlobalAlias *> &FnsMap);
7377
bool runOnModule(Module &M) override;
7478

7579
private:
7680
int cfguard_module_flag = 0;
7781
FunctionType *GuardFnType = nullptr;
7882
PointerType *GuardFnPtrType = nullptr;
83+
FunctionType *DispatchFnType = nullptr;
84+
PointerType *DispatchFnPtrType = nullptr;
7985
Constant *GuardFnCFGlobal = nullptr;
8086
Constant *GuardFnGlobal = nullptr;
87+
Constant *DispatchFnGlobal = nullptr;
8188
Module *M = nullptr;
8289

8390
Type *PtrTy;
@@ -671,6 +678,66 @@ Function *AArch64Arm64ECCallLowering::buildGuestExitThunk(Function *F) {
671678
return GuestExit;
672679
}
673680

681+
Function *
682+
AArch64Arm64ECCallLowering::buildPatchableThunk(GlobalAlias *UnmangledAlias,
683+
GlobalAlias *MangledAlias) {
684+
llvm::raw_null_ostream NullThunkName;
685+
FunctionType *Arm64Ty, *X64Ty;
686+
Function *F = cast<Function>(MangledAlias->getAliasee());
687+
SmallVector<ThunkArgTranslation> ArgTranslations;
688+
getThunkType(F->getFunctionType(), F->getAttributes(),
689+
Arm64ECThunkType::GuestExit, NullThunkName, Arm64Ty, X64Ty,
690+
ArgTranslations);
691+
std::string ThunkName(MangledAlias->getName());
692+
if (ThunkName[0] == '?' && ThunkName.find("@") != std::string::npos) {
693+
ThunkName.insert(ThunkName.find("@"), "$hybpatch_thunk");
694+
} else {
695+
ThunkName.append("$hybpatch_thunk");
696+
}
697+
698+
Function *GuestExit =
699+
Function::Create(Arm64Ty, GlobalValue::WeakODRLinkage, 0, ThunkName, M);
700+
GuestExit->setComdat(M->getOrInsertComdat(ThunkName));
701+
GuestExit->setSection(".wowthk$aa");
702+
BasicBlock *BB = BasicBlock::Create(M->getContext(), "", GuestExit);
703+
IRBuilder<> B(BB);
704+
705+
// Load the global symbol as a pointer to the check function.
706+
LoadInst *DispatchLoad = B.CreateLoad(DispatchFnPtrType, DispatchFnGlobal);
707+
708+
// Create new dispatch call instruction.
709+
Function *ExitThunk =
710+
buildExitThunk(F->getFunctionType(), F->getAttributes());
711+
CallInst *Dispatch =
712+
B.CreateCall(DispatchFnType, DispatchLoad,
713+
{UnmangledAlias, ExitThunk, UnmangledAlias->getAliasee()});
714+
715+
// Ensure that the first arguments are passed in the correct registers.
716+
Dispatch->setCallingConv(CallingConv::CFGuard_Check);
717+
718+
Value *DispatchRetVal = B.CreateBitCast(Dispatch, PtrTy);
719+
SmallVector<Value *> Args;
720+
for (Argument &Arg : GuestExit->args())
721+
Args.push_back(&Arg);
722+
CallInst *Call = B.CreateCall(Arm64Ty, DispatchRetVal, Args);
723+
Call->setTailCallKind(llvm::CallInst::TCK_MustTail);
724+
725+
if (Call->getType()->isVoidTy())
726+
B.CreateRetVoid();
727+
else
728+
B.CreateRet(Call);
729+
730+
auto SRetAttr = F->getAttributes().getParamAttr(0, Attribute::StructRet);
731+
auto InRegAttr = F->getAttributes().getParamAttr(0, Attribute::InReg);
732+
if (SRetAttr.isValid() && !InRegAttr.isValid()) {
733+
GuestExit->addParamAttr(0, SRetAttr);
734+
Call->addParamAttr(0, SRetAttr);
735+
}
736+
737+
MangledAlias->setAliasee(GuestExit);
738+
return GuestExit;
739+
}
740+
674741
// Lower an indirect call with inline code.
675742
void AArch64Arm64ECCallLowering::lowerCall(CallBase *CB) {
676743
assert(Triple(CB->getModule()->getTargetTriple()).isOSWindows() &&
@@ -726,17 +793,57 @@ bool AArch64Arm64ECCallLowering::runOnModule(Module &Mod) {
726793

727794
GuardFnType = FunctionType::get(PtrTy, {PtrTy, PtrTy}, false);
728795
GuardFnPtrType = PointerType::get(GuardFnType, 0);
796+
DispatchFnType = FunctionType::get(PtrTy, {PtrTy, PtrTy, PtrTy}, false);
797+
DispatchFnPtrType = PointerType::get(DispatchFnType, 0);
729798
GuardFnCFGlobal =
730799
M->getOrInsertGlobal("__os_arm64x_check_icall_cfg", GuardFnPtrType);
731800
GuardFnGlobal =
732801
M->getOrInsertGlobal("__os_arm64x_check_icall", GuardFnPtrType);
802+
DispatchFnGlobal =
803+
M->getOrInsertGlobal("__os_arm64x_dispatch_call", DispatchFnPtrType);
804+
805+
DenseMap<GlobalAlias *, GlobalAlias *> FnsMap;
806+
SetVector<GlobalAlias *> PatchableFns;
733807

734-
SetVector<Function *> DirectCalledFns;
808+
for (Function &F : Mod) {
809+
if (!F.hasFnAttribute(Attribute::HybridPatchable) || F.isDeclaration() ||
810+
F.hasLocalLinkage() || F.getName().ends_with("$hp_target"))
811+
continue;
812+
813+
// Rename hybrid patchable functions and change callers to use a global
814+
// alias instead.
815+
if (std::optional<std::string> MangledName =
816+
getArm64ECMangledFunctionName(F.getName().str())) {
817+
std::string OrigName(F.getName());
818+
F.setName(MangledName.value() + "$hp_target");
819+
820+
// The unmangled symbol is a weak alias to an undefined symbol with the
821+
// "EXP+" prefix. This undefined symbol is resolved by the linker by
822+
// creating an x86 thunk that jumps back to the actual EC target. Since we
823+
// can't represent that in IR, we create an alias to the target instead.
824+
// The "EXP+" symbol is set as metadata, which is then used by
825+
// emitGlobalAlias to emit the right alias.
826+
auto *A =
827+
GlobalAlias::create(GlobalValue::LinkOnceODRLinkage, OrigName, &F);
828+
F.replaceAllUsesWith(A);
829+
F.setMetadata("arm64ec_exp_name",
830+
MDNode::get(M->getContext(),
831+
MDString::get(M->getContext(),
832+
"EXP+" + MangledName.value())));
833+
A->setAliasee(&F);
834+
835+
FnsMap[A] = GlobalAlias::create(GlobalValue::LinkOnceODRLinkage,
836+
MangledName.value(), &F);
837+
PatchableFns.insert(A);
838+
}
839+
}
840+
841+
SetVector<GlobalValue *> DirectCalledFns;
735842
for (Function &F : Mod)
736843
if (!F.isDeclaration() &&
737844
F.getCallingConv() != CallingConv::ARM64EC_Thunk_Native &&
738845
F.getCallingConv() != CallingConv::ARM64EC_Thunk_X64)
739-
processFunction(F, DirectCalledFns);
846+
processFunction(F, DirectCalledFns, FnsMap);
740847

741848
struct ThunkInfo {
742849
Constant *Src;
@@ -754,14 +861,20 @@ bool AArch64Arm64ECCallLowering::runOnModule(Module &Mod) {
754861
{&F, buildEntryThunk(&F), Arm64ECThunkType::Entry});
755862
}
756863
}
757-
for (Function *F : DirectCalledFns) {
864+
for (GlobalValue *O : DirectCalledFns) {
865+
auto GA = dyn_cast<GlobalAlias>(O);
866+
auto F = dyn_cast<Function>(GA ? GA->getAliasee() : O);
758867
ThunkMapping.push_back(
759-
{F, buildExitThunk(F->getFunctionType(), F->getAttributes()),
868+
{O, buildExitThunk(F->getFunctionType(), F->getAttributes()),
760869
Arm64ECThunkType::Exit});
761-
if (!F->hasDLLImportStorageClass())
870+
if (!GA && !F->hasDLLImportStorageClass())
762871
ThunkMapping.push_back(
763872
{buildGuestExitThunk(F), F, Arm64ECThunkType::GuestExit});
764873
}
874+
for (GlobalAlias *A : PatchableFns) {
875+
Function *Thunk = buildPatchableThunk(A, FnsMap[A]);
876+
ThunkMapping.push_back({Thunk, A, Arm64ECThunkType::GuestExit});
877+
}
765878

766879
if (!ThunkMapping.empty()) {
767880
SmallVector<Constant *> ThunkMappingArrayElems;
@@ -784,7 +897,8 @@ bool AArch64Arm64ECCallLowering::runOnModule(Module &Mod) {
784897
}
785898

786899
bool AArch64Arm64ECCallLowering::processFunction(
787-
Function &F, SetVector<Function *> &DirectCalledFns) {
900+
Function &F, SetVector<GlobalValue *> &DirectCalledFns,
901+
DenseMap<GlobalAlias *, GlobalAlias *> &FnsMap) {
788902
SmallVector<CallBase *, 8> IndirectCalls;
789903

790904
// For ARM64EC targets, a function definition's name is mangled differently
@@ -836,6 +950,16 @@ bool AArch64Arm64ECCallLowering::processFunction(
836950
continue;
837951
}
838952

953+
// Use mangled global alias for direct calls to patchable functions.
954+
if (GlobalAlias *A = dyn_cast<GlobalAlias>(CB->getCalledOperand())) {
955+
auto I = FnsMap.find(A);
956+
if (I != FnsMap.end()) {
957+
CB->setCalledOperand(I->second);
958+
DirectCalledFns.insert(I->first);
959+
continue;
960+
}
961+
}
962+
839963
IndirectCalls.push_back(CB);
840964
++Arm64ECCallsLowered;
841965
}

llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,7 @@ class AArch64AsmPrinter : public AsmPrinter {
193193
void PrintDebugValueComment(const MachineInstr *MI, raw_ostream &OS);
194194

195195
void emitFunctionBodyEnd() override;
196+
void emitGlobalAlias(const Module &M, const GlobalAlias &GA) override;
196197

197198
MCSymbol *GetCPISymbol(unsigned CPID) const override;
198199
void emitEndOfAsmFile(Module &M) override;
@@ -1210,6 +1211,32 @@ void AArch64AsmPrinter::emitFunctionEntryLabel() {
12101211
}
12111212
}
12121213

1214+
void AArch64AsmPrinter::emitGlobalAlias(const Module &M,
1215+
const GlobalAlias &GA) {
1216+
if (auto F = dyn_cast_or_null<Function>(GA.getAliasee())) {
1217+
// Global aliases must point to a definition, but unmangled patchable
1218+
// symbols are special and need to point to an undefined symbol with "EXP+"
1219+
// prefix. Such undefined symbol is resolved by the linker by creating
1220+
// x86 thunk that jumps back to the actual EC target.
1221+
if (MDNode *Node = F->getMetadata("arm64ec_exp_name")) {
1222+
StringRef ExpStr = cast<MDString>(Node->getOperand(0))->getString();
1223+
MCSymbol *ExpSym = MMI->getContext().getOrCreateSymbol(ExpStr);
1224+
MCSymbol *Sym = MMI->getContext().getOrCreateSymbol(GA.getName());
1225+
OutStreamer->beginCOFFSymbolDef(Sym);
1226+
OutStreamer->emitCOFFSymbolStorageClass(COFF::IMAGE_SYM_CLASS_EXTERNAL);
1227+
OutStreamer->emitCOFFSymbolType(COFF::IMAGE_SYM_DTYPE_FUNCTION
1228+
<< COFF::SCT_COMPLEX_TYPE_SHIFT);
1229+
OutStreamer->endCOFFSymbolDef();
1230+
OutStreamer->emitSymbolAttribute(Sym, MCSA_Weak);
1231+
OutStreamer->emitAssignment(
1232+
Sym, MCSymbolRefExpr::create(ExpSym, MCSymbolRefExpr::VK_None,
1233+
MMI->getContext()));
1234+
return;
1235+
}
1236+
}
1237+
AsmPrinter::emitGlobalAlias(M, GA);
1238+
}
1239+
12131240
/// Small jump tables contain an unsigned byte or half, representing the offset
12141241
/// from the lowest-addressed possible destination to the desired basic
12151242
/// block. Since all instructions are 4-byte aligned, this is further compressed

llvm/lib/Target/AArch64/AArch64CallingConvention.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -333,7 +333,7 @@ def CC_AArch64_Win64_CFGuard_Check : CallingConv<[
333333

334334
let Entry = 1 in
335335
def CC_AArch64_Arm64EC_CFGuard_Check : CallingConv<[
336-
CCIfType<[i64], CCAssignToReg<[X11, X10]>>
336+
CCIfType<[i64], CCAssignToReg<[X11, X10, X9]>>
337337
]>;
338338

339339
let Entry = 1 in

llvm/lib/Transforms/Utils/CodeExtractor.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -932,6 +932,7 @@ Function *CodeExtractor::constructFunction(const ValueSet &inputs,
932932
case Attribute::DisableSanitizerInstrumentation:
933933
case Attribute::FnRetThunkExtern:
934934
case Attribute::Hot:
935+
case Attribute::HybridPatchable:
935936
case Attribute::NoRecurse:
936937
case Attribute::InlineHint:
937938
case Attribute::MinSize:

0 commit comments

Comments
 (0)