Skip to content

Commit a346bfd

Browse files
committed
[CodeGen][ARM64EC] Add support for hybrid_patchable attribute.
1 parent fe5d791 commit a346bfd

File tree

10 files changed

+396
-10
lines changed

10 files changed

+396
-10
lines changed

llvm/include/llvm/Bitcode/LLVMBitCodes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -753,6 +753,7 @@ enum AttributeKindCodes {
753753
ATTR_KIND_CORO_ONLY_DESTROY_WHEN_COMPLETE = 90,
754754
ATTR_KIND_DEAD_ON_UNWIND = 91,
755755
ATTR_KIND_RANGE = 92,
756+
ATTR_KIND_HYBRID_PATCHABLE = 93,
756757
};
757758

758759
enum ComdatSelectionKindCodes {

llvm/include/llvm/CodeGen/AsmPrinter.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -900,14 +900,14 @@ class AsmPrinter : public MachineFunctionPass {
900900
virtual void emitModuleCommandLines(Module &M);
901901

902902
GCMetadataPrinter *getOrCreateGCPrinter(GCStrategy &S);
903-
virtual void emitGlobalAlias(const Module &M, const GlobalAlias &GA);
904903
void emitGlobalIFunc(Module &M, const GlobalIFunc &GI);
905904

906905
private:
907906
/// This method decides whether the specified basic block requires a label.
908907
bool shouldEmitLabelForBasicBlock(const MachineBasicBlock &MBB) const;
909908

910909
protected:
910+
virtual void emitGlobalAlias(const Module &M, const GlobalAlias &GA);
911911
virtual bool shouldEmitWeakSwiftAsyncExtendedFramePointerFlags() const {
912912
return false;
913913
}

llvm/include/llvm/IR/Attributes.td

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,9 @@ def ElementType : TypeAttr<"elementtype", [ParamAttr]>;
109109
/// symbol.
110110
def FnRetThunkExtern : EnumAttr<"fn_ret_thunk_extern", [FnAttr]>;
111111

112+
/// Function has a hybrid patchable thunk.
113+
def HybridPatchable : EnumAttr<"hybrid_patchable", [FnAttr]>;
114+
112115
/// Pass structure in an alloca.
113116
def InAlloca : TypeAttr<"inalloca", [ParamAttr]>;
114117

llvm/lib/Bitcode/Writer/BitcodeWriter.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -717,6 +717,8 @@ static uint64_t getAttrKindEncoding(Attribute::AttrKind Kind) {
717717
return bitc::ATTR_KIND_HOT;
718718
case Attribute::ElementType:
719719
return bitc::ATTR_KIND_ELEMENTTYPE;
720+
case Attribute::HybridPatchable:
721+
return bitc::ATTR_KIND_HYBRID_PATCHABLE;
720722
case Attribute::InlineHint:
721723
return bitc::ATTR_KIND_INLINE_HINT;
722724
case Attribute::InReg:

llvm/lib/CodeGen/AsmPrinter/AsmPrinter.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2899,7 +2899,7 @@ bool AsmPrinter::emitSpecialLLVMGlobal(const GlobalVariable *GV) {
28992899
for (auto &U : Arr->operands()) {
29002900
auto *C = cast<Constant>(U);
29012901
auto *Src = cast<Function>(C->getOperand(0)->stripPointerCasts());
2902-
auto *Dst = cast<Function>(C->getOperand(1)->stripPointerCasts());
2902+
auto *Dst = cast<GlobalValue>(C->getOperand(1)->stripPointerCasts());
29032903
int Kind = cast<ConstantInt>(C->getOperand(2))->getZExtValue();
29042904

29052905
if (Src->hasDLLImportStorageClass()) {

llvm/lib/Target/AArch64/AArch64Arm64ECCallLowering.cpp

Lines changed: 124 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include "llvm/ADT/SmallVector.h"
2222
#include "llvm/ADT/Statistic.h"
2323
#include "llvm/IR/CallingConv.h"
24+
#include "llvm/IR/GlobalAlias.h"
2425
#include "llvm/IR/IRBuilder.h"
2526
#include "llvm/IR/Instruction.h"
2627
#include "llvm/IR/Mangler.h"
@@ -29,6 +30,7 @@
2930
#include "llvm/Pass.h"
3031
#include "llvm/Support/CommandLine.h"
3132
#include "llvm/TargetParser/Triple.h"
33+
#include <map>
3234

3335
using namespace llvm;
3436
using namespace llvm::COFF;
@@ -57,15 +59,21 @@ class AArch64Arm64ECCallLowering : public ModulePass {
5759
Function *buildEntryThunk(Function *F);
5860
void lowerCall(CallBase *CB);
5961
Function *buildGuestExitThunk(Function *F);
60-
bool processFunction(Function &F, SetVector<Function *> &DirectCalledFns);
62+
Function *buildPatchableThunk(GlobalAlias *UnmangledAlias,
63+
GlobalAlias *MangledAlias);
64+
bool processFunction(Function &F, SetVector<GlobalValue *> &DirectCalledFns,
65+
std::map<GlobalAlias *, GlobalAlias *> &PathcableFns);
6166
bool runOnModule(Module &M) override;
6267

6368
private:
6469
int cfguard_module_flag = 0;
6570
FunctionType *GuardFnType = nullptr;
6671
PointerType *GuardFnPtrType = nullptr;
72+
FunctionType *DispatchFnType = nullptr;
73+
PointerType *DispatchFnPtrType = nullptr;
6774
Constant *GuardFnCFGlobal = nullptr;
6875
Constant *GuardFnGlobal = nullptr;
76+
Constant *DispatchFnGlobal = nullptr;
6977
Module *M = nullptr;
7078

7179
Type *PtrTy;
@@ -615,6 +623,66 @@ Function *AArch64Arm64ECCallLowering::buildGuestExitThunk(Function *F) {
615623
return GuestExit;
616624
}
617625

626+
Function *
627+
AArch64Arm64ECCallLowering::buildPatchableThunk(GlobalAlias *UnmangledAlias,
628+
GlobalAlias *MangledAlias) {
629+
llvm::raw_null_ostream NullThunkName;
630+
FunctionType *Arm64Ty, *X64Ty;
631+
Function *F = cast<Function>(MangledAlias->getAliasee());
632+
getThunkType(F->getFunctionType(), F->getAttributes(),
633+
Arm64ECThunkType::GuestExit, NullThunkName, Arm64Ty, X64Ty);
634+
std::string ThunkName(MangledAlias->getName());
635+
if (ThunkName[0] == '?' && ThunkName.find("@") != std::string::npos) {
636+
ThunkName.insert(ThunkName.find("@"), "$hybpatch_thunk");
637+
} else {
638+
ThunkName.append("$hybpatch_thunk");
639+
}
640+
641+
Function *GuestExit =
642+
Function::Create(Arm64Ty, GlobalValue::WeakODRLinkage, 0, ThunkName, M);
643+
GuestExit->setComdat(M->getOrInsertComdat(ThunkName));
644+
GuestExit->setSection(".wowthk$aa");
645+
BasicBlock *BB = BasicBlock::Create(M->getContext(), "", GuestExit);
646+
IRBuilder<> B(BB);
647+
648+
// Load the global symbol as a pointer to the check function.
649+
LoadInst *DispatchLoad = B.CreateLoad(DispatchFnPtrType, DispatchFnGlobal);
650+
651+
// Create new dispatch call instruction.
652+
Function *ExitThunk =
653+
buildExitThunk(F->getFunctionType(), F->getAttributes());
654+
CallInst *Dispatch = B.CreateCall(
655+
DispatchFnType, DispatchLoad,
656+
{B.CreateBitCast(UnmangledAlias, B.getPtrTy()),
657+
B.CreateBitCast(ExitThunk, B.getPtrTy()),
658+
B.CreateBitCast(UnmangledAlias->getAliasee(), B.getPtrTy())});
659+
660+
// Ensure that the first arguments are passed in the correct registers.
661+
Dispatch->setCallingConv(CallingConv::CFGuard_Check);
662+
663+
Value *DispatchRetVal = B.CreateBitCast(Dispatch, PtrTy);
664+
SmallVector<Value *> Args;
665+
for (Argument &Arg : GuestExit->args())
666+
Args.push_back(&Arg);
667+
CallInst *Call = B.CreateCall(Arm64Ty, DispatchRetVal, Args);
668+
Call->setTailCallKind(llvm::CallInst::TCK_MustTail);
669+
670+
if (Call->getType()->isVoidTy())
671+
B.CreateRetVoid();
672+
else
673+
B.CreateRet(Call);
674+
675+
auto SRetAttr = F->getAttributes().getParamAttr(0, Attribute::StructRet);
676+
auto InRegAttr = F->getAttributes().getParamAttr(0, Attribute::InReg);
677+
if (SRetAttr.isValid() && !InRegAttr.isValid()) {
678+
GuestExit->addParamAttr(0, SRetAttr);
679+
Call->addParamAttr(0, SRetAttr);
680+
}
681+
682+
MangledAlias->setAliasee(GuestExit);
683+
return GuestExit;
684+
}
685+
618686
// Lower an indirect call with inline code.
619687
void AArch64Arm64ECCallLowering::lowerCall(CallBase *CB) {
620688
assert(Triple(CB->getModule()->getTargetTriple()).isOSWindows() &&
@@ -670,17 +738,48 @@ bool AArch64Arm64ECCallLowering::runOnModule(Module &Mod) {
670738

671739
GuardFnType = FunctionType::get(PtrTy, {PtrTy, PtrTy}, false);
672740
GuardFnPtrType = PointerType::get(GuardFnType, 0);
741+
DispatchFnType = FunctionType::get(PtrTy, {PtrTy, PtrTy, PtrTy}, false);
742+
DispatchFnPtrType = PointerType::get(DispatchFnType, 0);
673743
GuardFnCFGlobal =
674744
M->getOrInsertGlobal("__os_arm64x_check_icall_cfg", GuardFnPtrType);
675745
GuardFnGlobal =
676746
M->getOrInsertGlobal("__os_arm64x_check_icall", GuardFnPtrType);
747+
DispatchFnGlobal =
748+
M->getOrInsertGlobal("__os_arm64x_dispatch_call", DispatchFnPtrType);
749+
750+
std::map<GlobalAlias *, GlobalAlias *> PatchableFns;
751+
for (Function &F : Mod) {
752+
if (!F.hasFnAttribute(Attribute::HybridPatchable) || F.isDeclaration() ||
753+
F.hasLocalLinkage() || F.getName().ends_with("$hp_target"))
754+
continue;
755+
756+
// Rename hybrid patchable functions and change callers to use a global
757+
// alias instead.
758+
if (std::optional<std::string> MangledName =
759+
getArm64ECMangledFunctionName(F.getName().str())) {
760+
std::string OrigName(F.getName());
761+
F.setName(MangledName.value() + "$hp_target");
677762

678-
SetVector<Function *> DirectCalledFns;
763+
auto *A =
764+
GlobalAlias::create(GlobalValue::LinkOnceODRLinkage, OrigName, &F);
765+
F.replaceAllUsesWith(A);
766+
F.setMetadata("arm64ec_exp_name",
767+
MDNode::get(M->getContext(),
768+
MDString::get(M->getContext(),
769+
"EXP+" + MangledName.value())));
770+
A->setAliasee(&F);
771+
772+
PatchableFns[A] = GlobalAlias::create(GlobalValue::LinkOnceODRLinkage,
773+
MangledName.value(), &F);
774+
}
775+
}
776+
777+
SetVector<GlobalValue *> DirectCalledFns;
679778
for (Function &F : Mod)
680779
if (!F.isDeclaration() &&
681780
F.getCallingConv() != CallingConv::ARM64EC_Thunk_Native &&
682781
F.getCallingConv() != CallingConv::ARM64EC_Thunk_X64)
683-
processFunction(F, DirectCalledFns);
782+
processFunction(F, DirectCalledFns, PatchableFns);
684783

685784
struct ThunkInfo {
686785
Constant *Src;
@@ -698,14 +797,20 @@ bool AArch64Arm64ECCallLowering::runOnModule(Module &Mod) {
698797
{&F, buildEntryThunk(&F), Arm64ECThunkType::Entry});
699798
}
700799
}
701-
for (Function *F : DirectCalledFns) {
800+
for (GlobalValue *O : DirectCalledFns) {
801+
auto GA = dyn_cast<GlobalAlias>(O);
802+
auto F = dyn_cast<Function>(GA ? GA->getAliasee() : O);
702803
ThunkMapping.push_back(
703-
{F, buildExitThunk(F->getFunctionType(), F->getAttributes()),
804+
{O, buildExitThunk(F->getFunctionType(), F->getAttributes()),
704805
Arm64ECThunkType::Exit});
705-
if (!F->hasDLLImportStorageClass())
806+
if (!GA && !F->hasDLLImportStorageClass())
706807
ThunkMapping.push_back(
707808
{buildGuestExitThunk(F), F, Arm64ECThunkType::GuestExit});
708809
}
810+
for (auto I : PatchableFns) {
811+
Function *Thunk = buildPatchableThunk(I.first, I.second);
812+
ThunkMapping.push_back({Thunk, I.first, Arm64ECThunkType::GuestExit});
813+
}
709814

710815
if (!ThunkMapping.empty()) {
711816
SmallVector<Constant *> ThunkMappingArrayElems;
@@ -728,7 +833,8 @@ bool AArch64Arm64ECCallLowering::runOnModule(Module &Mod) {
728833
}
729834

730835
bool AArch64Arm64ECCallLowering::processFunction(
731-
Function &F, SetVector<Function *> &DirectCalledFns) {
836+
Function &F, SetVector<GlobalValue *> &DirectCalledFns,
837+
std::map<GlobalAlias *, GlobalAlias *> &PatchableFns) {
732838
SmallVector<CallBase *, 8> IndirectCalls;
733839

734840
// For ARM64EC targets, a function definition's name is mangled differently
@@ -780,6 +886,17 @@ bool AArch64Arm64ECCallLowering::processFunction(
780886
continue;
781887
}
782888

889+
// Use mangled global alias for direct calls to patchable functions.
890+
if (GlobalAlias *A =
891+
dyn_cast_or_null<GlobalAlias>(CB->getCalledOperand())) {
892+
auto I = PatchableFns.find(A);
893+
if (I != PatchableFns.end()) {
894+
CB->setCalledOperand(I->second);
895+
DirectCalledFns.insert(I->first);
896+
continue;
897+
}
898+
}
899+
783900
IndirectCalls.push_back(CB);
784901
++Arm64ECCallsLowered;
785902
}

llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,7 @@ class AArch64AsmPrinter : public AsmPrinter {
185185
void PrintDebugValueComment(const MachineInstr *MI, raw_ostream &OS);
186186

187187
void emitFunctionBodyEnd() override;
188+
void emitGlobalAlias(const Module &M, const GlobalAlias &GA) override;
188189

189190
MCSymbol *GetCPISymbol(unsigned CPID) const override;
190191
void emitEndOfAsmFile(Module &M) override;
@@ -1202,6 +1203,32 @@ void AArch64AsmPrinter::emitFunctionEntryLabel() {
12021203
}
12031204
}
12041205

1206+
void AArch64AsmPrinter::emitGlobalAlias(const Module &M,
1207+
const GlobalAlias &GA) {
1208+
if (auto F = dyn_cast_or_null<Function>(GA.getAliasee())) {
1209+
// Global aliases must point to a definition, but unmangled patchable
1210+
// symbols are special and need to point to an undefined symbol with "EXP+"
1211+
// prefix. Such undefined symbol is resolved by the linker by creating
1212+
// x86 thunk that jumps back to the actual EC target.
1213+
if (MDNode *Node = F->getMetadata("arm64ec_exp_name")) {
1214+
StringRef ExpStr = cast<MDString>(Node->getOperand(0))->getString();
1215+
MCSymbol *ExpSym = MMI->getContext().getOrCreateSymbol(ExpStr);
1216+
MCSymbol *Sym = MMI->getContext().getOrCreateSymbol(GA.getName());
1217+
OutStreamer->beginCOFFSymbolDef(Sym);
1218+
OutStreamer->emitCOFFSymbolStorageClass(COFF::IMAGE_SYM_CLASS_EXTERNAL);
1219+
OutStreamer->emitCOFFSymbolType(COFF::IMAGE_SYM_DTYPE_FUNCTION
1220+
<< COFF::SCT_COMPLEX_TYPE_SHIFT);
1221+
OutStreamer->endCOFFSymbolDef();
1222+
OutStreamer->emitSymbolAttribute(Sym, MCSA_Weak);
1223+
OutStreamer->emitAssignment(
1224+
Sym, MCSymbolRefExpr::create(ExpSym, MCSymbolRefExpr::VK_None,
1225+
MMI->getContext()));
1226+
return;
1227+
}
1228+
}
1229+
AsmPrinter::emitGlobalAlias(M, GA);
1230+
}
1231+
12051232
/// Small jump tables contain an unsigned byte or half, representing the offset
12061233
/// from the lowest-addressed possible destination to the desired basic
12071234
/// block. Since all instructions are 4-byte aligned, this is further compressed

llvm/lib/Target/AArch64/AArch64CallingConvention.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -333,7 +333,7 @@ def CC_AArch64_Win64_CFGuard_Check : CallingConv<[
333333

334334
let Entry = 1 in
335335
def CC_AArch64_Arm64EC_CFGuard_Check : CallingConv<[
336-
CCIfType<[i64], CCAssignToReg<[X11, X10]>>
336+
CCIfType<[i64], CCAssignToReg<[X11, X10, X9]>>
337337
]>;
338338

339339
let Entry = 1 in

llvm/lib/Transforms/Utils/CodeExtractor.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -932,6 +932,7 @@ Function *CodeExtractor::constructFunction(const ValueSet &inputs,
932932
case Attribute::DisableSanitizerInstrumentation:
933933
case Attribute::FnRetThunkExtern:
934934
case Attribute::Hot:
935+
case Attribute::HybridPatchable:
935936
case Attribute::NoRecurse:
936937
case Attribute::InlineHint:
937938
case Attribute::MinSize:

0 commit comments

Comments
 (0)