21
21
#include " llvm/ADT/SmallVector.h"
22
22
#include " llvm/ADT/Statistic.h"
23
23
#include " llvm/IR/CallingConv.h"
24
+ #include " llvm/IR/GlobalAlias.h"
24
25
#include " llvm/IR/IRBuilder.h"
25
26
#include " llvm/IR/Instruction.h"
26
27
#include " llvm/IR/Mangler.h"
29
30
#include " llvm/Pass.h"
30
31
#include " llvm/Support/CommandLine.h"
31
32
#include " llvm/TargetParser/Triple.h"
33
+ #include < map>
32
34
33
35
using namespace llvm ;
34
36
using namespace llvm ::COFF;
@@ -57,15 +59,21 @@ class AArch64Arm64ECCallLowering : public ModulePass {
57
59
Function *buildEntryThunk (Function *F);
58
60
void lowerCall (CallBase *CB);
59
61
Function *buildGuestExitThunk (Function *F);
60
- bool processFunction (Function &F, SetVector<Function *> &DirectCalledFns);
62
+ Function *buildPatchableThunk (GlobalAlias *UnmangledAlias,
63
+ GlobalAlias *MangledAlias);
64
+ bool processFunction (Function &F, SetVector<GlobalValue *> &DirectCalledFns,
65
+ std::map<GlobalAlias *, GlobalAlias *> &PathcableFns);
61
66
bool runOnModule (Module &M) override ;
62
67
63
68
private:
64
69
int cfguard_module_flag = 0 ;
65
70
FunctionType *GuardFnType = nullptr ;
66
71
PointerType *GuardFnPtrType = nullptr ;
72
+ FunctionType *DispatchFnType = nullptr ;
73
+ PointerType *DispatchFnPtrType = nullptr ;
67
74
Constant *GuardFnCFGlobal = nullptr ;
68
75
Constant *GuardFnGlobal = nullptr ;
76
+ Constant *DispatchFnGlobal = nullptr ;
69
77
Module *M = nullptr ;
70
78
71
79
Type *PtrTy;
@@ -615,6 +623,66 @@ Function *AArch64Arm64ECCallLowering::buildGuestExitThunk(Function *F) {
615
623
return GuestExit;
616
624
}
617
625
626
+ Function *
627
+ AArch64Arm64ECCallLowering::buildPatchableThunk (GlobalAlias *UnmangledAlias,
628
+ GlobalAlias *MangledAlias) {
629
+ llvm::raw_null_ostream NullThunkName;
630
+ FunctionType *Arm64Ty, *X64Ty;
631
+ Function *F = cast<Function>(MangledAlias->getAliasee ());
632
+ getThunkType (F->getFunctionType (), F->getAttributes (),
633
+ Arm64ECThunkType::GuestExit, NullThunkName, Arm64Ty, X64Ty);
634
+ std::string ThunkName (MangledAlias->getName ());
635
+ if (ThunkName[0 ] == ' ?' && ThunkName.find (" @" ) != std::string::npos) {
636
+ ThunkName.insert (ThunkName.find (" @" ), " $hybpatch_thunk" );
637
+ } else {
638
+ ThunkName.append (" $hybpatch_thunk" );
639
+ }
640
+
641
+ Function *GuestExit =
642
+ Function::Create (Arm64Ty, GlobalValue::WeakODRLinkage, 0 , ThunkName, M);
643
+ GuestExit->setComdat (M->getOrInsertComdat (ThunkName));
644
+ GuestExit->setSection (" .wowthk$aa" );
645
+ BasicBlock *BB = BasicBlock::Create (M->getContext (), " " , GuestExit);
646
+ IRBuilder<> B (BB);
647
+
648
+ // Load the global symbol as a pointer to the check function.
649
+ LoadInst *DispatchLoad = B.CreateLoad (DispatchFnPtrType, DispatchFnGlobal);
650
+
651
+ // Create new dispatch call instruction.
652
+ Function *ExitThunk =
653
+ buildExitThunk (F->getFunctionType (), F->getAttributes ());
654
+ CallInst *Dispatch = B.CreateCall (
655
+ DispatchFnType, DispatchLoad,
656
+ {B.CreateBitCast (UnmangledAlias, B.getPtrTy ()),
657
+ B.CreateBitCast (ExitThunk, B.getPtrTy ()),
658
+ B.CreateBitCast (UnmangledAlias->getAliasee (), B.getPtrTy ())});
659
+
660
+ // Ensure that the first arguments are passed in the correct registers.
661
+ Dispatch->setCallingConv (CallingConv::CFGuard_Check);
662
+
663
+ Value *DispatchRetVal = B.CreateBitCast (Dispatch, PtrTy);
664
+ SmallVector<Value *> Args;
665
+ for (Argument &Arg : GuestExit->args ())
666
+ Args.push_back (&Arg);
667
+ CallInst *Call = B.CreateCall (Arm64Ty, DispatchRetVal, Args);
668
+ Call->setTailCallKind (llvm::CallInst::TCK_MustTail);
669
+
670
+ if (Call->getType ()->isVoidTy ())
671
+ B.CreateRetVoid ();
672
+ else
673
+ B.CreateRet (Call);
674
+
675
+ auto SRetAttr = F->getAttributes ().getParamAttr (0 , Attribute::StructRet);
676
+ auto InRegAttr = F->getAttributes ().getParamAttr (0 , Attribute::InReg);
677
+ if (SRetAttr.isValid () && !InRegAttr.isValid ()) {
678
+ GuestExit->addParamAttr (0 , SRetAttr);
679
+ Call->addParamAttr (0 , SRetAttr);
680
+ }
681
+
682
+ MangledAlias->setAliasee (GuestExit);
683
+ return GuestExit;
684
+ }
685
+
618
686
// Lower an indirect call with inline code.
619
687
void AArch64Arm64ECCallLowering::lowerCall (CallBase *CB) {
620
688
assert (Triple (CB->getModule ()->getTargetTriple ()).isOSWindows () &&
@@ -670,17 +738,48 @@ bool AArch64Arm64ECCallLowering::runOnModule(Module &Mod) {
670
738
671
739
GuardFnType = FunctionType::get (PtrTy, {PtrTy, PtrTy}, false );
672
740
GuardFnPtrType = PointerType::get (GuardFnType, 0 );
741
+ DispatchFnType = FunctionType::get (PtrTy, {PtrTy, PtrTy, PtrTy}, false );
742
+ DispatchFnPtrType = PointerType::get (DispatchFnType, 0 );
673
743
GuardFnCFGlobal =
674
744
M->getOrInsertGlobal (" __os_arm64x_check_icall_cfg" , GuardFnPtrType);
675
745
GuardFnGlobal =
676
746
M->getOrInsertGlobal (" __os_arm64x_check_icall" , GuardFnPtrType);
747
+ DispatchFnGlobal =
748
+ M->getOrInsertGlobal (" __os_arm64x_dispatch_call" , DispatchFnPtrType);
749
+
750
+ std::map<GlobalAlias *, GlobalAlias *> PatchableFns;
751
+ for (Function &F : Mod) {
752
+ if (!F.hasFnAttribute (Attribute::HybridPatchable) || F.isDeclaration () ||
753
+ F.hasLocalLinkage () || F.getName ().ends_with (" $hp_target" ))
754
+ continue ;
755
+
756
+ // Rename hybrid patchable functions and change callers to use a global
757
+ // alias instead.
758
+ if (std::optional<std::string> MangledName =
759
+ getArm64ECMangledFunctionName (F.getName ().str ())) {
760
+ std::string OrigName (F.getName ());
761
+ F.setName (MangledName.value () + " $hp_target" );
677
762
678
- SetVector<Function *> DirectCalledFns;
763
+ auto *A =
764
+ GlobalAlias::create (GlobalValue::LinkOnceODRLinkage, OrigName, &F);
765
+ F.replaceAllUsesWith (A);
766
+ F.setMetadata (" arm64ec_exp_name" ,
767
+ MDNode::get (M->getContext (),
768
+ MDString::get (M->getContext (),
769
+ " EXP+" + MangledName.value ())));
770
+ A->setAliasee (&F);
771
+
772
+ PatchableFns[A] = GlobalAlias::create (GlobalValue::LinkOnceODRLinkage,
773
+ MangledName.value (), &F);
774
+ }
775
+ }
776
+
777
+ SetVector<GlobalValue *> DirectCalledFns;
679
778
for (Function &F : Mod)
680
779
if (!F.isDeclaration () &&
681
780
F.getCallingConv () != CallingConv::ARM64EC_Thunk_Native &&
682
781
F.getCallingConv () != CallingConv::ARM64EC_Thunk_X64)
683
- processFunction (F, DirectCalledFns);
782
+ processFunction (F, DirectCalledFns, PatchableFns );
684
783
685
784
struct ThunkInfo {
686
785
Constant *Src;
@@ -698,14 +797,20 @@ bool AArch64Arm64ECCallLowering::runOnModule(Module &Mod) {
698
797
{&F, buildEntryThunk (&F), Arm64ECThunkType::Entry});
699
798
}
700
799
}
701
- for (Function *F : DirectCalledFns) {
800
+ for (GlobalValue *O : DirectCalledFns) {
801
+ auto GA = dyn_cast<GlobalAlias>(O);
802
+ auto F = dyn_cast<Function>(GA ? GA->getAliasee () : O);
702
803
ThunkMapping.push_back (
703
- {F , buildExitThunk (F->getFunctionType (), F->getAttributes ()),
804
+ {O , buildExitThunk (F->getFunctionType (), F->getAttributes ()),
704
805
Arm64ECThunkType::Exit});
705
- if (!F->hasDLLImportStorageClass ())
806
+ if (!GA && ! F->hasDLLImportStorageClass ())
706
807
ThunkMapping.push_back (
707
808
{buildGuestExitThunk (F), F, Arm64ECThunkType::GuestExit});
708
809
}
810
+ for (auto I : PatchableFns) {
811
+ Function *Thunk = buildPatchableThunk (I.first , I.second );
812
+ ThunkMapping.push_back ({Thunk, I.first , Arm64ECThunkType::GuestExit});
813
+ }
709
814
710
815
if (!ThunkMapping.empty ()) {
711
816
SmallVector<Constant *> ThunkMappingArrayElems;
@@ -728,7 +833,8 @@ bool AArch64Arm64ECCallLowering::runOnModule(Module &Mod) {
728
833
}
729
834
730
835
bool AArch64Arm64ECCallLowering::processFunction (
731
- Function &F, SetVector<Function *> &DirectCalledFns) {
836
+ Function &F, SetVector<GlobalValue *> &DirectCalledFns,
837
+ std::map<GlobalAlias *, GlobalAlias *> &PatchableFns) {
732
838
SmallVector<CallBase *, 8 > IndirectCalls;
733
839
734
840
// For ARM64EC targets, a function definition's name is mangled differently
@@ -780,6 +886,17 @@ bool AArch64Arm64ECCallLowering::processFunction(
780
886
continue ;
781
887
}
782
888
889
+ // Use mangled global alias for direct calls to patchable functions.
890
+ if (GlobalAlias *A =
891
+ dyn_cast_or_null<GlobalAlias>(CB->getCalledOperand ())) {
892
+ auto I = PatchableFns.find (A);
893
+ if (I != PatchableFns.end ()) {
894
+ CB->setCalledOperand (I->second );
895
+ DirectCalledFns.insert (I->first );
896
+ continue ;
897
+ }
898
+ }
899
+
783
900
IndirectCalls.push_back (CB);
784
901
++Arm64ECCallsLowered;
785
902
}
0 commit comments