Skip to content

Commit 9207e12

Browse files
[PAC][AArch64] Lower ptrauth constants in code
Define the following pseudos for lowering ptrauth constants in code: - non-`extern_weak`: - no GOT load needed: `MOVaddrPAC` - similar to `MOVaddr`, with added PAC; - GOT load needed: `LOADgotPAC` - similar to `LOADgot`, with added PAC; - `extern_weak`: `LOADauthptrstatic` - similar to `LOADgot`, but use a special stub slot named `sym$auth_ptr$key$disc` filled by dynamic linker during relocation resolving instead of a GOT slot. Co-authored-by: Ahmed Bougacha <[email protected]>
1 parent f779ec7 commit 9207e12

22 files changed

+854
-37
lines changed

llvm/include/llvm/CodeGen/ISDOpcodes.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,12 @@ enum NodeType {
8383
ExternalSymbol,
8484
BlockAddress,
8585

86+
/// A ptrauth constant.
87+
/// ptr, key, addr-disc, disc
88+
/// Note that the addr-disc can be a non-constant value, to allow representing
89+
/// a constant global address signed using address-diversification, in code.
90+
PtrAuthGlobalAddress,
91+
8692
/// The address of the GOT
8793
GLOBAL_OFFSET_TABLE,
8894

llvm/include/llvm/CodeGen/MachineModuleInfoImpls.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,10 +61,18 @@ class MachineModuleInfoMachO : public MachineModuleInfoImpl {
6161
/// MachineModuleInfoELF - This is a MachineModuleInfoImpl implementation
6262
/// for ELF targets.
6363
class MachineModuleInfoELF : public MachineModuleInfoImpl {
64+
public:
65+
struct AuthStubInfo {
66+
const MCExpr *AuthPtrRef;
67+
};
68+
69+
private:
6470
/// GVStubs - These stubs are used to materialize global addresses in PIC
6571
/// mode.
6672
DenseMap<MCSymbol *, StubValueTy> GVStubs;
6773

74+
DenseMap<MCSymbol *, AuthStubInfo> AuthPtrStubs;
75+
6876
virtual void anchor(); // Out of line virtual method.
6977

7078
public:
@@ -75,6 +83,11 @@ class MachineModuleInfoELF : public MachineModuleInfoImpl {
7583
return GVStubs[Sym];
7684
}
7785

86+
AuthStubInfo &getAuthPtrStubEntry(MCSymbol *Sym) {
87+
assert(Sym && "Key cannot be null");
88+
return AuthPtrStubs[Sym];
89+
}
90+
7891
/// Accessor methods to return the set of stubs in sorted order.
7992

8093
SymbolListTy GetGVStubList() { return getSortedStubs(GVStubs); }

llvm/include/llvm/Support/TargetOpcodes.def

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,9 @@ HANDLE_TARGET_OPCODE(G_FRAME_INDEX)
294294
/// Generic reference to global value.
295295
HANDLE_TARGET_OPCODE(G_GLOBAL_VALUE)
296296

297+
/// Generic ptrauth-signed reference to global value.
298+
HANDLE_TARGET_OPCODE(G_PTRAUTH_GLOBAL_VALUE)
299+
297300
/// Generic instruction to materialize the address of an object in the constant
298301
/// pool.
299302
HANDLE_TARGET_OPCODE(G_CONSTANT_POOL)

llvm/include/llvm/Target/GenericOpcodes.td

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,12 @@ def G_GLOBAL_VALUE : GenericInstruction {
110110
let hasSideEffects = false;
111111
}
112112

113+
def G_PTRAUTH_GLOBAL_VALUE : GenericInstruction {
114+
let OutOperandList = (outs type0:$dst);
115+
let InOperandList = (ins unknown:$addr, i32imm:$key, type1:$addrdisc, i64imm:$disc);
116+
let hasSideEffects = 0;
117+
}
118+
113119
def G_CONSTANT_POOL : GenericInstruction {
114120
let OutOperandList = (outs type0:$dst);
115121
let InOperandList = (ins unknown:$src);

llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3490,7 +3490,16 @@ bool IRTranslator::translate(const Constant &C, Register Reg) {
34903490
EntryBuilder->buildConstant(Reg, 0);
34913491
else if (auto GV = dyn_cast<GlobalValue>(&C))
34923492
EntryBuilder->buildGlobalValue(Reg, GV);
3493-
else if (auto CAZ = dyn_cast<ConstantAggregateZero>(&C)) {
3493+
else if (auto CPA = dyn_cast<ConstantPtrAuth>(&C)) {
3494+
Register Addr = getOrCreateVReg(*CPA->getPointer());
3495+
Register AddrDisc = getOrCreateVReg(*CPA->getAddrDiscriminator());
3496+
EntryBuilder->buildInstr(TargetOpcode::G_PTRAUTH_GLOBAL_VALUE)
3497+
.addDef(Reg)
3498+
.addUse(Addr)
3499+
.addImm(CPA->getKey()->getZExtValue())
3500+
.addUse(AddrDisc)
3501+
.addImm(CPA->getDiscriminator()->getZExtValue());
3502+
} else if (auto CAZ = dyn_cast<ConstantAggregateZero>(&C)) {
34943503
if (!isa<FixedVectorType>(CAZ->getType()))
34953504
return false;
34963505
// Return the scalar if it is a <1 x Ty> vector.

llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1794,6 +1794,13 @@ SDValue SelectionDAGBuilder::getValueImpl(const Value *V) {
17941794
if (const GlobalValue *GV = dyn_cast<GlobalValue>(C))
17951795
return DAG.getGlobalAddress(GV, getCurSDLoc(), VT);
17961796

1797+
if (const ConstantPtrAuth *CPA = dyn_cast<ConstantPtrAuth>(C)) {
1798+
return DAG.getNode(ISD::PtrAuthGlobalAddress, getCurSDLoc(), VT,
1799+
getValue(CPA->getPointer()), getValue(CPA->getKey()),
1800+
getValue(CPA->getAddrDiscriminator()),
1801+
getValue(CPA->getDiscriminator()));
1802+
}
1803+
17971804
if (isa<ConstantPointerNull>(C)) {
17981805
unsigned AS = V->getType()->getPointerAddressSpace();
17991806
return DAG.getConstant(0, getCurSDLoc(),

llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ std::string SDNode::getOperationName(const SelectionDAG *G) const {
7575
}
7676
return "<<Unknown Node #" + utostr(getOpcode()) + ">>";
7777

78+
// clang-format off
7879
#ifndef NDEBUG
7980
case ISD::DELETED_NODE: return "<<Deleted Node!>>";
8081
#endif
@@ -124,6 +125,7 @@ std::string SDNode::getOperationName(const SelectionDAG *G) const {
124125
case ISD::ConstantFP: return "ConstantFP";
125126
case ISD::GlobalAddress: return "GlobalAddress";
126127
case ISD::GlobalTLSAddress: return "GlobalTLSAddress";
128+
case ISD::PtrAuthGlobalAddress: return "PtrAuthGlobalAddress";
127129
case ISD::FrameIndex: return "FrameIndex";
128130
case ISD::JumpTable: return "JumpTable";
129131
case ISD::JUMP_TABLE_DEBUG_INFO:
@@ -166,8 +168,6 @@ std::string SDNode::getOperationName(const SelectionDAG *G) const {
166168
return "OpaqueTargetConstant";
167169
return "TargetConstant";
168170

169-
// clang-format off
170-
171171
case ISD::TargetConstantFP: return "TargetConstantFP";
172172
case ISD::TargetGlobalAddress: return "TargetGlobalAddress";
173173
case ISD::TargetGlobalTLSAddress: return "TargetGlobalTLSAddress";

llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp

Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,13 @@ class AArch64AsmPrinter : public AsmPrinter {
131131
unsigned emitPtrauthDiscriminator(uint16_t Disc, unsigned AddrDisc,
132132
unsigned &InstsEmitted);
133133

134+
// Emit the sequence for LOADauthptrstatic
135+
void LowerLOADauthptrstatic(const MachineInstr &MI);
136+
137+
// Emit the sequence for LOADgotPAC/MOVaddrPAC (either GOT adrp-ldr or
138+
// adrp-add followed by PAC sign)
139+
void LowerMOVaddrPAC(const MachineInstr &MI);
140+
134141
/// tblgen'erated driver function for lowering simple MI->MC
135142
/// pseudo instructions.
136143
bool emitPseudoExpansionLowering(MCStreamer &OutStreamer,
@@ -1575,6 +1582,173 @@ void AArch64AsmPrinter::emitPtrauthBranch(const MachineInstr *MI) {
15751582
assert(STI->getInstrInfo()->getInstSizeInBytes(*MI) >= InstsEmitted * 4);
15761583
}
15771584

1585+
void AArch64AsmPrinter::LowerLOADauthptrstatic(const MachineInstr &MI) {
1586+
unsigned DstReg = MI.getOperand(0).getReg();
1587+
MachineOperand GAOp = MI.getOperand(1);
1588+
uint64_t KeyC = MI.getOperand(2).getImm();
1589+
assert(KeyC <= AArch64PACKey::LAST && "Key is out of range");
1590+
auto Key = (AArch64PACKey::ID)KeyC;
1591+
uint64_t Disc = MI.getOperand(3).getImm();
1592+
assert(isUInt<16>(Disc) && "Constant discriminator is too wide");
1593+
1594+
const MCSymbol *GASym = TM.getSymbol(GAOp.getGlobal());
1595+
uint64_t Offset = GAOp.getOffset();
1596+
1597+
// Emit instruction sequence like the following:
1598+
// ADRP x16, symbol$auth_ptr$key$disc
1599+
// LDR x16, [x16, :lo12:symbol$auth_ptr$key$disc]
1600+
//
1601+
// Where the $auth_ptr$ symbol is the stub slot containing the signed pointer
1602+
// to symbol.
1603+
assert(TM.getTargetTriple().isOSBinFormatELF() &&
1604+
"LOADauthptrstatic only implemented on ELF");
1605+
assert(Offset == 0 &&
1606+
"Non-zero offset for $auth_ptr$ stub slots is not supported");
1607+
1608+
const auto &TLOF =
1609+
static_cast<const AArch64_ELFTargetObjectFile &>(getObjFileLowering());
1610+
MCSymbol *AuthPtrStubSym =
1611+
TLOF.getAuthPtrSlotSymbol(TM, &MF->getMMI(), GASym, Key, Disc);
1612+
1613+
MachineOperand StubMOHi =
1614+
MachineOperand::CreateMCSymbol(AuthPtrStubSym, AArch64II::MO_PAGE);
1615+
MachineOperand StubMOLo = MachineOperand::CreateMCSymbol(
1616+
AuthPtrStubSym, AArch64II::MO_PAGEOFF | AArch64II::MO_NC);
1617+
MCOperand StubMCHi, StubMCLo;
1618+
1619+
MCInstLowering.lowerOperand(StubMOHi, StubMCHi);
1620+
MCInstLowering.lowerOperand(StubMOLo, StubMCLo);
1621+
1622+
EmitToStreamer(
1623+
*OutStreamer,
1624+
MCInstBuilder(AArch64::ADRP).addReg(DstReg).addOperand(StubMCHi));
1625+
1626+
EmitToStreamer(*OutStreamer, MCInstBuilder(AArch64::LDRXui)
1627+
.addReg(DstReg)
1628+
.addReg(DstReg)
1629+
.addOperand(StubMCLo));
1630+
}
1631+
1632+
void AArch64AsmPrinter::LowerMOVaddrPAC(const MachineInstr &MI) {
1633+
unsigned InstsEmitted = 0;
1634+
1635+
const bool IsGOTLoad = MI.getOpcode() == AArch64::LOADgotPAC;
1636+
MachineOperand GAOp = MI.getOperand(0);
1637+
uint64_t KeyC = MI.getOperand(1).getImm();
1638+
assert(KeyC <= AArch64PACKey::LAST && "Key is out of range");
1639+
auto Key = (AArch64PACKey::ID)KeyC;
1640+
unsigned AddrDisc = MI.getOperand(2).getReg();
1641+
uint64_t Disc = MI.getOperand(3).getImm();
1642+
assert(isUInt<16>(Disc) && "Constant discriminator is too wide");
1643+
1644+
uint64_t Offset = GAOp.getOffset();
1645+
GAOp.setOffset(0);
1646+
1647+
// Emit:
1648+
// target materialization:
1649+
// via GOT:
1650+
// adrp x16, :got:target
1651+
// ldr x16, [x16, :got_lo12:target]
1652+
// add x16, x16, #<offset> ; if offset != 0; up to 3 depending on width
1653+
//
1654+
// direct:
1655+
// adrp x16, target
1656+
// add x16, x16, :lo12:target
1657+
// add x16, x16, #<offset> ; if offset != 0; up to 3 depending on width
1658+
//
1659+
// signing:
1660+
// - 0 discriminator:
1661+
// paciza x16
1662+
// - Non-0 discriminator, no address discriminator:
1663+
// mov x17, #Disc
1664+
// pacia x16, x17
1665+
// - address discriminator (with potentially folded immediate discriminator):
1666+
// pacia x16, xAddrDisc
1667+
1668+
MachineOperand GAMOHi(GAOp), GAMOLo(GAOp);
1669+
MCOperand GAMCHi, GAMCLo;
1670+
1671+
GAMOHi.setTargetFlags(AArch64II::MO_PAGE);
1672+
GAMOLo.setTargetFlags(AArch64II::MO_PAGEOFF | AArch64II::MO_NC);
1673+
if (IsGOTLoad) {
1674+
GAMOHi.addTargetFlag(AArch64II::MO_GOT);
1675+
GAMOLo.addTargetFlag(AArch64II::MO_GOT);
1676+
}
1677+
1678+
MCInstLowering.lowerOperand(GAMOHi, GAMCHi);
1679+
MCInstLowering.lowerOperand(GAMOLo, GAMCLo);
1680+
1681+
EmitToStreamer(
1682+
*OutStreamer,
1683+
MCInstBuilder(AArch64::ADRP).addReg(AArch64::X16).addOperand(GAMCHi));
1684+
++InstsEmitted;
1685+
1686+
if (IsGOTLoad) {
1687+
EmitToStreamer(*OutStreamer, MCInstBuilder(AArch64::LDRXui)
1688+
.addReg(AArch64::X16)
1689+
.addReg(AArch64::X16)
1690+
.addOperand(GAMCLo));
1691+
++InstsEmitted;
1692+
} else {
1693+
EmitToStreamer(*OutStreamer, MCInstBuilder(AArch64::ADDXri)
1694+
.addReg(AArch64::X16)
1695+
.addReg(AArch64::X16)
1696+
.addOperand(GAMCLo)
1697+
.addImm(0));
1698+
++InstsEmitted;
1699+
}
1700+
1701+
if (Offset) {
1702+
if (!isUInt<32>(Offset))
1703+
report_fatal_error("ptrauth global offset too large, 32bit max encoding");
1704+
1705+
for (int BitPos = 0; BitPos < 32 && (Offset >> BitPos); BitPos += 12) {
1706+
EmitToStreamer(*OutStreamer, MCInstBuilder(AArch64::ADDXri)
1707+
.addReg(AArch64::X16)
1708+
.addReg(AArch64::X16)
1709+
.addImm((Offset >> BitPos) & 0xfff)
1710+
.addImm(AArch64_AM::getShifterImm(
1711+
AArch64_AM::LSL, BitPos)));
1712+
++InstsEmitted;
1713+
}
1714+
}
1715+
1716+
unsigned DiscReg = AddrDisc;
1717+
if (Disc) {
1718+
if (AddrDisc != AArch64::XZR) {
1719+
EmitToStreamer(*OutStreamer, MCInstBuilder(AArch64::ORRXrs)
1720+
.addReg(AArch64::X17)
1721+
.addReg(AArch64::XZR)
1722+
.addReg(AddrDisc)
1723+
.addImm(0));
1724+
++InstsEmitted;
1725+
EmitToStreamer(*OutStreamer, MCInstBuilder(AArch64::MOVKXi)
1726+
.addReg(AArch64::X17)
1727+
.addReg(AArch64::X17)
1728+
.addImm(Disc)
1729+
.addImm(/*shift=*/48));
1730+
++InstsEmitted;
1731+
} else {
1732+
EmitToStreamer(*OutStreamer, MCInstBuilder(AArch64::MOVZXi)
1733+
.addReg(AArch64::X17)
1734+
.addImm(Disc)
1735+
.addImm(/*shift=*/0));
1736+
++InstsEmitted;
1737+
}
1738+
DiscReg = AArch64::X17;
1739+
}
1740+
1741+
auto MIB = MCInstBuilder(getPACOpcodeForKey(Key, DiscReg == AArch64::XZR))
1742+
.addReg(AArch64::X16)
1743+
.addReg(AArch64::X16);
1744+
if (DiscReg != AArch64::XZR)
1745+
MIB.addReg(DiscReg);
1746+
EmitToStreamer(*OutStreamer, MIB);
1747+
++InstsEmitted;
1748+
1749+
assert(STI->getInstrInfo()->getInstSizeInBytes(MI) >= InstsEmitted * 4);
1750+
}
1751+
15781752
// Simple pseudo-instructions have their lowering (with expansion to real
15791753
// instructions) auto-generated.
15801754
#include "AArch64GenMCPseudoLowering.inc"
@@ -1710,6 +1884,15 @@ void AArch64AsmPrinter::emitInstruction(const MachineInstr *MI) {
17101884
return;
17111885
}
17121886

1887+
case AArch64::LOADauthptrstatic:
1888+
LowerLOADauthptrstatic(*MI);
1889+
return;
1890+
1891+
case AArch64::LOADgotPAC:
1892+
case AArch64::MOVaddrPAC:
1893+
LowerMOVaddrPAC(*MI);
1894+
return;
1895+
17131896
case AArch64::BLRA:
17141897
emitPtrauthBranch(MI);
17151898
return;

0 commit comments

Comments
 (0)