21
21
#include " llvm/ADT/SmallVector.h"
22
22
#include " llvm/ADT/Statistic.h"
23
23
#include " llvm/IR/CallingConv.h"
24
+ #include " llvm/IR/GlobalAlias.h"
24
25
#include " llvm/IR/IRBuilder.h"
25
26
#include " llvm/IR/Instruction.h"
26
27
#include " llvm/IR/Mangler.h"
@@ -70,15 +71,21 @@ class AArch64Arm64ECCallLowering : public ModulePass {
70
71
Function *buildEntryThunk (Function *F);
71
72
void lowerCall (CallBase *CB);
72
73
Function *buildGuestExitThunk (Function *F);
73
- bool processFunction (Function &F, SetVector<Function *> &DirectCalledFns);
74
+ Function *buildPatchableThunk (GlobalAlias *UnmangledAlias,
75
+ GlobalAlias *MangledAlias);
76
+ bool processFunction (Function &F, SetVector<GlobalValue *> &DirectCalledFns,
77
+ DenseMap<GlobalAlias *, GlobalAlias *> &FnsMap);
74
78
bool runOnModule (Module &M) override ;
75
79
76
80
private:
77
81
int cfguard_module_flag = 0 ;
78
82
FunctionType *GuardFnType = nullptr ;
79
83
PointerType *GuardFnPtrType = nullptr ;
84
+ FunctionType *DispatchFnType = nullptr ;
85
+ PointerType *DispatchFnPtrType = nullptr ;
80
86
Constant *GuardFnCFGlobal = nullptr ;
81
87
Constant *GuardFnGlobal = nullptr ;
88
+ Constant *DispatchFnGlobal = nullptr ;
82
89
Module *M = nullptr ;
83
90
84
91
Type *PtrTy;
@@ -672,6 +679,66 @@ Function *AArch64Arm64ECCallLowering::buildGuestExitThunk(Function *F) {
672
679
return GuestExit;
673
680
}
674
681
682
+ Function *
683
+ AArch64Arm64ECCallLowering::buildPatchableThunk (GlobalAlias *UnmangledAlias,
684
+ GlobalAlias *MangledAlias) {
685
+ llvm::raw_null_ostream NullThunkName;
686
+ FunctionType *Arm64Ty, *X64Ty;
687
+ Function *F = cast<Function>(MangledAlias->getAliasee ());
688
+ SmallVector<ThunkArgTranslation> ArgTranslations;
689
+ getThunkType (F->getFunctionType (), F->getAttributes (),
690
+ Arm64ECThunkType::GuestExit, NullThunkName, Arm64Ty, X64Ty,
691
+ ArgTranslations);
692
+ std::string ThunkName (MangledAlias->getName ());
693
+ if (ThunkName[0 ] == ' ?' && ThunkName.find (" @" ) != std::string::npos) {
694
+ ThunkName.insert (ThunkName.find (" @" ), " $hybpatch_thunk" );
695
+ } else {
696
+ ThunkName.append (" $hybpatch_thunk" );
697
+ }
698
+
699
+ Function *GuestExit =
700
+ Function::Create (Arm64Ty, GlobalValue::WeakODRLinkage, 0 , ThunkName, M);
701
+ GuestExit->setComdat (M->getOrInsertComdat (ThunkName));
702
+ GuestExit->setSection (" .wowthk$aa" );
703
+ BasicBlock *BB = BasicBlock::Create (M->getContext (), " " , GuestExit);
704
+ IRBuilder<> B (BB);
705
+
706
+ // Load the global symbol as a pointer to the check function.
707
+ LoadInst *DispatchLoad = B.CreateLoad (DispatchFnPtrType, DispatchFnGlobal);
708
+
709
+ // Create new dispatch call instruction.
710
+ Function *ExitThunk =
711
+ buildExitThunk (F->getFunctionType (), F->getAttributes ());
712
+ CallInst *Dispatch =
713
+ B.CreateCall (DispatchFnType, DispatchLoad,
714
+ {UnmangledAlias, ExitThunk, UnmangledAlias->getAliasee ()});
715
+
716
+ // Ensure that the first arguments are passed in the correct registers.
717
+ Dispatch->setCallingConv (CallingConv::CFGuard_Check);
718
+
719
+ Value *DispatchRetVal = B.CreateBitCast (Dispatch, PtrTy);
720
+ SmallVector<Value *> Args;
721
+ for (Argument &Arg : GuestExit->args ())
722
+ Args.push_back (&Arg);
723
+ CallInst *Call = B.CreateCall (Arm64Ty, DispatchRetVal, Args);
724
+ Call->setTailCallKind (llvm::CallInst::TCK_MustTail);
725
+
726
+ if (Call->getType ()->isVoidTy ())
727
+ B.CreateRetVoid ();
728
+ else
729
+ B.CreateRet (Call);
730
+
731
+ auto SRetAttr = F->getAttributes ().getParamAttr (0 , Attribute::StructRet);
732
+ auto InRegAttr = F->getAttributes ().getParamAttr (0 , Attribute::InReg);
733
+ if (SRetAttr.isValid () && !InRegAttr.isValid ()) {
734
+ GuestExit->addParamAttr (0 , SRetAttr);
735
+ Call->addParamAttr (0 , SRetAttr);
736
+ }
737
+
738
+ MangledAlias->setAliasee (GuestExit);
739
+ return GuestExit;
740
+ }
741
+
675
742
// Lower an indirect call with inline code.
676
743
void AArch64Arm64ECCallLowering::lowerCall (CallBase *CB) {
677
744
assert (Triple (CB->getModule ()->getTargetTriple ()).isOSWindows () &&
@@ -727,17 +794,57 @@ bool AArch64Arm64ECCallLowering::runOnModule(Module &Mod) {
727
794
728
795
GuardFnType = FunctionType::get (PtrTy, {PtrTy, PtrTy}, false );
729
796
GuardFnPtrType = PointerType::get (GuardFnType, 0 );
797
+ DispatchFnType = FunctionType::get (PtrTy, {PtrTy, PtrTy, PtrTy}, false );
798
+ DispatchFnPtrType = PointerType::get (DispatchFnType, 0 );
730
799
GuardFnCFGlobal =
731
800
M->getOrInsertGlobal (" __os_arm64x_check_icall_cfg" , GuardFnPtrType);
732
801
GuardFnGlobal =
733
802
M->getOrInsertGlobal (" __os_arm64x_check_icall" , GuardFnPtrType);
803
+ DispatchFnGlobal =
804
+ M->getOrInsertGlobal (" __os_arm64x_dispatch_call" , DispatchFnPtrType);
805
+
806
+ DenseMap<GlobalAlias *, GlobalAlias *> FnsMap;
807
+ SetVector<GlobalAlias *> PatchableFns;
734
808
735
- SetVector<Function *> DirectCalledFns;
809
+ for (Function &F : Mod) {
810
+ if (!F.hasFnAttribute (Attribute::HybridPatchable) || F.isDeclaration () ||
811
+ F.hasLocalLinkage () || F.getName ().ends_with (" $hp_target" ))
812
+ continue ;
813
+
814
+ // Rename hybrid patchable functions and change callers to use a global
815
+ // alias instead.
816
+ if (std::optional<std::string> MangledName =
817
+ getArm64ECMangledFunctionName (F.getName ().str ())) {
818
+ std::string OrigName (F.getName ());
819
+ F.setName (MangledName.value () + " $hp_target" );
820
+
821
+ // The unmangled symbol is a weak alias to an undefined symbol with the
822
+ // "EXP+" prefix. This undefined symbol is resolved by the linker by
823
+ // creating an x86 thunk that jumps back to the actual EC target. Since we
824
+ // can't represent that in IR, we create an alias to the target instead.
825
+ // The "EXP+" symbol is set as metadata, which is then used by
826
+ // emitGlobalAlias to emit the right alias.
827
+ auto *A =
828
+ GlobalAlias::create (GlobalValue::LinkOnceODRLinkage, OrigName, &F);
829
+ F.replaceAllUsesWith (A);
830
+ F.setMetadata (" arm64ec_exp_name" ,
831
+ MDNode::get (M->getContext (),
832
+ MDString::get (M->getContext (),
833
+ " EXP+" + MangledName.value ())));
834
+ A->setAliasee (&F);
835
+
836
+ FnsMap[A] = GlobalAlias::create (GlobalValue::LinkOnceODRLinkage,
837
+ MangledName.value (), &F);
838
+ PatchableFns.insert (A);
839
+ }
840
+ }
841
+
842
+ SetVector<GlobalValue *> DirectCalledFns;
736
843
for (Function &F : Mod)
737
844
if (!F.isDeclaration () &&
738
845
F.getCallingConv () != CallingConv::ARM64EC_Thunk_Native &&
739
846
F.getCallingConv () != CallingConv::ARM64EC_Thunk_X64)
740
- processFunction (F, DirectCalledFns);
847
+ processFunction (F, DirectCalledFns, FnsMap );
741
848
742
849
struct ThunkInfo {
743
850
Constant *Src;
@@ -755,14 +862,20 @@ bool AArch64Arm64ECCallLowering::runOnModule(Module &Mod) {
755
862
{&F, buildEntryThunk (&F), Arm64ECThunkType::Entry});
756
863
}
757
864
}
758
- for (Function *F : DirectCalledFns) {
865
+ for (GlobalValue *O : DirectCalledFns) {
866
+ auto GA = dyn_cast<GlobalAlias>(O);
867
+ auto F = dyn_cast<Function>(GA ? GA->getAliasee () : O);
759
868
ThunkMapping.push_back (
760
- {F , buildExitThunk (F->getFunctionType (), F->getAttributes ()),
869
+ {O , buildExitThunk (F->getFunctionType (), F->getAttributes ()),
761
870
Arm64ECThunkType::Exit});
762
- if (!F->hasDLLImportStorageClass ())
871
+ if (!GA && ! F->hasDLLImportStorageClass ())
763
872
ThunkMapping.push_back (
764
873
{buildGuestExitThunk (F), F, Arm64ECThunkType::GuestExit});
765
874
}
875
+ for (GlobalAlias *A : PatchableFns) {
876
+ Function *Thunk = buildPatchableThunk (A, FnsMap[A]);
877
+ ThunkMapping.push_back ({Thunk, A, Arm64ECThunkType::GuestExit});
878
+ }
766
879
767
880
if (!ThunkMapping.empty ()) {
768
881
SmallVector<Constant *> ThunkMappingArrayElems;
@@ -785,7 +898,8 @@ bool AArch64Arm64ECCallLowering::runOnModule(Module &Mod) {
785
898
}
786
899
787
900
bool AArch64Arm64ECCallLowering::processFunction (
788
- Function &F, SetVector<Function *> &DirectCalledFns) {
901
+ Function &F, SetVector<GlobalValue *> &DirectCalledFns,
902
+ DenseMap<GlobalAlias *, GlobalAlias *> &FnsMap) {
789
903
SmallVector<CallBase *, 8 > IndirectCalls;
790
904
791
905
// For ARM64EC targets, a function definition's name is mangled differently
@@ -837,6 +951,16 @@ bool AArch64Arm64ECCallLowering::processFunction(
837
951
continue ;
838
952
}
839
953
954
+ // Use mangled global alias for direct calls to patchable functions.
955
+ if (GlobalAlias *A = dyn_cast<GlobalAlias>(CB->getCalledOperand ())) {
956
+ auto I = FnsMap.find (A);
957
+ if (I != FnsMap.end ()) {
958
+ CB->setCalledOperand (I->second );
959
+ DirectCalledFns.insert (I->first );
960
+ continue ;
961
+ }
962
+ }
963
+
840
964
IndirectCalls.push_back (CB);
841
965
++Arm64ECCallsLowered;
842
966
}
0 commit comments