Skip to content

Commit 5ee0a71

Browse files
authored
[aarch64][win] Add support for import call optimization (equivalent to MSVC /d2ImportCallOptimization) (#121516)
This change implements import call optimization for AArch64 Windows (equivalent to the undocumented MSVC `/d2ImportCallOptimization` flag). Import call optimization adds additional data to the binary which can be used by the Windows kernel loader to rewrite indirect calls to imported functions as direct calls. It uses the same [Dynamic Value Relocation Table mechanism that was leveraged on x64 to implement `/d2GuardRetpoline`](https://techcommunity.microsoft.com/blog/windowsosplatform/mitigating-spectre-variant-2-with-retpoline-on-windows/295618). The change to the obj file is to add a new `.impcall` section with the following layout: ```cpp // Per section that contains calls to imported functions: // uint32_t SectionSize: Size in bytes for information in this section. // uint32_t Section Number // Per call to imported function in section: // uint32_t Kind: the kind of imported function. // uint32_t BranchOffset: the offset of the branch instruction in its // parent section. // uint32_t TargetSymbolId: the symbol id of the called function. ``` NOTE: If the import call optimization feature is enabled, then the `.impcall` section must be emitted, even if there are no calls to imported functions. The implementation is split across a few parts of LLVM: * During AArch64 instruction selection, the `GlobalValue` for each call to a global is recorded into the Extra Information for that node. * During lowering to machine instructions, the called global value for each call is noted in its containing `MachineFunction`. * During AArch64 asm printing, if the import call optimization feature is enabled: - A (new) `.impcall` directive is emitted for each call to an imported function. - The `.impcall` section is emitted with its magic header (but is not filled in). * During COFF object writing, the `.impcall` section is filled in based on each `.impcall` directive that were encountered. The `.impcall` section can only be filled in when we are writing the COFF object as it requires the actual section numbers, which are only assigned at that point (i.e., they don't exist during asm printing). I had tried to avoid using the Extra Information during instruction selection and instead implement this either purely during asm printing or in a `MachineFunctionPass` (as suggested in [on the forums](https://discourse.llvm.org/t/design-gathering-locations-of-instructions-to-emit-into-a-section/83729/3)) but this was not possible due to how loading and calling an imported function works on AArch64. Specifically, they are emitted as `ADRP` + `LDR` (to load the symbol) then a `BR` (to do the call), so at the point when we have machine instructions, we would have to work backwards through the instructions to discover what is being called. An initial prototype did work by inspecting instructions; however, it didn't correctly handle the case where the same function was called twice in a row, which caused LLVM to elide the `ADRP` + `LDR` and reuse the previously loaded address. Worse than that, sometimes for the double-call case LLVM decided to spill the loaded address to the stack and then reload it before making the second call. So, instead of trying to implement logic to discover where the value in a register came from, I instead recorded the symbol being called at the last place where it was easy to do: instruction selection.
1 parent 4f6fabd commit 5ee0a71

25 files changed

+673
-38
lines changed

llvm/include/llvm/CodeGen/MIRYamlMapping.h

Lines changed: 35 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -457,6 +457,16 @@ template <> struct ScalarTraits<FrameIndex> {
457457
static QuotingType mustQuote(StringRef S) { return needsQuotes(S); }
458458
};
459459

460+
/// Identifies call instruction location in machine function.
461+
struct MachineInstrLoc {
462+
unsigned BlockNum;
463+
unsigned Offset;
464+
465+
bool operator==(const MachineInstrLoc &Other) const {
466+
return BlockNum == Other.BlockNum && Offset == Other.Offset;
467+
}
468+
};
469+
460470
/// Serializable representation of CallSiteInfo.
461471
struct CallSiteInfo {
462472
// Representation of call argument and register which is used to
@@ -470,16 +480,6 @@ struct CallSiteInfo {
470480
}
471481
};
472482

473-
/// Identifies call instruction location in machine function.
474-
struct MachineInstrLoc {
475-
unsigned BlockNum;
476-
unsigned Offset;
477-
478-
bool operator==(const MachineInstrLoc &Other) const {
479-
return BlockNum == Other.BlockNum && Offset == Other.Offset;
480-
}
481-
};
482-
483483
MachineInstrLoc CallLocation;
484484
std::vector<ArgRegPair> ArgForwardingRegs;
485485

@@ -595,6 +595,26 @@ template <> struct MappingTraits<MachineJumpTable::Entry> {
595595
}
596596
};
597597

598+
struct CalledGlobal {
599+
MachineInstrLoc CallSite;
600+
StringValue Callee;
601+
unsigned Flags;
602+
603+
bool operator==(const CalledGlobal &Other) const {
604+
return CallSite == Other.CallSite && Callee == Other.Callee &&
605+
Flags == Other.Flags;
606+
}
607+
};
608+
609+
template <> struct MappingTraits<CalledGlobal> {
610+
static void mapping(IO &YamlIO, CalledGlobal &CG) {
611+
YamlIO.mapRequired("bb", CG.CallSite.BlockNum);
612+
YamlIO.mapRequired("offset", CG.CallSite.Offset);
613+
YamlIO.mapRequired("callee", CG.Callee);
614+
YamlIO.mapRequired("flags", CG.Flags);
615+
}
616+
};
617+
598618
} // end namespace yaml
599619
} // end namespace llvm
600620

