Skip to content

Commit 6b7eb51

Browse files
committed
[AArch64][PAC] Lower jump-tables using hardened pseudo.
1 parent 507b0f6 commit 6b7eb51

File tree

5 files changed

+248
-1
lines changed

5 files changed

+248
-1
lines changed

llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,8 @@ class AArch64AsmPrinter : public AsmPrinter {
104104

105105
void LowerJumpTableDest(MCStreamer &OutStreamer, const MachineInstr &MI);
106106

107+
void LowerHardenedBRJumpTable(const MachineInstr &MI);
108+
107109
void LowerMOPS(MCStreamer &OutStreamer, const MachineInstr &MI);
108110

109111
void LowerSTACKMAP(MCStreamer &OutStreamer, StackMaps &SM,
@@ -1310,6 +1312,141 @@ void AArch64AsmPrinter::LowerJumpTableDest(llvm::MCStreamer &OutStreamer,
13101312
.addImm(Size == 4 ? 0 : 2));
13111313
}
13121314

1315+
void AArch64AsmPrinter::LowerHardenedBRJumpTable(const MachineInstr &MI) {
1316+
unsigned InstsEmitted = 0;
1317+
1318+
const MachineJumpTableInfo *MJTI = MF->getJumpTableInfo();
1319+
assert(MJTI && "Can't lower jump-table dispatch without JTI");
1320+
1321+
const std::vector<MachineJumpTableEntry> &JTs = MJTI->getJumpTables();
1322+
assert(!JTs.empty() && "Invalid JT index for jump-table dispatch");
1323+
1324+
// Emit:
1325+
// mov x17, #<size of table> ; depending on table size, with MOVKs
1326+
// cmp x16, x17 ; or #imm if table size fits in 12-bit
1327+
// csel x16, x16, xzr, ls ; check for index overflow
1328+
//
1329+
// adrp x17, Ltable@PAGE ; materialize table address
1330+
// add x17, Ltable@PAGEOFF
1331+
// ldrsw x16, [x17, x16, lsl #2] ; load table entry
1332+
//
1333+
// Lanchor:
1334+
// adr x17, Lanchor ; compute target address
1335+
// add x16, x17, x16
1336+
// br x16 ; branch to target
1337+
1338+
MachineOperand JTOp = MI.getOperand(0);
1339+
1340+
unsigned JTI = JTOp.getIndex();
1341+
assert(!AArch64FI->getJumpTableEntryPCRelSymbol(JTI) &&
1342+
"unsupported compressed jump table");
1343+
1344+
const uint64_t NumTableEntries = JTs[JTI].MBBs.size();
1345+
1346+
// cmp only supports a 12-bit immediate. If we need more, materialize the
1347+
// immediate, using x17 as a scratch register.
1348+
uint64_t MaxTableEntry = NumTableEntries - 1;
1349+
if (isUInt<12>(MaxTableEntry)) {
1350+
EmitToStreamer(*OutStreamer, MCInstBuilder(AArch64::SUBSXri)
1351+
.addReg(AArch64::XZR)
1352+
.addReg(AArch64::X16)
1353+
.addImm(MaxTableEntry)
1354+
.addImm(0));
1355+
++InstsEmitted;
1356+
} else {
1357+
EmitToStreamer(*OutStreamer,
1358+
MCInstBuilder(AArch64::MOVZXi)
1359+
.addReg(AArch64::X17)
1360+
.addImm(static_cast<uint16_t>(MaxTableEntry))
1361+
.addImm(0));
1362+
++InstsEmitted;
1363+
// It's sad that we have to manually materialize instructions, but we can't
1364+
// trivially reuse the main pseudo expansion logic.
1365+
// A MOVK sequence is easy enough to generate and handles the general case.
1366+
for (int Offset = 16; Offset < 64; Offset += 16) {
1367+
if ((MaxTableEntry >> Offset) == 0)
1368+
break;
1369+
EmitToStreamer(*OutStreamer,
1370+
MCInstBuilder(AArch64::MOVKXi)
1371+
.addReg(AArch64::X17)
1372+
.addReg(AArch64::X17)
1373+
.addImm(static_cast<uint16_t>(MaxTableEntry >> Offset))
1374+
.addImm(Offset));
1375+
++InstsEmitted;
1376+
}
1377+
EmitToStreamer(*OutStreamer, MCInstBuilder(AArch64::SUBSXrs)
1378+
.addReg(AArch64::XZR)
1379+
.addReg(AArch64::X16)
1380+
.addReg(AArch64::X17)
1381+
.addImm(0));
1382+
++InstsEmitted;
1383+
}
1384+
1385+
// This picks entry #0 on failure.
1386+
// We might want to trap instead.
1387+
EmitToStreamer(*OutStreamer, MCInstBuilder(AArch64::CSELXr)
1388+
.addReg(AArch64::X16)
1389+
.addReg(AArch64::X16)
1390+
.addReg(AArch64::XZR)
1391+
.addImm(AArch64CC::LS));
1392+
++InstsEmitted;
1393+
1394+
// Prepare the @PAGE/@PAGEOFF low/high operands.
1395+
MachineOperand JTMOHi(JTOp), JTMOLo(JTOp);
1396+
MCOperand JTMCHi, JTMCLo;
1397+
1398+
JTMOHi.setTargetFlags(AArch64II::MO_PAGE);
1399+
JTMOLo.setTargetFlags(AArch64II::MO_PAGEOFF | AArch64II::MO_NC);
1400+
1401+
MCInstLowering.lowerOperand(JTMOHi, JTMCHi);
1402+
MCInstLowering.lowerOperand(JTMOLo, JTMCLo);
1403+
1404+
EmitToStreamer(*OutStreamer, MCInstBuilder(AArch64::ADRP)
1405+
.addReg(AArch64::X17)
1406+
.addOperand(JTMCHi));
1407+
++InstsEmitted;
1408+
1409+
EmitToStreamer(*OutStreamer, MCInstBuilder(AArch64::ADDXri)
1410+
.addReg(AArch64::X17)
1411+
.addReg(AArch64::X17)
1412+
.addOperand(JTMCLo)
1413+
.addImm(0));
1414+
++InstsEmitted;
1415+
1416+
1417+
EmitToStreamer(*OutStreamer, MCInstBuilder(AArch64::LDRSWroX)
1418+
.addReg(AArch64::X16)
1419+
.addReg(AArch64::X17)
1420+
.addReg(AArch64::X16)
1421+
.addImm(0)
1422+
.addImm(1));
1423+
++InstsEmitted;
1424+
1425+
MCSymbol *AdrLabel = MF->getContext().createTempSymbol();
1426+
auto *AdrLabelE = MCSymbolRefExpr::create(AdrLabel, MF->getContext());
1427+
AArch64FI->setJumpTableEntryInfo(JTI, 4, AdrLabel);
1428+
1429+
OutStreamer->emitLabel(AdrLabel);
1430+
EmitToStreamer(*OutStreamer, MCInstBuilder(AArch64::ADR)
1431+
.addReg(AArch64::X17)
1432+
.addExpr(AdrLabelE));
1433+
++InstsEmitted;
1434+
1435+
EmitToStreamer(*OutStreamer, MCInstBuilder(AArch64::ADDXrs)
1436+
.addReg(AArch64::X16)
1437+
.addReg(AArch64::X17)
1438+
.addReg(AArch64::X16)
1439+
.addImm(0));
1440+
++InstsEmitted;
1441+
1442+
EmitToStreamer(*OutStreamer, MCInstBuilder(AArch64::BR)
1443+
.addReg(AArch64::X16));
1444+
++InstsEmitted;
1445+
1446+
assert(STI->getInstrInfo()->getInstSizeInBytes(MI) >= InstsEmitted * 4);
1447+
}
1448+
1449+
13131450
void AArch64AsmPrinter::LowerMOPS(llvm::MCStreamer &OutStreamer,
13141451
const llvm::MachineInstr &MI) {
13151452
unsigned Opcode = MI.getOpcode();
@@ -2177,6 +2314,10 @@ void AArch64AsmPrinter::emitInstruction(const MachineInstr *MI) {
21772314
LowerJumpTableDest(*OutStreamer, *MI);
21782315
return;
21792316

2317+
case AArch64::BR_JumpTable:
2318+
LowerHardenedBRJumpTable(*MI);
2319+
return;
2320+
21802321
case AArch64::FMOVH0:
21812322
case AArch64::FMOVS0:
21822323
case AArch64::FMOVD0:

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10678,6 +10678,21 @@ SDValue AArch64TargetLowering::LowerBR_JT(SDValue Op,
1067810678
auto *AFI = DAG.getMachineFunction().getInfo<AArch64FunctionInfo>();
1067910679
AFI->setJumpTableEntryInfo(JTI, 4, nullptr);
1068010680

10681+
// With jump-table-hardening, we only expand the full jump table dispatch
10682+
// sequence later, to guarantee the integrity of the intermediate values.
10683+
if (DAG.getMachineFunction().getFunction()
10684+
.hasFnAttribute("jump-table-hardening") ||
10685+
Subtarget->getTargetTriple().isArm64e()) {
10686+
assert(Subtarget->isTargetMachO() &&
10687+
"hardened jump-table not yet supported on non-macho");
10688+
SDValue X16Copy = DAG.getCopyToReg(DAG.getEntryNode(), DL, AArch64::X16,
10689+
Entry, SDValue());
10690+
SDNode *B = DAG.getMachineNode(AArch64::BR_JumpTable, DL, MVT::Other,
10691+
DAG.getTargetJumpTable(JTI, MVT::i32),
10692+
X16Copy.getValue(0), X16Copy.getValue(1));
10693+
return SDValue(B, 0);
10694+
}
10695+
1068110696
SDNode *Dest =
1068210697
DAG.getMachineNode(AArch64::JumpTableDest32, DL, MVT::i64, MVT::i64, JT,
1068310698
Entry, DAG.getTargetJumpTable(JTI, MVT::i32));

llvm/lib/Target/AArch64/AArch64InstrInfo.td

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1143,6 +1143,32 @@ def JumpTableDest8 : Pseudo<(outs GPR64:$dst, GPR64sp:$scratch),
11431143
Sched<[]>;
11441144
}
11451145

1146+
// A hardened but more expensive version of jump-table dispatch.
1147+
// This combines the target address computation (otherwise done using the
1148+
// JumpTableDest pseudos above) with the branch itself (otherwise done using
1149+
// a plain BR) in a single non-attackable sequence.
1150+
//
1151+
// We take the final entry index as an operand to allow isel freedom. This does
1152+
// mean that the index can be attacker-controlled. To address that, we also do
1153+
// limited checking of the offset, mainly ensuring it still points within the
1154+
// jump-table array. When it doesn't, this branches to the first entry.
1155+
//
1156+
// This is intended for use in conjunction with ptrauth for other code pointers,
1157+
// to avoid signing jump-table entries and turning them into pointers.
1158+
//
1159+
// Entry index is passed in x16. Clobbers x16/x17/nzcv.
1160+
let isNotDuplicable = 1 in
1161+
def BR_JumpTable : Pseudo<(outs), (ins i32imm:$jti), []>, Sched<[]> {
1162+
let isBranch = 1;
1163+
let isTerminator = 1;
1164+
let isIndirectBranch = 1;
1165+
let isBarrier = 1;
1166+
let isNotDuplicable = 1;
1167+
let Defs = [X16,X17,NZCV];
1168+
let Uses = [X16];
1169+
let Size = 44; // 28 fixed + 16 variable, for table size materialization
1170+
}
1171+
11461172
// Space-consuming pseudo to aid testing of placement and reachability
11471173
// algorithms. Immediate operand is the number of bytes this "instruction"
11481174
// occupies; register operands can be used to enforce dependency and constrain

llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3597,10 +3597,22 @@ bool AArch64InstructionSelector::selectBrJT(MachineInstr &I,
35973597
unsigned JTI = I.getOperand(1).getIndex();
35983598
Register Index = I.getOperand(2).getReg();
35993599

3600+
MF->getInfo<AArch64FunctionInfo>()->setJumpTableEntryInfo(JTI, 4, nullptr);
3601+
if (MF->getFunction().hasFnAttribute("jump-table-hardening") ||
3602+
STI.getTargetTriple().isArm64e()) {
3603+
if (TM.getCodeModel() != CodeModel::Small)
3604+
report_fatal_error("Unsupported code-model for hardened jump-table");
3605+
3606+
MIB.buildCopy({AArch64::X16}, I.getOperand(2).getReg());
3607+
MIB.buildInstr(AArch64::BR_JumpTable)
3608+
.addJumpTableIndex(I.getOperand(1).getIndex());
3609+
I.eraseFromParent();
3610+
return true;
3611+
}
3612+
36003613
Register TargetReg = MRI.createVirtualRegister(&AArch64::GPR64RegClass);
36013614
Register ScratchReg = MRI.createVirtualRegister(&AArch64::GPR64spRegClass);
36023615

3603-
MF->getInfo<AArch64FunctionInfo>()->setJumpTableEntryInfo(JTI, 4, nullptr);
36043616
auto JumpTableInst = MIB.buildInstr(AArch64::JumpTableDest32,
36053617
{TargetReg, ScratchReg}, {JTAddr, Index})
36063618
.addJumpTableIndex(JTI);
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
; RUN: llc -verify-machineinstrs -o - %s -mtriple=arm64-apple-ios -aarch64-min-jump-table-entries=1 -aarch64-enable-atomic-cfg-tidy=0 | FileCheck %s
2+
; RUN: llc -verify-machineinstrs -o - %s -mtriple=arm64-apple-ios -aarch64-min-jump-table-entries=1 -aarch64-enable-atomic-cfg-tidy=0 -code-model=large | FileCheck %s
3+
; RUN: llc -verify-machineinstrs -o - %s -mtriple=arm64-apple-ios -aarch64-min-jump-table-entries=1 -aarch64-enable-atomic-cfg-tidy=0 -global-isel -global-isel-abort=1 | FileCheck %s
4+
5+
; CHECK-LABEL: test_jumptable:
6+
; CHECK: mov w[[INDEX:[0-9]+]], w0
7+
; CHECK: cmp x[[INDEX]], #5
8+
; CHECK: csel [[INDEX2:x[0-9]+]], x[[INDEX]], xzr, ls
9+
; CHECK-NEXT: adrp [[JTPAGE:x[0-9]+]], LJTI0_0@PAGE
10+
; CHECK-NEXT: add x[[JT:[0-9]+]], [[JTPAGE]], LJTI0_0@PAGEOFF
11+
; CHECK-NEXT: ldrsw [[OFFSET:x[0-9]+]], [x[[JT]], [[INDEX2]], lsl #2]
12+
; CHECK-NEXT: Ltmp0:
13+
; CHECK-NEXT: adr [[TABLE:x[0-9]+]], Ltmp0
14+
; CHECK-NEXT: add [[DEST:x[0-9]+]], [[TABLE]], [[OFFSET]]
15+
; CHECK-NEXT: br [[DEST]]
16+
17+
define i32 @test_jumptable(i32 %in) "jump-table-hardening" {
18+
19+
switch i32 %in, label %def [
20+
i32 0, label %lbl1
21+
i32 1, label %lbl2
22+
i32 2, label %lbl3
23+
i32 4, label %lbl4
24+
i32 5, label %lbl5
25+
]
26+
27+
def:
28+
ret i32 0
29+
30+
lbl1:
31+
ret i32 1
32+
33+
lbl2:
34+
ret i32 2
35+
36+
lbl3:
37+
ret i32 4
38+
39+
lbl4:
40+
ret i32 8
41+
42+
lbl5:
43+
ret i32 10
44+
45+
}
46+
47+
; CHECK: LJTI0_0:
48+
; CHECK-NEXT: .long LBB{{[0-9_]+}}-Ltmp0
49+
; CHECK-NEXT: .long LBB{{[0-9_]+}}-Ltmp0
50+
; CHECK-NEXT: .long LBB{{[0-9_]+}}-Ltmp0
51+
; CHECK-NEXT: .long LBB{{[0-9_]+}}-Ltmp0
52+
; CHECK-NEXT: .long LBB{{[0-9_]+}}-Ltmp0
53+
; CHECK-NEXT: .long LBB{{[0-9_]+}}-Ltmp0

0 commit comments

Comments
 (0)