Skip to content

[AArch64][PAC] Lower direct authenticated calls to ptrauth constants. #97664

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions llvm/lib/CodeGen/GlobalISel/CallLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,14 @@ bool CallLowering::lowerCall(MachineIRBuilder &MIRBuilder, const CallBase &CB,
// Try looking through a bitcast from one function type to another.
// Commonly happens with calls to objc_msgSend().
const Value *CalleeV = CB.getCalledOperand()->stripPointerCasts();

// If IRTranslator chose to drop the ptrauth info, we can turn this into
// a direct call.
if (!PAI && CB.countOperandBundlesOfType(LLVMContext::OB_ptrauth)) {
CalleeV = cast<ConstantPtrAuth>(CalleeV)->getPointer();
assert(isa<Function>(CalleeV));
}

if (const Function *F = dyn_cast<Function>(CalleeV)) {
if (F->hasFnAttribute(Attribute::NonLazyBind)) {
LLT Ty = getLLTForType(*F->getType(), DL);
Expand Down
23 changes: 15 additions & 8 deletions llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2649,17 +2649,24 @@ bool IRTranslator::translateCallBase(const CallBase &CB,
}

std::optional<CallLowering::PtrAuthInfo> PAI;
if (CB.countOperandBundlesOfType(LLVMContext::OB_ptrauth)) {
if (auto Bundle = CB.getOperandBundle(LLVMContext::OB_ptrauth)) {
// Functions should never be ptrauth-called directly.
assert(!CB.getCalledFunction() && "invalid direct ptrauth call");

auto PAB = CB.getOperandBundle("ptrauth");
const Value *Key = PAB->Inputs[0];
const Value *Discriminator = PAB->Inputs[1];

Register DiscReg = getOrCreateVReg(*Discriminator);
PAI = CallLowering::PtrAuthInfo{cast<ConstantInt>(Key)->getZExtValue(),
DiscReg};
const Value *Key = Bundle->Inputs[0];
const Value *Discriminator = Bundle->Inputs[1];

// Look through ptrauth constants to try to eliminate the matching bundle
// and turn this into a direct call with no ptrauth.
// CallLowering will use the raw pointer if it doesn't find the PAI.
const auto *CalleeCPA = dyn_cast<ConstantPtrAuth>(CB.getCalledOperand());
if (!CalleeCPA || !isa<Function>(CalleeCPA->getPointer()) ||
!CalleeCPA->isKnownCompatibleWith(Key, Discriminator, *DL)) {
// If we can't make it direct, package the bundle into PAI.
Register DiscReg = getOrCreateVReg(*Discriminator);
PAI = CallLowering::PtrAuthInfo{cast<ConstantInt>(Key)->getZExtValue(),
DiscReg};
}
}

Register ConvergenceCtrlToken = 0;
Expand Down
8 changes: 8 additions & 0 deletions llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9454,6 +9454,14 @@ void SelectionDAGBuilder::LowerCallSiteWithPtrAuthBundle(
assert(Discriminator->getType()->isIntegerTy(64) &&
"Invalid ptrauth discriminator");

// Look through ptrauth constants to find the raw callee.
// Do a direct unauthenticated call if we found it and everything matches.
if (const auto *CalleeCPA = dyn_cast<ConstantPtrAuth>(CalleeV))
if (CalleeCPA->isKnownCompatibleWith(Key, Discriminator,
DAG.getDataLayout()))
return LowerCallTo(CB, getValue(CalleeCPA->getPointer()), CB.isTailCall(),
CB.isMustTailCall(), EHPadBB);

// Functions should never be ptrauth-called directly.
assert(!isa<Function>(CalleeV) && "invalid direct ptrauth call");

Expand Down
132 changes: 132 additions & 0 deletions llvm/test/CodeGen/AArch64/ptrauth-call.ll
Original file line number Diff line number Diff line change
Expand Up @@ -269,4 +269,136 @@ define i32 @test_tailcall_ib_arg_ind(ptr %arg0, i64 %arg1) #0 {
ret i32 %tmp1
}

; Test direct calls

define i32 @test_direct_call() #0 {
; DARWIN-LABEL: test_direct_call:
; DARWIN-NEXT: stp x29, x30, [sp, #-16]!
; DARWIN-NEXT: bl _f
; DARWIN-NEXT: ldp x29, x30, [sp], #16
; DARWIN-NEXT: ret
;
; ELF-LABEL: test_direct_call:
; ELF-NEXT: str x30, [sp, #-16]!
; ELF-NEXT: bl f
; ELF-NEXT: ldr x30, [sp], #16
; ELF-NEXT: ret
%tmp0 = call i32 ptrauth(ptr @f, i32 0, i64 42)() [ "ptrauth"(i32 0, i64 42) ]
ret i32 %tmp0
}

define i32 @test_direct_tailcall(ptr %arg0) #0 {
; DARWIN-LABEL: test_direct_tailcall:
; DARWIN: b _f
;
; ELF-LABEL: test_direct_tailcall:
; ELF-NEXT: b f
%tmp0 = tail call i32 ptrauth(ptr @f, i32 0, i64 42)() [ "ptrauth"(i32 0, i64 42) ]
ret i32 %tmp0
}

define i32 @test_direct_call_mismatch() #0 {
; DARWIN-LABEL: test_direct_call_mismatch:
; DARWIN-NEXT: stp x29, x30, [sp, #-16]!
; DARWIN-NEXT: adrp x16, _f@GOTPAGE
; DARWIN-NEXT: ldr x16, [x16, _f@GOTPAGEOFF]
; DARWIN-NEXT: mov x17, #42
; DARWIN-NEXT: pacia x16, x17
; DARWIN-NEXT: mov x8, x16
; DARWIN-NEXT: mov x17, #42
; DARWIN-NEXT: blrab x8, x17
; DARWIN-NEXT: ldp x29, x30, [sp], #16
; DARWIN-NEXT: ret
;
; ELF-LABEL: test_direct_call_mismatch:
; ELF-NEXT: str x30, [sp, #-16]!
; ELF-NEXT: adrp x16, :got:f
; ELF-NEXT: ldr x16, [x16, :got_lo12:f]
; ELF-NEXT: mov x17, #42
; ELF-NEXT: pacia x16, x17
; ELF-NEXT: mov x8, x16
; ELF-NEXT: mov x17, #42
; ELF-NEXT: blrab x8, x17
; ELF-NEXT: ldr x30, [sp], #16
; ELF-NEXT: ret
%tmp0 = call i32 ptrauth(ptr @f, i32 0, i64 42)() [ "ptrauth"(i32 1, i64 42) ]
ret i32 %tmp0
}

define i32 @test_direct_call_addr() #0 {
; DARWIN-LABEL: test_direct_call_addr:
; DARWIN-NEXT: stp x29, x30, [sp, #-16]!
; DARWIN-NEXT: bl _f
; DARWIN-NEXT: ldp x29, x30, [sp], #16
; DARWIN-NEXT: ret
;
; ELF-LABEL: test_direct_call_addr:
; ELF-NEXT: str x30, [sp, #-16]!
; ELF-NEXT: bl f
; ELF-NEXT: ldr x30, [sp], #16
; ELF-NEXT: ret
%tmp0 = call i32 ptrauth(ptr @f, i32 1, i64 0, ptr @f.ref.ib.0.addr)() [ "ptrauth"(i32 1, i64 ptrtoint (ptr @f.ref.ib.0.addr to i64)) ]
ret i32 %tmp0
}

define i32 @test_direct_call_addr_blend() #0 {
; DARWIN-LABEL: test_direct_call_addr_blend:
; DARWIN-NEXT: stp x29, x30, [sp, #-16]!
; DARWIN-NEXT: bl _f
; DARWIN-NEXT: ldp x29, x30, [sp], #16
; DARWIN-NEXT: ret
;
; ELF-LABEL: test_direct_call_addr_blend:
; ELF-NEXT: str x30, [sp, #-16]!
; ELF-NEXT: bl f
; ELF-NEXT: ldr x30, [sp], #16
; ELF-NEXT: ret
%tmp0 = call i64 @llvm.ptrauth.blend(i64 ptrtoint (ptr @f.ref.ib.42.addr to i64), i64 42)
%tmp1 = call i32 ptrauth(ptr @f, i32 1, i64 42, ptr @f.ref.ib.42.addr)() [ "ptrauth"(i32 1, i64 %tmp0) ]
ret i32 %tmp1
}

define i32 @test_direct_call_addr_gep_different_index_types() #0 {
; DARWIN-LABEL: test_direct_call_addr_gep_different_index_types:
; DARWIN-NEXT: stp x29, x30, [sp, #-16]!
; DARWIN-NEXT: bl _f
; DARWIN-NEXT: ldp x29, x30, [sp], #16
; DARWIN-NEXT: ret
;
; ELF-LABEL: test_direct_call_addr_gep_different_index_types:
; ELF-NEXT: str x30, [sp, #-16]!
; ELF-NEXT: bl f
; ELF-NEXT: ldr x30, [sp], #16
; ELF-NEXT: ret
%tmp0 = call i32 ptrauth(ptr @f, i32 1, i64 0, ptr getelementptr ({ ptr }, ptr @f_struct.ref.ib.0.addr, i64 0, i32 0))() [ "ptrauth"(i32 1, i64 ptrtoint (ptr getelementptr ({ ptr }, ptr @f_struct.ref.ib.0.addr, i32 0, i32 0) to i64)) ]
ret i32 %tmp0
}

define i32 @test_direct_call_addr_blend_gep_different_index_types() #0 {
; DARWIN-LABEL: test_direct_call_addr_blend_gep_different_index_types:
; DARWIN-NEXT: stp x29, x30, [sp, #-16]!
; DARWIN-NEXT: bl _f
; DARWIN-NEXT: ldp x29, x30, [sp], #16
; DARWIN-NEXT: ret
;
; ELF-LABEL: test_direct_call_addr_blend_gep_different_index_types:
; ELF-NEXT: str x30, [sp, #-16]!
; ELF-NEXT: bl f
; ELF-NEXT: ldr x30, [sp], #16
; ELF-NEXT: ret
%tmp0 = call i64 @llvm.ptrauth.blend(i64 ptrtoint (ptr getelementptr ({ ptr }, ptr @f_struct.ref.ib.123.addr, i32 0, i32 0) to i64), i64 123)
%tmp1 = call i32 ptrauth(ptr @f, i32 1, i64 123, ptr getelementptr ({ ptr }, ptr @f_struct.ref.ib.123.addr, i64 0, i32 0))() [ "ptrauth"(i32 1, i64 %tmp0) ]
ret i32 %tmp1
}

@f.ref.ib.42.addr = external global ptr
@f.ref.ib.0.addr = external global ptr
@f_struct.ref.ib.0.addr = external global ptr
@f_struct.ref.ib.123.addr = external global ptr

declare void @f()

declare i64 @llvm.ptrauth.auth(i64, i32, i64)
declare i64 @llvm.ptrauth.blend(i64, i64)

attributes #0 = { nounwind }
Loading
Loading