@@ -606,6 +626,7 @@ LLVM_YAML_IS_SEQUENCE_VECTOR(llvm::yaml::FixedMachineStackObject)
606626
LLVM_YAML_IS_SEQUENCE_VECTOR(llvm::yaml::CallSiteInfo)
607627
LLVM_YAML_IS_SEQUENCE_VECTOR(llvm::yaml::MachineConstantPoolValue)
608628
LLVM_YAML_IS_SEQUENCE_VECTOR(llvm::yaml::MachineJumpTable::Entry)
629+
LLVM_YAML_IS_SEQUENCE_VECTOR(llvm::yaml::CalledGlobal)
609630

610631
namespace llvm {
611632
namespace yaml {
@@ -764,6 +785,7 @@ struct MachineFunction {
764785
std::vector<DebugValueSubstitution> DebugValueSubstitutions;
765786
MachineJumpTable JumpTableInfo;
766787
std::vector<StringValue> MachineMetadataNodes;
788+
std::vector<CalledGlobal> CalledGlobals;
767789
BlockStringValue Body;
768790
};
769791

@@ -822,6 +844,9 @@ template <> struct MappingTraits<MachineFunction> {
822844
if (!YamlIO.outputting() || !MF.MachineMetadataNodes.empty())
823845
YamlIO.mapOptional("machineMetadataNodes", MF.MachineMetadataNodes,
824846
std::vector<StringValue>());
847+
if (!YamlIO.outputting() || !MF.CalledGlobals.empty())
848+
YamlIO.mapOptional("calledGlobals", MF.CalledGlobals,
849+
std::vector<CalledGlobal>());
825850
YamlIO.mapOptional("body", MF.Body, BlockStringValue());
826851
}
827852
};

llvm/include/llvm/CodeGen/MachineFunction.h

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -354,6 +354,11 @@ class LLVM_ABI MachineFunction {
354354
/// a table of valid targets for Windows EHCont Guard.
355355
std::vector<MCSymbol *> CatchretTargets;
356356

357+
/// Mapping of call instruction to the global value and target flags that it
358+
/// calls, if applicable.
359+
DenseMap<const MachineInstr *, std::pair<const GlobalValue *, unsigned>>
360+
CalledGlobalsMap;
361+
357362
/// \name Exception Handling
358363
/// \{
359364

@@ -1182,6 +1187,26 @@ class LLVM_ABI MachineFunction {
11821187
CatchretTargets.push_back(Target);
11831188
}
11841189

1190+
/// Tries to get the global and target flags for a call site, if the
1191+
/// instruction is a call to a global.
1192+
std::pair<const GlobalValue *, unsigned>
1193+
tryGetCalledGlobal(const MachineInstr *MI) const {
1194+
return CalledGlobalsMap.lookup(MI);
1195+
}
1196+
1197+
/// Notes the global and target flags for a call site.
1198+
void addCalledGlobal(const MachineInstr *MI,
1199+
std::pair<const GlobalValue *, unsigned> Details) {
1200+
assert(MI && "MI must not be null");
1201+
assert(Details.first && "Global must not be null");
1202+
CalledGlobalsMap.insert({MI, Details});
1203+
}
1204+
1205+
/// Iterates over the full set of call sites and their associated globals.
1206+
auto getCalledGlobals() const {
1207+
return llvm::make_range(CalledGlobalsMap.begin(), CalledGlobalsMap.end());
1208+
}
1209+
11851210
/// \name Exception Handling
11861211
/// \{
11871212

llvm/include/llvm/CodeGen/SelectionDAG.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,7 @@ class SelectionDAG {
293293
MDNode *HeapAllocSite = nullptr;
294294
MDNode *PCSections = nullptr;
295295
MDNode *MMRA = nullptr;
296+
std::pair<const GlobalValue *, unsigned> CalledGlobal{};
296297
bool NoMerge = false;
297298
};
298299
/// Out-of-line extra information for SDNodes.
@@ -2373,6 +2374,19 @@ class SelectionDAG {
23732374
auto It = SDEI.find(Node);
23742375
return It != SDEI.end() ? It->second.MMRA : nullptr;
23752376
}
2377+
/// Set CalledGlobal to be associated with Node.
2378+
void addCalledGlobal(const SDNode *Node, const GlobalValue *GV,
2379+
unsigned OpFlags) {
2380+
SDEI[Node].CalledGlobal = {GV, OpFlags};
2381+
}
2382+
/// Return CalledGlobal associated with Node, or a nullopt if none exists.
2383+
std::optional<std::pair<const GlobalValue *, unsigned>>
2384+
getCalledGlobal(const SDNode *Node) {
2385+
auto I = SDEI.find(Node);
2386+
return I != SDEI.end()
2387+
? std::make_optional(std::move(I->second).CalledGlobal)
2388+
: std::nullopt;
2389+
}
23762390
/// Set NoMergeSiteInfo to be associated with Node if NoMerge is true.
23772391
void addNoMergeSiteInfo(const SDNode *Node, bool NoMerge) {
23782392
if (NoMerge)

llvm/include/llvm/MC/MCObjectFileInfo.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,10 @@ class MCObjectFileInfo {
7373
/// to emit them into.
7474
MCSection *CompactUnwindSection = nullptr;
7575

76+
/// If import call optimization is supported by the target, this is the
77+
/// section to emit import call data to.
78+
MCSection *ImportCallSection = nullptr;
79+
7680
// Dwarf sections for debug info. If a target supports debug info, these must
7781
// be set.
7882
MCSection *DwarfAbbrevSection = nullptr;
@@ -269,6 +273,7 @@ class MCObjectFileInfo {
269273
MCSection *getBSSSection() const { return BSSSection; }
270274
MCSection *getReadOnlySection() const { return ReadOnlySection; }
271275
MCSection *getLSDASection() const { return LSDASection; }
276+
MCSection *getImportCallSection() const { return ImportCallSection; }
272277
MCSection *getCompactUnwindSection() const { return CompactUnwindSection; }
273278
MCSection *getDwarfAbbrevSection() const { return DwarfAbbrevSection; }
274279
MCSection *getDwarfInfoSection() const { return DwarfInfoSection; }

llvm/include/llvm/MC/MCStreamer.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -569,6 +569,14 @@ class MCStreamer {
569569
/// \param Symbol - Symbol the image relative relocation should point to.
570570
virtual void emitCOFFImgRel32(MCSymbol const *Symbol, int64_t Offset);
571571

572+
/// Emits the physical number of the section containing the given symbol as
573+
/// assigned during object writing (i.e., this is not a runtime relocation).
574+
virtual void emitCOFFSecNumber(MCSymbol const *Symbol);
575+
576+
/// Emits the offset of the symbol from the beginning of the section during
577+
/// object writing (i.e., this is not a runtime relocation).
578+
virtual void emitCOFFSecOffset(MCSymbol const *Symbol);
579+
572580
/// Emits an lcomm directive with XCOFF csect information.
573581
///
574582
/// \param LabelSym - Label on the block of storage.

llvm/include/llvm/MC/MCWinCOFFObjectWriter.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ class WinCOFFObjectWriter final : public MCObjectWriter {
7272
const MCFixup &Fixup, MCValue Target,
7373
uint64_t &FixedValue) override;
7474
uint64_t writeObject(MCAssembler &Asm) override;
75+
int getSectionNumber(const MCSection &Section) const;
7576
};
7677

7778
/// Construct a new Win COFF writer instance.

llvm/include/llvm/MC/MCWinCOFFStreamer.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@ class MCWinCOFFStreamer : public MCObjectStreamer {
5858
void emitCOFFSectionIndex(MCSymbol const *Symbol) override;
5959
void emitCOFFSecRel32(MCSymbol const *Symbol, uint64_t Offset) override;
6060
void emitCOFFImgRel32(MCSymbol const *Symbol, int64_t Offset) override;
61+
void emitCOFFSecNumber(MCSymbol const *Symbol) override;
62+
void emitCOFFSecOffset(MCSymbol const *Symbol) override;
6163
void emitCommonSymbol(MCSymbol *Symbol, uint64_t Size,
6264
Align ByteAlignment) override;
6365
void emitLocalCommonSymbol(MCSymbol *Symbol, uint64_t Size,

llvm/lib/CodeGen/MIRParser/MIRParser.cpp

Lines changed: 62 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,9 @@ class MIRParserImpl {
158158
MachineFunction &MF,
159159
const yaml::MachineFunction &YMF);
160160

161+
bool parseCalledGlobals(PerFunctionMIParsingState &PFS, MachineFunction &MF,
162+
const yaml::MachineFunction &YMF);
163+
161164
private:
162165
bool parseMDNode(PerFunctionMIParsingState &PFS, MDNode *&Node,
163166
const yaml::StringValue &Source);
@@ -183,6 +186,9 @@ class MIRParserImpl {
183186

184187
void setupDebugValueTracking(MachineFunction &MF,
185188
PerFunctionMIParsingState &PFS, const yaml::MachineFunction &YamlMF);
189+
190+
bool parseMachineInst(MachineFunction &MF, yaml::MachineInstrLoc MILoc,
191+
MachineInstr const *&MI);
186192
};
187193

188194
} // end namespace llvm
@@ -457,24 +463,34 @@ bool MIRParserImpl::computeFunctionProperties(
457463
return false;
458464
}
459465

466+
bool MIRParserImpl::parseMachineInst(MachineFunction &MF,
467+
yaml::MachineInstrLoc MILoc,
468+
MachineInstr const *&MI) {
469+
if (MILoc.BlockNum >= MF.size()) {
470+
return error(Twine(MF.getName()) +
471+
Twine(" instruction block out of range.") +
472+
" Unable to reference bb:" + Twine(MILoc.BlockNum));
473+
}
474+
auto BB = std::next(MF.begin(), MILoc.BlockNum);
475+
if (MILoc.Offset >= BB->size())
476+
return error(
477+
Twine(MF.getName()) + Twine(" instruction offset out of range.") +
478+
" Unable to reference instruction at bb: " + Twine(MILoc.BlockNum) +
479+
" at offset:" + Twine(MILoc.Offset));
480+
MI = &*std::next(BB->instr_begin(), MILoc.Offset);
481+
return false;
482+
}
483+
460484
bool MIRParserImpl::initializeCallSiteInfo(
461485
PerFunctionMIParsingState &PFS, const yaml::MachineFunction &YamlMF) {
462486
MachineFunction &MF = PFS.MF;
463487
SMDiagnostic Error;
464488
const TargetMachine &TM = MF.getTarget();
465489
for (auto &YamlCSInfo : YamlMF.CallSitesInfo) {
466-
yaml::CallSiteInfo::MachineInstrLoc MILoc = YamlCSInfo.CallLocation;
467-
if (MILoc.BlockNum >= MF.size())
468-
return error(Twine(MF.getName()) +
469-
Twine(" call instruction block out of range.") +
470-
" Unable to reference bb:" + Twine(MILoc.BlockNum));
471-
auto CallB = std::next(MF.begin(), MILoc.BlockNum);
472-
if (MILoc.Offset >= CallB->size())
473-
return error(Twine(MF.getName()) +
474-
Twine(" call instruction offset out of range.") +
475-
" Unable to reference instruction at bb: " +
476-
Twine(MILoc.BlockNum) + " at offset:" + Twine(MILoc.Offset));
477-
auto CallI = std::next(CallB->instr_begin(), MILoc.Offset);
490+
yaml::MachineInstrLoc MILoc = YamlCSInfo.CallLocation;
491+
const MachineInstr *CallI;
492+
if (parseMachineInst(MF, MILoc, CallI))
493+
return true;
478494
if (!CallI->isCall(MachineInstr::IgnoreBundle))
479495
return error(Twine(MF.getName()) +
480496
Twine(" call site info should reference call "
@@ -641,6 +657,9 @@ MIRParserImpl::initializeMachineFunction(const yaml::MachineFunction &YamlMF,
641657
if (initializeCallSiteInfo(PFS, YamlMF))
642658
return true;
643659

660+
if (parseCalledGlobals(PFS, MF, YamlMF))
661+
return true;
662+
644663
setupDebugValueTracking(MF, PFS, YamlMF);
645664

646665
MF.getSubtarget().mirFileLoaded(MF);
@@ -1111,6 +1130,37 @@ bool MIRParserImpl::parseMachineMetadataNodes(
11111130
return false;
11121131
}
11131132

1133+
bool MIRParserImpl::parseCalledGlobals(PerFunctionMIParsingState &PFS,
1134+
MachineFunction &MF,
1135+
const yaml::MachineFunction &YMF) {
1136+
Function &F = MF.getFunction();
1137+
for (const auto &YamlCG : YMF.CalledGlobals) {
1138+
yaml::MachineInstrLoc MILoc = YamlCG.CallSite;
1139+
const MachineInstr *CallI;
1140+
if (parseMachineInst(MF, MILoc, CallI))
1141+
return true;
1142+
if (!CallI->isCall(MachineInstr::IgnoreBundle))
1143+
return error(Twine(MF.getName()) +
1144+
Twine(" called global should reference call "
1145+
"instruction. Instruction at bb:") +
1146+
Twine(MILoc.BlockNum) + " at offset:" + Twine(MILoc.Offset) +
1147+
" is not a call instruction");
1148+
1149+
auto Callee =
1150+
F.getParent()->getValueSymbolTable().lookup(YamlCG.Callee.Value);
1151+
if (!Callee)
1152+
return error(YamlCG.Callee.SourceRange.Start,
1153+
"use of undefined global '" + YamlCG.Callee.Value + "'");
1154+
if (!isa<GlobalValue>(Callee))
1155+
return error(YamlCG.Callee.SourceRange.Start,
1156+
"use of non-global value '" + YamlCG.Callee.Value + "'");
1157+
1158+
MF.addCalledGlobal(CallI, {cast<GlobalValue>(Callee), YamlCG.Flags});
1159+
}
1160+
1161+
return false;
1162+
}
1163+
11141164
SMDiagnostic MIRParserImpl::diagFromMIStringDiag(const SMDiagnostic &Error,
11151165
SMRange SourceRange) {
11161166
assert(SourceRange.isValid() && "Invalid source range");

llvm/lib/CodeGen/MIRPrinter.cpp

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,9 @@ class MIRPrinter {
133133
void convertMachineMetadataNodes(yaml::MachineFunction &YMF,
134134
const MachineFunction &MF,
135135
MachineModuleSlotTracker &MST);
136+
void convertCalledGlobals(yaml::MachineFunction &YMF,
137+
const MachineFunction &MF,
138+
MachineModuleSlotTracker &MST);
136139

137140
private:
138141
void initRegisterMaskIds(const MachineFunction &MF);
@@ -269,6 +272,8 @@ void MIRPrinter::print(const MachineFunction &MF) {
269272
// function.
270273
convertMachineMetadataNodes(YamlMF, MF, MST);
271274

275+
convertCalledGlobals(YamlMF, MF, MST);
276+
272277
yaml::Output Out(OS);
273278
if (!SimplifyMIR)
274279
Out.setWriteDefaultValues(true);
@@ -555,7 +560,7 @@ void MIRPrinter::convertCallSiteObjects(yaml::MachineFunction &YMF,
555560
const auto *TRI = MF.getSubtarget().getRegisterInfo();
556561
for (auto CSInfo : MF.getCallSitesInfo()) {
557562
yaml::CallSiteInfo YmlCS;
558-
yaml::CallSiteInfo::MachineInstrLoc CallLocation;
563+
yaml::MachineInstrLoc CallLocation;
559564

560565
// Prepare instruction position.
561566
MachineBasicBlock::const_instr_iterator CallI = CSInfo.first->getIterator();
@@ -596,6 +601,32 @@ void MIRPrinter::convertMachineMetadataNodes(yaml::MachineFunction &YMF,
596601
}
597602
}
598603

604+
void MIRPrinter::convertCalledGlobals(yaml::MachineFunction &YMF,
605+
const MachineFunction &MF,
606+
MachineModuleSlotTracker &MST) {
607+
for (const auto [CallInst, CG] : MF.getCalledGlobals()) {
608+
// If the call instruction was dropped, then we don't need to print it.
609+
auto BB = CallInst->getParent();
610+
if (BB) {
611+
yaml::MachineInstrLoc CallSite;
612+
CallSite.BlockNum = CallInst->getParent()->getNumber();
613+
CallSite.Offset = std::distance(CallInst->getParent()->instr_begin(),
614+
CallInst->getIterator());
615+
616+
yaml::CalledGlobal YamlCG{CallSite, CG.first->getName().str(), CG.second};
617+
YMF.CalledGlobals.push_back(YamlCG);
618+
}
619+
}
620+
621+
// Sort by position of call instructions.
622+
llvm::sort(YMF.CalledGlobals.begin(), YMF.CalledGlobals.end(),
623+
[](yaml::CalledGlobal A, yaml::CalledGlobal B) {
624+
if (A.CallSite.BlockNum == B.CallSite.BlockNum)
625+
return A.CallSite.Offset < B.CallSite.Offset;
626+
return A.CallSite.BlockNum < B.CallSite.BlockNum;
627+
});
628+
}
629+
599630
void MIRPrinter::convert(yaml::MachineFunction &MF,
600631
const MachineConstantPool &ConstantPool) {
601632
unsigned ID = 0;

llvm/lib/CodeGen/SelectionDAG/ScheduleDAGSDNodes.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -908,6 +908,10 @@ EmitSchedule(MachineBasicBlock::iterator &InsertPos) {
908908
It->setMMRAMetadata(MF, MMRA);
909909
}
910910

911+
if (auto CalledGlobal = DAG->getCalledGlobal(Node))
912+
if (CalledGlobal->first)
913+
MF.addCalledGlobal(MI, *CalledGlobal);
914+
911915
return MI;
912916
};
913917

0 commit comments

Comments
 (0)