21
21
#include " llvm/ADT/SmallVector.h"
22
22
#include " llvm/ADT/Statistic.h"
23
23
#include " llvm/IR/CallingConv.h"
24
+ #include " llvm/IR/GlobalAlias.h"
24
25
#include " llvm/IR/IRBuilder.h"
25
26
#include " llvm/IR/Instruction.h"
26
27
#include " llvm/IR/Mangler.h"
@@ -69,15 +70,21 @@ class AArch64Arm64ECCallLowering : public ModulePass {
69
70
Function *buildEntryThunk (Function *F);
70
71
void lowerCall (CallBase *CB);
71
72
Function *buildGuestExitThunk (Function *F);
72
- bool processFunction (Function &F, SetVector<Function *> &DirectCalledFns);
73
+ Function *buildPatchableThunk (GlobalAlias *UnmangledAlias,
74
+ GlobalAlias *MangledAlias);
75
+ bool processFunction (Function &F, SetVector<GlobalValue *> &DirectCalledFns,
76
+ DenseMap<GlobalAlias *, GlobalAlias *> &FnsMap);
73
77
bool runOnModule (Module &M) override ;
74
78
75
79
private:
76
80
int cfguard_module_flag = 0 ;
77
81
FunctionType *GuardFnType = nullptr ;
78
82
PointerType *GuardFnPtrType = nullptr ;
83
+ FunctionType *DispatchFnType = nullptr ;
84
+ PointerType *DispatchFnPtrType = nullptr ;
79
85
Constant *GuardFnCFGlobal = nullptr ;
80
86
Constant *GuardFnGlobal = nullptr ;
87
+ Constant *DispatchFnGlobal = nullptr ;
81
88
Module *M = nullptr ;
82
89
83
90
Type *PtrTy;
@@ -671,6 +678,66 @@ Function *AArch64Arm64ECCallLowering::buildGuestExitThunk(Function *F) {
671
678
return GuestExit;
672
679
}
673
680
681
+ Function *
682
+ AArch64Arm64ECCallLowering::buildPatchableThunk (GlobalAlias *UnmangledAlias,
683
+ GlobalAlias *MangledAlias) {
684
+ llvm::raw_null_ostream NullThunkName;
685
+ FunctionType *Arm64Ty, *X64Ty;
686
+ Function *F = cast<Function>(MangledAlias->getAliasee ());
687
+ SmallVector<ThunkArgTranslation> ArgTranslations;
688
+ getThunkType (F->getFunctionType (), F->getAttributes (),
689
+ Arm64ECThunkType::GuestExit, NullThunkName, Arm64Ty, X64Ty,
690
+ ArgTranslations);
691
+ std::string ThunkName (MangledAlias->getName ());
692
+ if (ThunkName[0 ] == ' ?' && ThunkName.find (" @" ) != std::string::npos) {
693
+ ThunkName.insert (ThunkName.find (" @" ), " $hybpatch_thunk" );
694
+ } else {
695
+ ThunkName.append (" $hybpatch_thunk" );
696
+ }
697
+
698
+ Function *GuestExit =
699
+ Function::Create (Arm64Ty, GlobalValue::WeakODRLinkage, 0 , ThunkName, M);
700
+ GuestExit->setComdat (M->getOrInsertComdat (ThunkName));
701
+ GuestExit->setSection (" .wowthk$aa" );
702
+ BasicBlock *BB = BasicBlock::Create (M->getContext (), " " , GuestExit);
703
+ IRBuilder<> B (BB);
704
+
705
+ // Load the global symbol as a pointer to the check function.
706
+ LoadInst *DispatchLoad = B.CreateLoad (DispatchFnPtrType, DispatchFnGlobal);
707
+
708
+ // Create new dispatch call instruction.
709
+ Function *ExitThunk =
710
+ buildExitThunk (F->getFunctionType (), F->getAttributes ());
711
+ CallInst *Dispatch =
712
+ B.CreateCall (DispatchFnType, DispatchLoad,
713
+ {UnmangledAlias, ExitThunk, UnmangledAlias->getAliasee ()});
714
+
715
+ // Ensure that the first arguments are passed in the correct registers.
716
+ Dispatch->setCallingConv (CallingConv::CFGuard_Check);
717
+
718
+ Value *DispatchRetVal = B.CreateBitCast (Dispatch, PtrTy);
719
+ SmallVector<Value *> Args;
720
+ for (Argument &Arg : GuestExit->args ())
721
+ Args.push_back (&Arg);
722
+ CallInst *Call = B.CreateCall (Arm64Ty, DispatchRetVal, Args);
723
+ Call->setTailCallKind (llvm::CallInst::TCK_MustTail);
724
+
725
+ if (Call->getType ()->isVoidTy ())
726
+ B.CreateRetVoid ();
727
+ else
728
+ B.CreateRet (Call);
729
+
730
+ auto SRetAttr = F->getAttributes ().getParamAttr (0 , Attribute::StructRet);
731
+ auto InRegAttr = F->getAttributes ().getParamAttr (0 , Attribute::InReg);
732
+ if (SRetAttr.isValid () && !InRegAttr.isValid ()) {
733
+ GuestExit->addParamAttr (0 , SRetAttr);
734
+ Call->addParamAttr (0 , SRetAttr);
735
+ }
736
+
737
+ MangledAlias->setAliasee (GuestExit);
738
+ return GuestExit;
739
+ }
740
+
674
741
// Lower an indirect call with inline code.
675
742
void AArch64Arm64ECCallLowering::lowerCall (CallBase *CB) {
676
743
assert (Triple (CB->getModule ()->getTargetTriple ()).isOSWindows () &&
@@ -726,17 +793,57 @@ bool AArch64Arm64ECCallLowering::runOnModule(Module &Mod) {
726
793
727
794
GuardFnType = FunctionType::get (PtrTy, {PtrTy, PtrTy}, false );
728
795
GuardFnPtrType = PointerType::get (GuardFnType, 0 );
796
+ DispatchFnType = FunctionType::get (PtrTy, {PtrTy, PtrTy, PtrTy}, false );
797
+ DispatchFnPtrType = PointerType::get (DispatchFnType, 0 );
729
798
GuardFnCFGlobal =
730
799
M->getOrInsertGlobal (" __os_arm64x_check_icall_cfg" , GuardFnPtrType);
731
800
GuardFnGlobal =
732
801
M->getOrInsertGlobal (" __os_arm64x_check_icall" , GuardFnPtrType);
802
+ DispatchFnGlobal =
803
+ M->getOrInsertGlobal (" __os_arm64x_dispatch_call" , DispatchFnPtrType);
804
+
805
+ DenseMap<GlobalAlias *, GlobalAlias *> FnsMap;
806
+ SetVector<GlobalAlias *> PatchableFns;
733
807
734
- SetVector<Function *> DirectCalledFns;
808
+ for (Function &F : Mod) {
809
+ if (!F.hasFnAttribute (Attribute::HybridPatchable) || F.isDeclaration () ||
810
+ F.hasLocalLinkage () || F.getName ().ends_with (" $hp_target" ))
811
+ continue ;
812
+
813
+ // Rename hybrid patchable functions and change callers to use a global
814
+ // alias instead.
815
+ if (std::optional<std::string> MangledName =
816
+ getArm64ECMangledFunctionName (F.getName ().str ())) {
817
+ std::string OrigName (F.getName ());
818
+ F.setName (MangledName.value () + " $hp_target" );
819
+
820
+ // The unmangled symbol is a weak alias to an undefined symbol with the
821
+ // "EXP+" prefix. This undefined symbol is resolved by the linker by
822
+ // creating an x86 thunk that jumps back to the actual EC target. Since we
823
+ // can't represent that in IR, we create an alias to the target instead.
824
+ // The "EXP+" symbol is set as metadata, which is then used by
825
+ // emitGlobalAlias to emit the right alias.
826
+ auto *A =
827
+ GlobalAlias::create (GlobalValue::LinkOnceODRLinkage, OrigName, &F);
828
+ F.replaceAllUsesWith (A);
829
+ F.setMetadata (" arm64ec_exp_name" ,
830
+ MDNode::get (M->getContext (),
831
+ MDString::get (M->getContext (),
832
+ " EXP+" + MangledName.value ())));
833
+ A->setAliasee (&F);
834
+
835
+ FnsMap[A] = GlobalAlias::create (GlobalValue::LinkOnceODRLinkage,
836
+ MangledName.value (), &F);
837
+ PatchableFns.insert (A);
838
+ }
839
+ }
840
+
841
+ SetVector<GlobalValue *> DirectCalledFns;
735
842
for (Function &F : Mod)
736
843
if (!F.isDeclaration () &&
737
844
F.getCallingConv () != CallingConv::ARM64EC_Thunk_Native &&
738
845
F.getCallingConv () != CallingConv::ARM64EC_Thunk_X64)
739
- processFunction (F, DirectCalledFns);
846
+ processFunction (F, DirectCalledFns, FnsMap );
740
847
741
848
struct ThunkInfo {
742
849
Constant *Src;
@@ -754,14 +861,20 @@ bool AArch64Arm64ECCallLowering::runOnModule(Module &Mod) {
754
861
{&F, buildEntryThunk (&F), Arm64ECThunkType::Entry});
755
862
}
756
863
}
757
- for (Function *F : DirectCalledFns) {
864
+ for (GlobalValue *O : DirectCalledFns) {
865
+ auto GA = dyn_cast<GlobalAlias>(O);
866
+ auto F = dyn_cast<Function>(GA ? GA->getAliasee () : O);
758
867
ThunkMapping.push_back (
759
- {F , buildExitThunk (F->getFunctionType (), F->getAttributes ()),
868
+ {O , buildExitThunk (F->getFunctionType (), F->getAttributes ()),
760
869
Arm64ECThunkType::Exit});
761
- if (!F->hasDLLImportStorageClass ())
870
+ if (!GA && ! F->hasDLLImportStorageClass ())
762
871
ThunkMapping.push_back (
763
872
{buildGuestExitThunk (F), F, Arm64ECThunkType::GuestExit});
764
873
}
874
+ for (GlobalAlias *A : PatchableFns) {
875
+ Function *Thunk = buildPatchableThunk (A, FnsMap[A]);
876
+ ThunkMapping.push_back ({Thunk, A, Arm64ECThunkType::GuestExit});
877
+ }
765
878
766
879
if (!ThunkMapping.empty ()) {
767
880
SmallVector<Constant *> ThunkMappingArrayElems;
@@ -784,7 +897,8 @@ bool AArch64Arm64ECCallLowering::runOnModule(Module &Mod) {
784
897
}
785
898
786
899
bool AArch64Arm64ECCallLowering::processFunction (
787
- Function &F, SetVector<Function *> &DirectCalledFns) {
900
+ Function &F, SetVector<GlobalValue *> &DirectCalledFns,
901
+ DenseMap<GlobalAlias *, GlobalAlias *> &FnsMap) {
788
902
SmallVector<CallBase *, 8 > IndirectCalls;
789
903
790
904
// For ARM64EC targets, a function definition's name is mangled differently
@@ -836,6 +950,16 @@ bool AArch64Arm64ECCallLowering::processFunction(
836
950
continue ;
837
951
}
838
952
953
+ // Use mangled global alias for direct calls to patchable functions.
954
+ if (GlobalAlias *A = dyn_cast<GlobalAlias>(CB->getCalledOperand ())) {
955
+ auto I = FnsMap.find (A);
956
+ if (I != FnsMap.end ()) {
957
+ CB->setCalledOperand (I->second );
958
+ DirectCalledFns.insert (I->first );
959
+ continue ;
960
+ }
961
+ }
962
+
839
963
IndirectCalls.push_back (CB);
840
964
++Arm64ECCallsLowered;
841
965
}
0 commit comments