Skip to content

[PAC][IR][AArch64] Add "ptrauth(...)" Constant to represent signed pointers. #85738

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 28, 2024

Conversation

ahmedbougacha
Copy link
Member

@ahmedbougacha ahmedbougacha commented Mar 19, 2024

This defines a new kind of IR Constant that represents a ptrauth signed pointer, as used in AArch64 PAuth.

It allows representing most kinds of signed pointer constants used thus far in the llvm ptrauth implementations, notably those used in the Darwin and ELF ABIs being implemented for c/c++. These signed pointer constants are then lowered to ELF/MachO relocations.

These can be simply thought of as a constant llvm.ptrauth.sign, with the interesting addition of discriminator computation: the ptrauth constant can also represent a combined blend, when both address and integer discriminator operands are used.

This also teaches some of the most common constant folding and analysis paths about these, usually in a straightforward way. I have a couple fixmes to expand some of those here, as well as to add unittests for the ConstantPtrAuth methods.

@llvmbot
Copy link
Member

llvmbot commented Mar 19, 2024

@llvm/pr-subscribers-llvm-analysis

Author: Ahmed Bougacha (ahmedbougacha)

Changes

This defines a new kind of IR Constant that represents a ptrauth signed pointer, as used in AArch64 PAuth.

It allows representing most kinds of signed pointer constants used thus far in the llvm ptrauth implementations, notably those used in the Darwin and ELF ABIs being implemented for c/c++. These signed pointer constants are then lowered to ELF/MachO relocations.

These can be simply thought of as a constant llvm.ptrauth.sign, with the interesting addition of discriminator computation: the ptrauth constant can also represent a combined blend, when both address and integer discriminator operands are used.

This also teaches some of the most common constant folding and analysis paths about these, usually in a straightforward way. I have a couple fixmes to expand some of those here, as well as to add unittest for the ConstantPtrAuth methods.


Patch is 29.61 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/85738.diff

26 Files Affected:

  • (modified) llvm/docs/LangRef.rst (+27)
  • (modified) llvm/include/llvm-c/Core.h (+1)
  • (modified) llvm/include/llvm/AsmParser/LLToken.h (+1)
  • (modified) llvm/include/llvm/Bitcode/LLVMBitCodes.h (+2)
  • (modified) llvm/include/llvm/IR/Constants.h (+63)
  • (modified) llvm/include/llvm/IR/Value.def (+1)
  • (modified) llvm/lib/Analysis/ValueTracking.cpp (+7)
  • (modified) llvm/lib/AsmParser/LLLexer.cpp (+1)
  • (modified) llvm/lib/AsmParser/LLParser.cpp (+42)
  • (modified) llvm/lib/Bitcode/Reader/BitcodeReader.cpp (+29-1)
  • (modified) llvm/lib/Bitcode/Writer/BitcodeWriter.cpp (+8)
  • (modified) llvm/lib/IR/AsmWriter.cpp (+15)
  • (modified) llvm/lib/IR/ConstantFold.cpp (+3)
  • (modified) llvm/lib/IR/Constants.cpp (+115)
  • (modified) llvm/lib/IR/ConstantsContext.h (+47)
  • (modified) llvm/lib/IR/LLVMContextImpl.h (+2)
  • (modified) llvm/lib/IR/Verifier.cpp (+20)
  • (added) llvm/test/Assembler/invalid-ptrauth-const1.ll (+6)
  • (added) llvm/test/Assembler/invalid-ptrauth-const2.ll (+6)
  • (added) llvm/test/Assembler/invalid-ptrauth-const3.ll (+6)
  • (added) llvm/test/Assembler/invalid-ptrauth-const4.ll (+6)
  • (added) llvm/test/Assembler/invalid-ptrauth-const5.ll (+6)
  • (added) llvm/test/Assembler/invalid-ptrauth-const6.ll (+6)
  • (added) llvm/test/Assembler/ptrauth-const.ll (+13)
  • (modified) llvm/test/Bitcode/compatibility.ll (+4)
  • (modified) llvm/utils/vim/syntax/llvm.vim (+1)
diff --git a/llvm/docs/LangRef.rst b/llvm/docs/LangRef.rst
index e07b642285b3e6..0d91d4fc3ba1ef 100644
--- a/llvm/docs/LangRef.rst
+++ b/llvm/docs/LangRef.rst
@@ -4748,6 +4748,33 @@ reference to the CFI jump table in the ``LowerTypeTests`` pass. These constants
 may be useful in low-level programs, such as operating system kernels, which
 need to refer to the actual function body.
 
+.. _ptrauth
+
+Authenticated Pointers
+----------------------
+
+``ptrauth (ptr CST, i32 KEY, ptr ADDRDISC, i16 DISC)
+
+A '``ptrauth``' constant represents a pointer with a cryptographic
+authentication signature embedded into some bits. Its type is the same as the
+first argument.
+
+
+If the address disciminator is ``null`` then the expression is equivalent to
+
+.. code-block:llvm
+    %tmp = call i64 @llvm.ptrauth.sign.i64(i64 ptrtoint (ptr CST to i64), i32 KEY, i64 DISC)
+    %val = inttoptr i64 %tmp to ptr
+
+If the address discriminator is present, then it is
+
+.. code-block:llvm
+    %tmp1 = call i64 @llvm.ptrauth.blend.i64(i64 ptrtoint (ptr ADDRDISC to i64), i64 DISC)
+    %tmp2 = call i64 @llvm.ptrauth.sign.i64(i64 ptrtoint (ptr CST to i64), i64  %tmp1)
+    %val = inttoptr i64 %tmp2 to ptr
+
+    %tmp = call i64 @llvm.ptrauth.blend.i64
+
 .. _constantexprs:
 
 Constant Expressions
diff --git a/llvm/include/llvm-c/Core.h b/llvm/include/llvm-c/Core.h
index f56a6c961aad74..5f69a07fbed644 100644
--- a/llvm/include/llvm-c/Core.h
+++ b/llvm/include/llvm-c/Core.h
@@ -286,6 +286,7 @@ typedef enum {
   LLVMInstructionValueKind,
   LLVMPoisonValueValueKind,
   LLVMConstantTargetNoneValueKind,
+  LLVMConstantPtrAuthValueKind,
 } LLVMValueKind;
 
 typedef enum {
diff --git a/llvm/include/llvm/AsmParser/LLToken.h b/llvm/include/llvm/AsmParser/LLToken.h
index 5863a8d6e8ee84..e949023463f54d 100644
--- a/llvm/include/llvm/AsmParser/LLToken.h
+++ b/llvm/include/llvm/AsmParser/LLToken.h
@@ -343,6 +343,7 @@ enum Kind {
   kw_insertvalue,
   kw_blockaddress,
   kw_dso_local_equivalent,
+  kw_ptrauth,
   kw_no_cfi,
 
   kw_freeze,
diff --git a/llvm/include/llvm/Bitcode/LLVMBitCodes.h b/llvm/include/llvm/Bitcode/LLVMBitCodes.h
index 39303e64852141..747bd55c2a8c82 100644
--- a/llvm/include/llvm/Bitcode/LLVMBitCodes.h
+++ b/llvm/include/llvm/Bitcode/LLVMBitCodes.h
@@ -411,6 +411,8 @@ enum ConstantsCodes {
                               //                 sideeffect|alignstack|
                               //                 asmdialect|unwind,
                               //                 asmstr,conststr]
+  CST_CODE_SIGNED_PTR = 31,   // CE_SIGNED_PTR: [ptrty, ptr, key,
+                              //                 addrdiscty, addrdisc, disc]
 };
 
 /// CastOpcodes - These are values used in the bitcode files to encode which
diff --git a/llvm/include/llvm/IR/Constants.h b/llvm/include/llvm/IR/Constants.h
index c0ac9a4aa6750c..9cf53616cc921b 100644
--- a/llvm/include/llvm/IR/Constants.h
+++ b/llvm/include/llvm/IR/Constants.h
@@ -1006,6 +1006,69 @@ struct OperandTraits<NoCFIValue> : public FixedNumOperandTraits<NoCFIValue, 1> {
 
 DEFINE_TRANSPARENT_OPERAND_ACCESSORS(NoCFIValue, Value)
 
+/// A signed pointer
+///
+class ConstantPtrAuth final : public Constant {
+  friend struct ConstantPtrAuthKeyType;
+  friend class Constant;
+
+  ConstantPtrAuth(Constant *Ptr, ConstantInt *Key, Constant *AddrDisc,
+                  ConstantInt *Disc);
+
+  void *operator new(size_t s) { return User::operator new(s, 4); }
+
+  void destroyConstantImpl();
+  Value *handleOperandChangeImpl(Value *From, Value *To);
+
+public:
+  /// Return a pointer authenticated with the specified parameters.
+  static ConstantPtrAuth *get(Constant *Ptr, ConstantInt *Key,
+                              Constant *AddrDisc, ConstantInt *Disc);
+
+  /// Produce a new ptrauth expression signing the given value using
+  /// the same schema as is stored in one.
+  ConstantPtrAuth *getWithSameSchema(Constant *Pointer) const;
+
+  /// Transparently provide more efficient getOperand methods.
+  DECLARE_TRANSPARENT_OPERAND_ACCESSORS(Constant);
+
+  /// The pointer that is authenticated in this authenticated global reference.
+  Constant *getPointer() const { return (Constant *)Op<0>().get(); }
+
+  /// The Key ID, an i32 constant.
+  ConstantInt *getKey() const { return (ConstantInt *)Op<1>().get(); }
+
+  /// The address discriminator if any, or the null constant.
+  /// If present, this must be a value equivalent to the storage location of
+  /// the only user of the authenticated ptrauth global.
+  Constant *getAddrDiscriminator() const { return (Constant *)Op<2>().get(); }
+
+  /// The discriminator.
+  ConstantInt *getDiscriminator() const { return (ConstantInt *)Op<3>().get(); }
+
+  /// Whether there is any non-null address discriminator.
+  bool hasAddressDiversity() const {
+    return !getAddrDiscriminator()->isNullValue();
+  }
+
+  /// Check whether an authentication operation with key \p KeyV and (possibly
+  /// blended) discriminator \p DiscriminatorV is compatible with this
+  /// authenticated global reference.
+  bool isCompatibleWith(const Value *Key, const Value *Discriminator,
+                        const DataLayout &DL) const;
+
+  /// Methods for support type inquiry through isa, cast, and dyn_cast:
+  static bool classof(const Value *V) {
+    return V->getValueID() == ConstantPtrAuthVal;
+  }
+};
+
+template <>
+struct OperandTraits<ConstantPtrAuth>
+    : public FixedNumOperandTraits<ConstantPtrAuth, 4> {};
+
+DEFINE_TRANSPARENT_OPERAND_ACCESSORS(ConstantPtrAuth, Constant)
+
 //===----------------------------------------------------------------------===//
 /// A constant value that is initialized with an expression using
 /// other constant values.
diff --git a/llvm/include/llvm/IR/Value.def b/llvm/include/llvm/IR/Value.def
index 61f7a87666d094..31110ff05ae368 100644
--- a/llvm/include/llvm/IR/Value.def
+++ b/llvm/include/llvm/IR/Value.def
@@ -78,6 +78,7 @@ HANDLE_GLOBAL_VALUE(GlobalAlias)
 HANDLE_GLOBAL_VALUE(GlobalIFunc)
 HANDLE_GLOBAL_VALUE(GlobalVariable)
 HANDLE_CONSTANT(BlockAddress)
+HANDLE_CONSTANT(ConstantPtrAuth)
 HANDLE_CONSTANT(ConstantExpr)
 HANDLE_CONSTANT_EXCLUDE_LLVM_C_API(DSOLocalEquivalent)
 HANDLE_CONSTANT_EXCLUDE_LLVM_C_API(NoCFIValue)
diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index fe5d084b55bbe3..d0bdaca57e47f5 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -2900,6 +2900,10 @@ bool isKnownNonZero(const Value *V, const APInt &DemandedElts, unsigned Depth,
       return true;
     }
 
+    // Constant ptrauth can be null, iff the base pointer can be.
+    if (auto *CPA = dyn_cast<ConstantPtrAuth>(V))
+      return isKnownNonZero(CPA->getPointer(), DemandedElts, Depth, Q);
+
     // A global variable in address space 0 is non null unless extern weak
     // or an absolute symbol reference. Other address spaces may have null as a
     // valid address for a global, so we can't assume anything.
@@ -6993,6 +6997,9 @@ static bool isGuaranteedNotToBeUndefOrPoison(
         isa<ConstantPointerNull>(C) || isa<Function>(C))
       return true;
 
+    if (isa<ConstantPtrAuth>(C))
+      return true;
+
     if (C->getType()->isVectorTy() && !isa<ConstantExpr>(C))
       return (!includesUndef(Kind) ? !C->containsPoisonElement()
                                    : !C->containsUndefOrPoisonElement()) &&
diff --git a/llvm/lib/AsmParser/LLLexer.cpp b/llvm/lib/AsmParser/LLLexer.cpp
index 02f64fcfac4f0c..e37ee0bb90a82d 100644
--- a/llvm/lib/AsmParser/LLLexer.cpp
+++ b/llvm/lib/AsmParser/LLLexer.cpp
@@ -708,6 +708,7 @@ lltok::Kind LLLexer::LexIdentifier() {
   KEYWORD(blockaddress);
   KEYWORD(dso_local_equivalent);
   KEYWORD(no_cfi);
+  KEYWORD(ptrauth);
 
   // Metadata types.
   KEYWORD(distinct);
diff --git a/llvm/lib/AsmParser/LLParser.cpp b/llvm/lib/AsmParser/LLParser.cpp
index 2e0f5ba82220c9..21039f7efb9b27 100644
--- a/llvm/lib/AsmParser/LLParser.cpp
+++ b/llvm/lib/AsmParser/LLParser.cpp
@@ -3998,6 +3998,48 @@ bool LLParser::parseValID(ValID &ID, PerFunctionState *PFS, Type *ExpectedTy) {
     ID.NoCFI = true;
     return false;
   }
+  case lltok::kw_ptrauth: {
+    // ValID ::= 'ptrauth' '(' ptr @foo ',' i32 <key> ','
+    //                         ptr addrdisc ',' i64 <disc> ')'
+    Lex.Lex();
+
+    Constant *Ptr, *Key, *AddrDisc, *Disc;
+
+    if (parseToken(lltok::lparen,
+                   "expected '(' in signed pointer expression") ||
+        parseGlobalTypeAndValue(Ptr) ||
+        parseToken(lltok::comma,
+                   "expected comma in signed pointer expression") ||
+        parseGlobalTypeAndValue(Key) ||
+        parseToken(lltok::comma,
+                   "expected comma in signed pointer expression") ||
+        parseGlobalTypeAndValue(AddrDisc) ||
+        parseToken(lltok::comma,
+                   "expected comma in signed pointer expression") ||
+        parseGlobalTypeAndValue(Disc) ||
+        parseToken(lltok::rparen, "expected ')' in signed pointer expression"))
+      return true;
+
+    if (!Ptr->getType()->isPointerTy())
+      return error(ID.Loc, "signed pointer must be a pointer");
+
+    auto KeyC = dyn_cast<ConstantInt>(Key);
+    if (!KeyC || KeyC->getBitWidth() != 32)
+      return error(ID.Loc, "signed pointer key must be i32 constant integer");
+
+    if (!AddrDisc->getType()->isPointerTy())
+      return error(ID.Loc,
+                   "signed pointer address discriminator must be a pointer");
+
+    auto DiscC = dyn_cast<ConstantInt>(Disc);
+    if (!DiscC || DiscC->getBitWidth() != 64)
+      return error(ID.Loc,
+                   "signed pointer discriminator must be i64 constant integer");
+
+    ID.ConstantVal = ConstantPtrAuth::get(Ptr, KeyC, AddrDisc, DiscC);
+    ID.Kind = ValID::t_Constant;
+    return false;
+  }
 
   case lltok::kw_trunc:
   case lltok::kw_bitcast:
diff --git a/llvm/lib/Bitcode/Reader/BitcodeReader.cpp b/llvm/lib/Bitcode/Reader/BitcodeReader.cpp
index d284c9823c9ede..538200abcf6f97 100644
--- a/llvm/lib/Bitcode/Reader/BitcodeReader.cpp
+++ b/llvm/lib/Bitcode/Reader/BitcodeReader.cpp
@@ -504,7 +504,8 @@ class BitcodeConstant final : public Value,
   static constexpr uint8_t NoCFIOpcode = 252;
   static constexpr uint8_t DSOLocalEquivalentOpcode = 251;
   static constexpr uint8_t BlockAddressOpcode = 250;
-  static constexpr uint8_t FirstSpecialOpcode = BlockAddressOpcode;
+  static constexpr uint8_t ConstantPtrAuthOpcode = 249;
+  static constexpr uint8_t FirstSpecialOpcode = ConstantPtrAuthOpcode;
 
   // Separate struct to make passing different number of parameters to
   // BitcodeConstant::create() more convenient.
@@ -1528,6 +1529,18 @@ Expected<Value *> BitcodeReader::materializeValue(unsigned StartValID,
         C = ConstantExpr::get(BC->Opcode, ConstOps[0], ConstOps[1], BC->Flags);
       } else {
         switch (BC->Opcode) {
+        case BitcodeConstant::ConstantPtrAuthOpcode: {
+          auto *Key = dyn_cast<ConstantInt>(ConstOps[1]);
+          if (!Key)
+            return error("ptrauth key operand must be ConstantInt");
+
+          auto *Disc = dyn_cast<ConstantInt>(ConstOps[3]);
+          if (!Disc)
+            return error("ptrauth disc operand must be ConstantInt");
+
+          C = ConstantPtrAuth::get(ConstOps[0], Key, ConstOps[2], Disc);
+          break;
+        }
         case BitcodeConstant::NoCFIOpcode: {
           auto *GV = dyn_cast<GlobalValue>(ConstOps[0]);
           if (!GV)
@@ -3596,6 +3609,21 @@ Error BitcodeReader::parseConstants() {
                                   Record[1]);
       break;
     }
+    case bitc::CST_CODE_SIGNED_PTR: {
+      if (Record.size() < 6)
+        return error("Invalid record");
+      Type *PtrTy = getTypeByID(Record[0]);
+      if (!PtrTy)
+        return error("Invalid record");
+
+      // PtrTy, Ptr, Key, AddrDiscTy, AddrDisc, Disc
+      V = BitcodeConstant::create(
+        Alloc, CurTy, BitcodeConstant::ConstantPtrAuthOpcode,
+        {(unsigned)Record[1], (unsigned)Record[2], (unsigned)Record[4],
+         (unsigned)Record[5]});
+
+      break;
+    }
     }
 
     assert(V->getType() == getTypeByID(CurTyID) && "Incorrect result type ID");
diff --git a/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp b/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp
index 6f0879a4e0ee74..74f1bd8ba49b57 100644
--- a/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp
+++ b/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp
@@ -2800,6 +2800,14 @@ void ModuleBitcodeWriter::writeConstants(unsigned FirstVal, unsigned LastVal,
       Record.push_back(VE.getTypeID(BA->getFunction()->getType()));
       Record.push_back(VE.getValueID(BA->getFunction()));
       Record.push_back(VE.getGlobalBasicBlockID(BA->getBasicBlock()));
+    } else if (const ConstantPtrAuth *SP = dyn_cast<ConstantPtrAuth>(C)) {
+      Code = bitc::CST_CODE_SIGNED_PTR;
+      Record.push_back(VE.getTypeID(SP->getPointer()->getType()));
+      Record.push_back(VE.getValueID(SP->getPointer()));
+      Record.push_back(VE.getValueID(SP->getKey()));
+      Record.push_back(VE.getTypeID(SP->getAddrDiscriminator()->getType()));
+      Record.push_back(VE.getValueID(SP->getAddrDiscriminator()));
+      Record.push_back(VE.getValueID(SP->getDiscriminator()));
     } else if (const auto *Equiv = dyn_cast<DSOLocalEquivalent>(C)) {
       Code = bitc::CST_CODE_DSO_LOCAL_EQUIVALENT;
       Record.push_back(VE.getTypeID(Equiv->getGlobalValue()->getType()));
diff --git a/llvm/lib/IR/AsmWriter.cpp b/llvm/lib/IR/AsmWriter.cpp
index 19acc89f73fb7e..0e9227f0945a4d 100644
--- a/llvm/lib/IR/AsmWriter.cpp
+++ b/llvm/lib/IR/AsmWriter.cpp
@@ -1578,6 +1578,21 @@ static void WriteConstantInternal(raw_ostream &Out, const Constant *CV,
     return;
   }
 
+  if (const ConstantPtrAuth *SP = dyn_cast<ConstantPtrAuth>(CV)) {
+    Out << "ptrauth (";
+
+    for (unsigned i = 0; i < SP->getNumOperands(); ++i) {
+      WriterCtx.TypePrinter->print(SP->getOperand(i)->getType(), Out);
+      Out << ' ';
+      WriteAsOperandInternal(Out, SP->getOperand(i), WriterCtx);
+      if (i != SP->getNumOperands() - 1)
+        Out << ", ";
+    }
+
+    Out << ')';
+    return;
+  }
+
   if (const ConstantArray *CA = dyn_cast<ConstantArray>(CV)) {
     Type *ETy = CA->getType()->getElementType();
     Out << '[';
diff --git a/llvm/lib/IR/ConstantFold.cpp b/llvm/lib/IR/ConstantFold.cpp
index 034e397bc69fce..0bfec86783378a 100644
--- a/llvm/lib/IR/ConstantFold.cpp
+++ b/llvm/lib/IR/ConstantFold.cpp
@@ -1154,6 +1154,9 @@ static ICmpInst::Predicate evaluateICmpRelation(Constant *V1, Constant *V2) {
                                 GV->getType()->getAddressSpace()))
         return ICmpInst::ICMP_UGT;
     }
+  } else if (const ConstantPtrAuth *SP = dyn_cast<ConstantPtrAuth>(V1)) {
+    // FIXME: ahmedbougacha: implement ptrauth cst comparison
+    return ICmpInst::BAD_ICMP_PREDICATE;
   } else {
     // Ok, the LHS is known to be a constantexpr.  The RHS can be any of a
     // constantexpr, a global, block address, or a simple constant.
diff --git a/llvm/lib/IR/Constants.cpp b/llvm/lib/IR/Constants.cpp
index e6b92aad392f66..1af52f9e612c29 100644
--- a/llvm/lib/IR/Constants.cpp
+++ b/llvm/lib/IR/Constants.cpp
@@ -550,6 +550,9 @@ void llvm::deleteConstant(Constant *C) {
   case Constant::NoCFIValueVal:
     delete static_cast<NoCFIValue *>(C);
     break;
+  case Constant::ConstantPtrAuthVal:
+    delete static_cast<ConstantPtrAuth *>(C);
+    break;
   case Constant::UndefValueVal:
     delete static_cast<UndefValue *>(C);
     break;
@@ -2015,6 +2018,118 @@ Value *NoCFIValue::handleOperandChangeImpl(Value *From, Value *To) {
   return nullptr;
 }
 
+//---- ConstantPtrAuth::get() implementations.
+//
+
+static bool areEquivalentAddrDiscriminators(const Value *V1, const Value *V2,
+                                            const DataLayout &DL) {
+  APInt V1Off(DL.getPointerSizeInBits(), 0);
+  APInt V2Off(DL.getPointerSizeInBits(), 0);
+
+  if (auto *V1Cast = dyn_cast<PtrToIntOperator>(V1))
+    V1 = V1Cast->getPointerOperand();
+  if (auto *V2Cast = dyn_cast<PtrToIntOperator>(V2))
+    V2 = V2Cast->getPointerOperand();
+  auto *V1Base = V1->stripAndAccumulateConstantOffsets(
+      DL, V1Off, /*AllowNonInbounds=*/true);
+  auto *V2Base = V2->stripAndAccumulateConstantOffsets(
+      DL, V2Off, /*AllowNonInbounds=*/true);
+  return V1Base == V2Base && V1Off == V2Off;
+}
+
+bool ConstantPtrAuth::isCompatibleWith(const Value *Key,
+                                       const Value *Discriminator,
+                                       const DataLayout &DL) const {
+  // If the keys are different, there's no chance for this to be compatible.
+  if (Key != getKey())
+    return false;
+
+  // If the discriminators are the same, this is compatible iff there is no
+  // address discriminator.
+  if (Discriminator == getDiscriminator())
+    return getAddrDiscriminator()->isNullValue();
+
+  // If we dynamically blend the discriminator with the address discriminator,
+  // this is compatible.
+  if (auto *DiscBlend = dyn_cast<IntrinsicInst>(Discriminator)) {
+    if (DiscBlend->getIntrinsicID() == Intrinsic::ptrauth_blend &&
+        DiscBlend->getOperand(1) == getDiscriminator() &&
+        areEquivalentAddrDiscriminators(DiscBlend->getOperand(0),
+                                        getAddrDiscriminator(), DL))
+      return true;
+  }
+
+  // If we don't have a non-address discriminator, we don't need a blend in
+  // the first place:  accept the address discriminator as the discriminator.
+  if (getDiscriminator()->isNullValue() &&
+      areEquivalentAddrDiscriminators(getAddrDiscriminator(), Discriminator,
+                                      DL))
+    return true;
+
+  // Otherwise, we don't know.
+  return false;
+}
+
+ConstantPtrAuth *ConstantPtrAuth::getWithSameSchema(Constant *Pointer) const {
+  return get(Pointer, getKey(), getAddrDiscriminator(), getDiscriminator());
+}
+
+ConstantPtrAuth *ConstantPtrAuth::get(Constant *Ptr, ConstantInt *Key,
+                                      Constant *AddrDisc, ConstantInt *Disc) {
+  Constant *ArgVec[] = {Ptr, Key, AddrDisc, Disc};
+  ConstantPtrAuthKeyType MapKey(ArgVec);
+  LLVMContextImpl *pImpl = Ptr->getContext().pImpl;
+  return pImpl->ConstantPtrAuths.getOrCreate(Ptr->getType(), MapKey);
+}
+
+ConstantPtrAuth::ConstantPtrAuth(Constant *Ptr, ConstantInt *Key,
+                                 Constant *AddrDisc, ConstantInt *Disc)
+    : Constant(Ptr->getType(), Value::ConstantPtrAuthVal, &Op<0>(), 4) {
+#ifndef NDEBUG
+  assert(Ptr->getType()->isPointerTy());
+  assert(Key->getBitWidth() == 32);
+  assert(AddrDisc->getType()->isPointerTy());
+  assert(Disc->getBitWidth() == 64);
+#endif
+  setOperand(0, Ptr);
+  setOperand(1, Key);
+  setOperand(2, AddrDisc);
+  setOperand(3, Disc);
+}
+
+/// Remove the constant from the constant table.
+void ConstantPtrAuth::destroyConstantImpl() {
+  getType()->getContext().pImpl->ConstantPtrAuths.remove(this);
+}
+
+Value *ConstantPtrAuth::handleOperandChangeImpl(Value *From, Value *ToV) {
+  assert(isa<Constant>(ToV) && "Cannot make Constant refer to non-constant!");
+  Constant *To = cast<Constant>(ToV);
+
+  SmallVector<Constant *, 8> Values;
+  Values.reserve(getNumOperands()); // Build replacement array.
+
+  // Fill values with the modified operands of the constant array.  Also,
+  // compute whether this turns into an all-zeros array.
+  unsigned NumUpdated = 0;
+
+  Use *OperandList = getOperandList();
+  unsigned OperandNo = 0;
+  for (Use *O = OperandList, *E = OperandList + getNumOperands(); O != E; ++O) {
+    Constant *Val = cast<Constant>(O->get());
+    if (Val == From) {
+      OperandNo = (O - OperandList);
+      Val = To;
+      ++NumUpdated;
+    }
+    Values.push_back(Val);
+  }
+
+  // FIXME: shouldn't we check it's not already there?
+  return getContext().pImpl->ConstantPtrAuths.replaceOperandsInPlace(
+      Values, this, From, To, NumUpdated, OperandNo);
+}
+
 //---- ConstantExpr::get() implementations.
 //
 
diff --git a/llvm/lib/IR/ConstantsContext.h b/llvm/lib/IR/ConstantsContext.h
index 44a926b5dc58e0..bd111f406687c1 100644
--- a/llvm/lib/IR/ConstantsContext.h
+++ b/llvm/lib/IR/ConstantsContext.h
@@ -23,6 +23,7 @@
 #include "llvm/IR/Constant.h"
 #include "llvm/IR/Constants.h"
 #include "llvm/IR/DerivedTypes.h"
+#include "llvm/IR/GlobalVariable.h"
 #include "llvm/IR/InlineAsm.h"
 #include "llvm/IR/Instruction.h"
 #include "llvm/IR/Instructions.h"
@@ -282,6 +283...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Mar 19, 2024

@llvm/pr-subscribers-llvm-ir

Author: Ahmed Bougacha (ahmedbougacha)

Changes

This defines a new kind of IR Constant that represents a ptrauth signed pointer, as used in AArch64 PAuth.

It allows representing most kinds of signed pointer constants used thus far in the llvm ptrauth implementations, notably those used in the Darwin and ELF ABIs being implemented for c/c++. These signed pointer constants are then lowered to ELF/MachO relocations.

These can be simply thought of as a constant llvm.ptrauth.sign, with the interesting addition of discriminator computation: the ptrauth constant can also represent a combined blend, when both address and integer discriminator operands are used.

This also teaches some of the most common constant folding and analysis paths about these, usually in a straightforward way. I have a couple fixmes to expand some of those here, as well as to add unittest for the ConstantPtrAuth methods.


Patch is 29.61 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/85738.diff

26 Files Affected:

  • (modified) llvm/docs/LangRef.rst (+27)
  • (modified) llvm/include/llvm-c/Core.h (+1)
  • (modified) llvm/include/llvm/AsmParser/LLToken.h (+1)
  • (modified) llvm/include/llvm/Bitcode/LLVMBitCodes.h (+2)
  • (modified) llvm/include/llvm/IR/Constants.h (+63)
  • (modified) llvm/include/llvm/IR/Value.def (+1)
  • (modified) llvm/lib/Analysis/ValueTracking.cpp (+7)
  • (modified) llvm/lib/AsmParser/LLLexer.cpp (+1)
  • (modified) llvm/lib/AsmParser/LLParser.cpp (+42)
  • (modified) llvm/lib/Bitcode/Reader/BitcodeReader.cpp (+29-1)
  • (modified) llvm/lib/Bitcode/Writer/BitcodeWriter.cpp (+8)
  • (modified) llvm/lib/IR/AsmWriter.cpp (+15)
  • (modified) llvm/lib/IR/ConstantFold.cpp (+3)
  • (modified) llvm/lib/IR/Constants.cpp (+115)
  • (modified) llvm/lib/IR/ConstantsContext.h (+47)
  • (modified) llvm/lib/IR/LLVMContextImpl.h (+2)
  • (modified) llvm/lib/IR/Verifier.cpp (+20)
  • (added) llvm/test/Assembler/invalid-ptrauth-const1.ll (+6)
  • (added) llvm/test/Assembler/invalid-ptrauth-const2.ll (+6)
  • (added) llvm/test/Assembler/invalid-ptrauth-const3.ll (+6)
  • (added) llvm/test/Assembler/invalid-ptrauth-const4.ll (+6)
  • (added) llvm/test/Assembler/invalid-ptrauth-const5.ll (+6)
  • (added) llvm/test/Assembler/invalid-ptrauth-const6.ll (+6)
  • (added) llvm/test/Assembler/ptrauth-const.ll (+13)
  • (modified) llvm/test/Bitcode/compatibility.ll (+4)
  • (modified) llvm/utils/vim/syntax/llvm.vim (+1)
diff --git a/llvm/docs/LangRef.rst b/llvm/docs/LangRef.rst
index e07b642285b3e6..0d91d4fc3ba1ef 100644
--- a/llvm/docs/LangRef.rst
+++ b/llvm/docs/LangRef.rst
@@ -4748,6 +4748,33 @@ reference to the CFI jump table in the ``LowerTypeTests`` pass. These constants
 may be useful in low-level programs, such as operating system kernels, which
 need to refer to the actual function body.
 
+.. _ptrauth
+
+Authenticated Pointers
+----------------------
+
+``ptrauth (ptr CST, i32 KEY, ptr ADDRDISC, i16 DISC)
+
+A '``ptrauth``' constant represents a pointer with a cryptographic
+authentication signature embedded into some bits. Its type is the same as the
+first argument.
+
+
+If the address disciminator is ``null`` then the expression is equivalent to
+
+.. code-block:llvm
+    %tmp = call i64 @llvm.ptrauth.sign.i64(i64 ptrtoint (ptr CST to i64), i32 KEY, i64 DISC)
+    %val = inttoptr i64 %tmp to ptr
+
+If the address discriminator is present, then it is
+
+.. code-block:llvm
+    %tmp1 = call i64 @llvm.ptrauth.blend.i64(i64 ptrtoint (ptr ADDRDISC to i64), i64 DISC)
+    %tmp2 = call i64 @llvm.ptrauth.sign.i64(i64 ptrtoint (ptr CST to i64), i64  %tmp1)
+    %val = inttoptr i64 %tmp2 to ptr
+
+    %tmp = call i64 @llvm.ptrauth.blend.i64
+
 .. _constantexprs:
 
 Constant Expressions
diff --git a/llvm/include/llvm-c/Core.h b/llvm/include/llvm-c/Core.h
index f56a6c961aad74..5f69a07fbed644 100644
--- a/llvm/include/llvm-c/Core.h
+++ b/llvm/include/llvm-c/Core.h
@@ -286,6 +286,7 @@ typedef enum {
   LLVMInstructionValueKind,
   LLVMPoisonValueValueKind,
   LLVMConstantTargetNoneValueKind,
+  LLVMConstantPtrAuthValueKind,
 } LLVMValueKind;
 
 typedef enum {
diff --git a/llvm/include/llvm/AsmParser/LLToken.h b/llvm/include/llvm/AsmParser/LLToken.h
index 5863a8d6e8ee84..e949023463f54d 100644
--- a/llvm/include/llvm/AsmParser/LLToken.h
+++ b/llvm/include/llvm/AsmParser/LLToken.h
@@ -343,6 +343,7 @@ enum Kind {
   kw_insertvalue,
   kw_blockaddress,
   kw_dso_local_equivalent,
+  kw_ptrauth,
   kw_no_cfi,
 
   kw_freeze,
diff --git a/llvm/include/llvm/Bitcode/LLVMBitCodes.h b/llvm/include/llvm/Bitcode/LLVMBitCodes.h
index 39303e64852141..747bd55c2a8c82 100644
--- a/llvm/include/llvm/Bitcode/LLVMBitCodes.h
+++ b/llvm/include/llvm/Bitcode/LLVMBitCodes.h
@@ -411,6 +411,8 @@ enum ConstantsCodes {
                               //                 sideeffect|alignstack|
                               //                 asmdialect|unwind,
                               //                 asmstr,conststr]
+  CST_CODE_SIGNED_PTR = 31,   // CE_SIGNED_PTR: [ptrty, ptr, key,
+                              //                 addrdiscty, addrdisc, disc]
 };
 
 /// CastOpcodes - These are values used in the bitcode files to encode which
diff --git a/llvm/include/llvm/IR/Constants.h b/llvm/include/llvm/IR/Constants.h
index c0ac9a4aa6750c..9cf53616cc921b 100644
--- a/llvm/include/llvm/IR/Constants.h
+++ b/llvm/include/llvm/IR/Constants.h
@@ -1006,6 +1006,69 @@ struct OperandTraits<NoCFIValue> : public FixedNumOperandTraits<NoCFIValue, 1> {
 
 DEFINE_TRANSPARENT_OPERAND_ACCESSORS(NoCFIValue, Value)
 
+/// A signed pointer
+///
+class ConstantPtrAuth final : public Constant {
+  friend struct ConstantPtrAuthKeyType;
+  friend class Constant;
+
+  ConstantPtrAuth(Constant *Ptr, ConstantInt *Key, Constant *AddrDisc,
+                  ConstantInt *Disc);
+
+  void *operator new(size_t s) { return User::operator new(s, 4); }
+
+  void destroyConstantImpl();
+  Value *handleOperandChangeImpl(Value *From, Value *To);
+
+public:
+  /// Return a pointer authenticated with the specified parameters.
+  static ConstantPtrAuth *get(Constant *Ptr, ConstantInt *Key,
+                              Constant *AddrDisc, ConstantInt *Disc);
+
+  /// Produce a new ptrauth expression signing the given value using
+  /// the same schema as is stored in one.
+  ConstantPtrAuth *getWithSameSchema(Constant *Pointer) const;
+
+  /// Transparently provide more efficient getOperand methods.
+  DECLARE_TRANSPARENT_OPERAND_ACCESSORS(Constant);
+
+  /// The pointer that is authenticated in this authenticated global reference.
+  Constant *getPointer() const { return (Constant *)Op<0>().get(); }
+
+  /// The Key ID, an i32 constant.
+  ConstantInt *getKey() const { return (ConstantInt *)Op<1>().get(); }
+
+  /// The address discriminator if any, or the null constant.
+  /// If present, this must be a value equivalent to the storage location of
+  /// the only user of the authenticated ptrauth global.
+  Constant *getAddrDiscriminator() const { return (Constant *)Op<2>().get(); }
+
+  /// The discriminator.
+  ConstantInt *getDiscriminator() const { return (ConstantInt *)Op<3>().get(); }
+
+  /// Whether there is any non-null address discriminator.
+  bool hasAddressDiversity() const {
+    return !getAddrDiscriminator()->isNullValue();
+  }
+
+  /// Check whether an authentication operation with key \p KeyV and (possibly
+  /// blended) discriminator \p DiscriminatorV is compatible with this
+  /// authenticated global reference.
+  bool isCompatibleWith(const Value *Key, const Value *Discriminator,
+                        const DataLayout &DL) const;
+
+  /// Methods for support type inquiry through isa, cast, and dyn_cast:
+  static bool classof(const Value *V) {
+    return V->getValueID() == ConstantPtrAuthVal;
+  }
+};
+
+template <>
+struct OperandTraits<ConstantPtrAuth>
+    : public FixedNumOperandTraits<ConstantPtrAuth, 4> {};
+
+DEFINE_TRANSPARENT_OPERAND_ACCESSORS(ConstantPtrAuth, Constant)
+
 //===----------------------------------------------------------------------===//
 /// A constant value that is initialized with an expression using
 /// other constant values.
diff --git a/llvm/include/llvm/IR/Value.def b/llvm/include/llvm/IR/Value.def
index 61f7a87666d094..31110ff05ae368 100644
--- a/llvm/include/llvm/IR/Value.def
+++ b/llvm/include/llvm/IR/Value.def
@@ -78,6 +78,7 @@ HANDLE_GLOBAL_VALUE(GlobalAlias)
 HANDLE_GLOBAL_VALUE(GlobalIFunc)
 HANDLE_GLOBAL_VALUE(GlobalVariable)
 HANDLE_CONSTANT(BlockAddress)
+HANDLE_CONSTANT(ConstantPtrAuth)
 HANDLE_CONSTANT(ConstantExpr)
 HANDLE_CONSTANT_EXCLUDE_LLVM_C_API(DSOLocalEquivalent)
 HANDLE_CONSTANT_EXCLUDE_LLVM_C_API(NoCFIValue)
diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index fe5d084b55bbe3..d0bdaca57e47f5 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -2900,6 +2900,10 @@ bool isKnownNonZero(const Value *V, const APInt &DemandedElts, unsigned Depth,
       return true;
     }
 
+    // Constant ptrauth can be null, iff the base pointer can be.
+    if (auto *CPA = dyn_cast<ConstantPtrAuth>(V))
+      return isKnownNonZero(CPA->getPointer(), DemandedElts, Depth, Q);
+
     // A global variable in address space 0 is non null unless extern weak
     // or an absolute symbol reference. Other address spaces may have null as a
     // valid address for a global, so we can't assume anything.
@@ -6993,6 +6997,9 @@ static bool isGuaranteedNotToBeUndefOrPoison(
         isa<ConstantPointerNull>(C) || isa<Function>(C))
       return true;
 
+    if (isa<ConstantPtrAuth>(C))
+      return true;
+
     if (C->getType()->isVectorTy() && !isa<ConstantExpr>(C))
       return (!includesUndef(Kind) ? !C->containsPoisonElement()
                                    : !C->containsUndefOrPoisonElement()) &&
diff --git a/llvm/lib/AsmParser/LLLexer.cpp b/llvm/lib/AsmParser/LLLexer.cpp
index 02f64fcfac4f0c..e37ee0bb90a82d 100644
--- a/llvm/lib/AsmParser/LLLexer.cpp
+++ b/llvm/lib/AsmParser/LLLexer.cpp
@@ -708,6 +708,7 @@ lltok::Kind LLLexer::LexIdentifier() {
   KEYWORD(blockaddress);
   KEYWORD(dso_local_equivalent);
   KEYWORD(no_cfi);
+  KEYWORD(ptrauth);
 
   // Metadata types.
   KEYWORD(distinct);
diff --git a/llvm/lib/AsmParser/LLParser.cpp b/llvm/lib/AsmParser/LLParser.cpp
index 2e0f5ba82220c9..21039f7efb9b27 100644
--- a/llvm/lib/AsmParser/LLParser.cpp
+++ b/llvm/lib/AsmParser/LLParser.cpp
@@ -3998,6 +3998,48 @@ bool LLParser::parseValID(ValID &ID, PerFunctionState *PFS, Type *ExpectedTy) {
     ID.NoCFI = true;
     return false;
   }
+  case lltok::kw_ptrauth: {
+    // ValID ::= 'ptrauth' '(' ptr @foo ',' i32 <key> ','
+    //                         ptr addrdisc ',' i64 <disc> ')'
+    Lex.Lex();
+
+    Constant *Ptr, *Key, *AddrDisc, *Disc;
+
+    if (parseToken(lltok::lparen,
+                   "expected '(' in signed pointer expression") ||
+        parseGlobalTypeAndValue(Ptr) ||
+        parseToken(lltok::comma,
+                   "expected comma in signed pointer expression") ||
+        parseGlobalTypeAndValue(Key) ||
+        parseToken(lltok::comma,
+                   "expected comma in signed pointer expression") ||
+        parseGlobalTypeAndValue(AddrDisc) ||
+        parseToken(lltok::comma,
+                   "expected comma in signed pointer expression") ||
+        parseGlobalTypeAndValue(Disc) ||
+        parseToken(lltok::rparen, "expected ')' in signed pointer expression"))
+      return true;
+
+    if (!Ptr->getType()->isPointerTy())
+      return error(ID.Loc, "signed pointer must be a pointer");
+
+    auto KeyC = dyn_cast<ConstantInt>(Key);
+    if (!KeyC || KeyC->getBitWidth() != 32)
+      return error(ID.Loc, "signed pointer key must be i32 constant integer");
+
+    if (!AddrDisc->getType()->isPointerTy())
+      return error(ID.Loc,
+                   "signed pointer address discriminator must be a pointer");
+
+    auto DiscC = dyn_cast<ConstantInt>(Disc);
+    if (!DiscC || DiscC->getBitWidth() != 64)
+      return error(ID.Loc,
+                   "signed pointer discriminator must be i64 constant integer");
+
+    ID.ConstantVal = ConstantPtrAuth::get(Ptr, KeyC, AddrDisc, DiscC);
+    ID.Kind = ValID::t_Constant;
+    return false;
+  }
 
   case lltok::kw_trunc:
   case lltok::kw_bitcast:
diff --git a/llvm/lib/Bitcode/Reader/BitcodeReader.cpp b/llvm/lib/Bitcode/Reader/BitcodeReader.cpp
index d284c9823c9ede..538200abcf6f97 100644
--- a/llvm/lib/Bitcode/Reader/BitcodeReader.cpp
+++ b/llvm/lib/Bitcode/Reader/BitcodeReader.cpp
@@ -504,7 +504,8 @@ class BitcodeConstant final : public Value,
   static constexpr uint8_t NoCFIOpcode = 252;
   static constexpr uint8_t DSOLocalEquivalentOpcode = 251;
   static constexpr uint8_t BlockAddressOpcode = 250;
-  static constexpr uint8_t FirstSpecialOpcode = BlockAddressOpcode;
+  static constexpr uint8_t ConstantPtrAuthOpcode = 249;
+  static constexpr uint8_t FirstSpecialOpcode = ConstantPtrAuthOpcode;
 
   // Separate struct to make passing different number of parameters to
   // BitcodeConstant::create() more convenient.
@@ -1528,6 +1529,18 @@ Expected<Value *> BitcodeReader::materializeValue(unsigned StartValID,
         C = ConstantExpr::get(BC->Opcode, ConstOps[0], ConstOps[1], BC->Flags);
       } else {
         switch (BC->Opcode) {
+        case BitcodeConstant::ConstantPtrAuthOpcode: {
+          auto *Key = dyn_cast<ConstantInt>(ConstOps[1]);
+          if (!Key)
+            return error("ptrauth key operand must be ConstantInt");
+
+          auto *Disc = dyn_cast<ConstantInt>(ConstOps[3]);
+          if (!Disc)
+            return error("ptrauth disc operand must be ConstantInt");
+
+          C = ConstantPtrAuth::get(ConstOps[0], Key, ConstOps[2], Disc);
+          break;
+        }
         case BitcodeConstant::NoCFIOpcode: {
           auto *GV = dyn_cast<GlobalValue>(ConstOps[0]);
           if (!GV)
@@ -3596,6 +3609,21 @@ Error BitcodeReader::parseConstants() {
                                   Record[1]);
       break;
     }
+    case bitc::CST_CODE_SIGNED_PTR: {
+      if (Record.size() < 6)
+        return error("Invalid record");
+      Type *PtrTy = getTypeByID(Record[0]);
+      if (!PtrTy)
+        return error("Invalid record");
+
+      // PtrTy, Ptr, Key, AddrDiscTy, AddrDisc, Disc
+      V = BitcodeConstant::create(
+        Alloc, CurTy, BitcodeConstant::ConstantPtrAuthOpcode,
+        {(unsigned)Record[1], (unsigned)Record[2], (unsigned)Record[4],
+         (unsigned)Record[5]});
+
+      break;
+    }
     }
 
     assert(V->getType() == getTypeByID(CurTyID) && "Incorrect result type ID");
diff --git a/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp b/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp
index 6f0879a4e0ee74..74f1bd8ba49b57 100644
--- a/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp
+++ b/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp
@@ -2800,6 +2800,14 @@ void ModuleBitcodeWriter::writeConstants(unsigned FirstVal, unsigned LastVal,
       Record.push_back(VE.getTypeID(BA->getFunction()->getType()));
       Record.push_back(VE.getValueID(BA->getFunction()));
       Record.push_back(VE.getGlobalBasicBlockID(BA->getBasicBlock()));
+    } else if (const ConstantPtrAuth *SP = dyn_cast<ConstantPtrAuth>(C)) {
+      Code = bitc::CST_CODE_SIGNED_PTR;
+      Record.push_back(VE.getTypeID(SP->getPointer()->getType()));
+      Record.push_back(VE.getValueID(SP->getPointer()));
+      Record.push_back(VE.getValueID(SP->getKey()));
+      Record.push_back(VE.getTypeID(SP->getAddrDiscriminator()->getType()));
+      Record.push_back(VE.getValueID(SP->getAddrDiscriminator()));
+      Record.push_back(VE.getValueID(SP->getDiscriminator()));
     } else if (const auto *Equiv = dyn_cast<DSOLocalEquivalent>(C)) {
       Code = bitc::CST_CODE_DSO_LOCAL_EQUIVALENT;
       Record.push_back(VE.getTypeID(Equiv->getGlobalValue()->getType()));
diff --git a/llvm/lib/IR/AsmWriter.cpp b/llvm/lib/IR/AsmWriter.cpp
index 19acc89f73fb7e..0e9227f0945a4d 100644
--- a/llvm/lib/IR/AsmWriter.cpp
+++ b/llvm/lib/IR/AsmWriter.cpp
@@ -1578,6 +1578,21 @@ static void WriteConstantInternal(raw_ostream &Out, const Constant *CV,
     return;
   }
 
+  if (const ConstantPtrAuth *SP = dyn_cast<ConstantPtrAuth>(CV)) {
+    Out << "ptrauth (";
+
+    for (unsigned i = 0; i < SP->getNumOperands(); ++i) {
+      WriterCtx.TypePrinter->print(SP->getOperand(i)->getType(), Out);
+      Out << ' ';
+      WriteAsOperandInternal(Out, SP->getOperand(i), WriterCtx);
+      if (i != SP->getNumOperands() - 1)
+        Out << ", ";
+    }
+
+    Out << ')';
+    return;
+  }
+
   if (const ConstantArray *CA = dyn_cast<ConstantArray>(CV)) {
     Type *ETy = CA->getType()->getElementType();
     Out << '[';
diff --git a/llvm/lib/IR/ConstantFold.cpp b/llvm/lib/IR/ConstantFold.cpp
index 034e397bc69fce..0bfec86783378a 100644
--- a/llvm/lib/IR/ConstantFold.cpp
+++ b/llvm/lib/IR/ConstantFold.cpp
@@ -1154,6 +1154,9 @@ static ICmpInst::Predicate evaluateICmpRelation(Constant *V1, Constant *V2) {
                                 GV->getType()->getAddressSpace()))
         return ICmpInst::ICMP_UGT;
     }
+  } else if (const ConstantPtrAuth *SP = dyn_cast<ConstantPtrAuth>(V1)) {
+    // FIXME: ahmedbougacha: implement ptrauth cst comparison
+    return ICmpInst::BAD_ICMP_PREDICATE;
   } else {
     // Ok, the LHS is known to be a constantexpr.  The RHS can be any of a
     // constantexpr, a global, block address, or a simple constant.
diff --git a/llvm/lib/IR/Constants.cpp b/llvm/lib/IR/Constants.cpp
index e6b92aad392f66..1af52f9e612c29 100644
--- a/llvm/lib/IR/Constants.cpp
+++ b/llvm/lib/IR/Constants.cpp
@@ -550,6 +550,9 @@ void llvm::deleteConstant(Constant *C) {
   case Constant::NoCFIValueVal:
     delete static_cast<NoCFIValue *>(C);
     break;
+  case Constant::ConstantPtrAuthVal:
+    delete static_cast<ConstantPtrAuth *>(C);
+    break;
   case Constant::UndefValueVal:
     delete static_cast<UndefValue *>(C);
     break;
@@ -2015,6 +2018,118 @@ Value *NoCFIValue::handleOperandChangeImpl(Value *From, Value *To) {
   return nullptr;
 }
 
+//---- ConstantPtrAuth::get() implementations.
+//
+
+static bool areEquivalentAddrDiscriminators(const Value *V1, const Value *V2,
+                                            const DataLayout &DL) {
+  APInt V1Off(DL.getPointerSizeInBits(), 0);
+  APInt V2Off(DL.getPointerSizeInBits(), 0);
+
+  if (auto *V1Cast = dyn_cast<PtrToIntOperator>(V1))
+    V1 = V1Cast->getPointerOperand();
+  if (auto *V2Cast = dyn_cast<PtrToIntOperator>(V2))
+    V2 = V2Cast->getPointerOperand();
+  auto *V1Base = V1->stripAndAccumulateConstantOffsets(
+      DL, V1Off, /*AllowNonInbounds=*/true);
+  auto *V2Base = V2->stripAndAccumulateConstantOffsets(
+      DL, V2Off, /*AllowNonInbounds=*/true);
+  return V1Base == V2Base && V1Off == V2Off;
+}
+
+bool ConstantPtrAuth::isCompatibleWith(const Value *Key,
+                                       const Value *Discriminator,
+                                       const DataLayout &DL) const {
+  // If the keys are different, there's no chance for this to be compatible.
+  if (Key != getKey())
+    return false;
+
+  // If the discriminators are the same, this is compatible iff there is no
+  // address discriminator.
+  if (Discriminator == getDiscriminator())
+    return getAddrDiscriminator()->isNullValue();
+
+  // If we dynamically blend the discriminator with the address discriminator,
+  // this is compatible.
+  if (auto *DiscBlend = dyn_cast<IntrinsicInst>(Discriminator)) {
+    if (DiscBlend->getIntrinsicID() == Intrinsic::ptrauth_blend &&
+        DiscBlend->getOperand(1) == getDiscriminator() &&
+        areEquivalentAddrDiscriminators(DiscBlend->getOperand(0),
+                                        getAddrDiscriminator(), DL))
+      return true;
+  }
+
+  // If we don't have a non-address discriminator, we don't need a blend in
+  // the first place:  accept the address discriminator as the discriminator.
+  if (getDiscriminator()->isNullValue() &&
+      areEquivalentAddrDiscriminators(getAddrDiscriminator(), Discriminator,
+                                      DL))
+    return true;
+
+  // Otherwise, we don't know.
+  return false;
+}
+
+ConstantPtrAuth *ConstantPtrAuth::getWithSameSchema(Constant *Pointer) const {
+  return get(Pointer, getKey(), getAddrDiscriminator(), getDiscriminator());
+}
+
+ConstantPtrAuth *ConstantPtrAuth::get(Constant *Ptr, ConstantInt *Key,
+                                      Constant *AddrDisc, ConstantInt *Disc) {
+  Constant *ArgVec[] = {Ptr, Key, AddrDisc, Disc};
+  ConstantPtrAuthKeyType MapKey(ArgVec);
+  LLVMContextImpl *pImpl = Ptr->getContext().pImpl;
+  return pImpl->ConstantPtrAuths.getOrCreate(Ptr->getType(), MapKey);
+}
+
+ConstantPtrAuth::ConstantPtrAuth(Constant *Ptr, ConstantInt *Key,
+                                 Constant *AddrDisc, ConstantInt *Disc)
+    : Constant(Ptr->getType(), Value::ConstantPtrAuthVal, &Op<0>(), 4) {
+#ifndef NDEBUG
+  assert(Ptr->getType()->isPointerTy());
+  assert(Key->getBitWidth() == 32);
+  assert(AddrDisc->getType()->isPointerTy());
+  assert(Disc->getBitWidth() == 64);
+#endif
+  setOperand(0, Ptr);
+  setOperand(1, Key);
+  setOperand(2, AddrDisc);
+  setOperand(3, Disc);
+}
+
+/// Remove the constant from the constant table.
+void ConstantPtrAuth::destroyConstantImpl() {
+  getType()->getContext().pImpl->ConstantPtrAuths.remove(this);
+}
+
+Value *ConstantPtrAuth::handleOperandChangeImpl(Value *From, Value *ToV) {
+  assert(isa<Constant>(ToV) && "Cannot make Constant refer to non-constant!");
+  Constant *To = cast<Constant>(ToV);
+
+  SmallVector<Constant *, 8> Values;
+  Values.reserve(getNumOperands()); // Build replacement array.
+
+  // Fill values with the modified operands of the constant array.  Also,
+  // compute whether this turns into an all-zeros array.
+  unsigned NumUpdated = 0;
+
+  Use *OperandList = getOperandList();
+  unsigned OperandNo = 0;
+  for (Use *O = OperandList, *E = OperandList + getNumOperands(); O != E; ++O) {
+    Constant *Val = cast<Constant>(O->get());
+    if (Val == From) {
+      OperandNo = (O - OperandList);
+      Val = To;
+      ++NumUpdated;
+    }
+    Values.push_back(Val);
+  }
+
+  // FIXME: shouldn't we check it's not already there?
+  return getContext().pImpl->ConstantPtrAuths.replaceOperandsInPlace(
+      Values, this, From, To, NumUpdated, OperandNo);
+}
+
 //---- ConstantExpr::get() implementations.
 //
 
diff --git a/llvm/lib/IR/ConstantsContext.h b/llvm/lib/IR/ConstantsContext.h
index 44a926b5dc58e0..bd111f406687c1 100644
--- a/llvm/lib/IR/ConstantsContext.h
+++ b/llvm/lib/IR/ConstantsContext.h
@@ -23,6 +23,7 @@
 #include "llvm/IR/Constant.h"
 #include "llvm/IR/Constants.h"
 #include "llvm/IR/DerivedTypes.h"
+#include "llvm/IR/GlobalVariable.h"
 #include "llvm/IR/InlineAsm.h"
 #include "llvm/IR/Instruction.h"
 #include "llvm/IR/Instructions.h"
@@ -282,6 +283...
[truncated]

Copy link

github-actions bot commented Mar 19, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

Copy link
Contributor

@nikic nikic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not familiar with ptrauth, but this looks fine to me from an IR perspective. There should probably be an RFC on discourse for this addition though (unless it's already part of some general ptrauth RFC).

@asl asl linked an issue Apr 29, 2024 that may be closed by this pull request
@asl
Copy link
Collaborator

asl commented May 6, 2024

@ahmedbougacha How can we help to move this PR forward?

@asl asl changed the title [IR][AArch64] Add "ptrauth(...)" Constant to represent signed pointers. [PAC][IR][AArch64] Add "ptrauth(...)" Constant to represent signed pointers. May 6, 2024
@ahmedbougacha
Copy link
Member Author

ahmedbougacha commented May 8, 2024

Thanks for taking a look! e47a75a should address the comments, and 48a946c does a couple other minor cleanups. Looking at this again, we can make the discriminator components optional while it's still easy, and re-order them to have the simple integer-discriminator-only case shorter; did that in c6638c7.
We had a couple RFCs in the past but I'll make a fresh thread.

Copy link
Contributor

@nikic nikic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good in terms of IR, but someone familiar with ptrauth should take a look at this as well...

return error(
ID.Loc, "constant ptrauth address discriminator must be a pointer");
} else {
AddrDisc = ConstantPointerNull::get(PointerType::get(Context, 0));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can Ptr be in a non-default address space, and if so, does AddrDisc have to be in the same address space?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, I see a test below where this is the case. I guess the question then changes to: Should the default address space of AddrDisc be 0 or the address space of Ptr?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, the base pointer and the addr-disc pointer aren't really related, other than "addr-disc points to memory that's initialized with the full ptrauth() signed pointer." So IMO it makes sense to leave the addr-disc as address space 0 by default, though the targets where we'd have real ptrauth implementations don't have very sophisticated usage of address-spaces, so it's probably not the end of the world to forbid them entirely for now.

APInt Off1(DL.getTypeSizeInBits(V1->getType()), 0);
APInt Off2(DL.getTypeSizeInBits(V2->getType()), 0);

if (auto *V1Cast = dyn_cast<PtrToIntOperator>(V1))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does ptrtoint occur here if the other side is a ptrauth.blend?

Copy link
Member Author

@ahmedbougacha ahmedbougacha May 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good observation, it's really only one side that we'd expect to be a (ptrtoint of a) blend. I rewrote the function to hopefully be more explicit about what it's actually checking.

For some context for the ptrauth crowd, an interesting user is comparing a call bundle with a constant target, and the bundles only have a single discriminator operand. I spent some time looking into splitting up the discriminator there, but I don't think it's worth it right now, and isn't all that useful without doing the same split throughout the intrinsics as well. But if and when we do that, this function becomes a trivial component-wise comparison, mostly.

But back to the patch: I also renamed this isKnownCompatibleWith because the old name was misleading.

@ahmedbougacha
Copy link
Member Author

FWIW, RFC here:
https://discourse.llvm.org/t/rfc-adding-ptrauth-constants/78926

I was waiting to also bring up the bundles there (from #85736), but on second thought those aren't really interesting enough (at the IR level, beyond backend boilerplate) to be deserving of that; let's focus on constants!

which describes an authenticated relocation producing a signed pointer.

```llvm
ptrauth (ptr CST, i32 KEY, i64 DISC, ptr ADDRDISC)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any restrictions on ptr worth documenting? Or is it just the same restrictions as constant expressions in globals, ie, that they can eventually be lowered to Mach-o/ELF relocations?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. Eventually this will be lowered to ELF pauth relocs (see e.g. https://github.com/ARM-software/abi-aa/blob/main/pauthabielf64/pauthabielf64.rst)

@asl asl requested a review from nikic May 21, 2024 20:21
@efriedma-quic
Copy link
Collaborator

I'm a little concerned that various places which check for isa<ConstantExpr> won't interact correctly with this: prior to this patch, the only possible constant expressions were ConstantData, GlobalValue, ConstantExpr, and ConstantAggregate. This introduces another possibility... and I'm not sure everything is prepared for it. Some stuff falls back gracefully, but I'm not sure everything does (e.g. Constant::containsConstantExpression).

Fixing that assumption is something we want to do eventually; so not really a problem, but consider spending a bit of time to try to find issues before we trip over them on user code.

@asl
Copy link
Collaborator

asl commented May 21, 2024

@efriedma-quic But it is Constant, not ConstantExpr. So in this sense it is not different from e.g. DSOLocalEquivalent or NoCFIValue that we are already having in tree. Or maybe I'm missing what do you mean?

@efriedma-quic
Copy link
Collaborator

@efriedma-quic But it is Constant, not ConstantExpr. So in this sense it is not different from e.g. DSOLocalEquivalent or NoCFIValue that we are already having in tree. Or maybe I'm missing what do you mean?

Well, no, but those are both used pretty infrequently. Actually, I forgot they already existed...

@asl
Copy link
Collaborator

asl commented May 21, 2024

Well, no, but those are both used pretty infrequently. Actually, I forgot they already existed...

Yeah. So I hope there is already some sufficient coverage these days for these custom constants. They might be used infrequently, yes, but still I'd expect corner cases were found.

In any case, ptrauth is an opt-in option from the frontend perspective. So, we'd certainly run bunch of tests when there will be end-to-end support to catch something missed / not-covered here. Just my 2c.

@asl
Copy link
Collaborator

asl commented May 23, 2024

@nikic Will you please take another look when you will have a chance? This PR is currently blocker for many frontend and backend-related pauth changes.

Thanks!

Copy link
Contributor

@nikic nikic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@asl
Copy link
Collaborator

asl commented May 28, 2024

@ahmedbougacha This needs to be rebased and finally merged.

Tagging @Gerolf-Apple

This defines a new kind of IR Constant that represents a ptrauth signed
pointer, as used in AArch64 PAuth.

It allows representing most kinds of signed pointer constants used thus
far in the llvm ptrauth implementations, notably those used in the
Darwin and ELF ABIs being implemented for c/c++.  These signed pointer
constants are then lowered to ELF/MachO relocations.

These can be simply thought of as a constant `llvm.ptrauth.sign`, with
the interesting addition of discriminator computation: the `ptrauth`
constant can also represent a combined blend, when both address and
integer discriminator operands are used.  Both operands are otherwise
optional, with default values 0/null.

Co-Authored-by: Tim Northover <[email protected]>
@ahmedbougacha ahmedbougacha force-pushed the users/ahmedbougacha/ptrauth-constant branch from 817a4fc to c7779d0 Compare May 28, 2024 21:54
@ahmedbougacha ahmedbougacha merged commit 0edc97f into main May 28, 2024
6 of 7 checks passed
@ahmedbougacha ahmedbougacha deleted the users/ahmedbougacha/ptrauth-constant branch May 28, 2024 23:39
vg0204 pushed a commit to vg0204/llvm-project that referenced this pull request May 29, 2024
…inters. (llvm#85738)

This defines a new kind of IR Constant that represents a ptrauth signed
pointer, as used in AArch64 PAuth.

It allows representing most kinds of signed pointer constants used thus
far in the llvm ptrauth implementations, notably those used in the
Darwin and ELF ABIs being implemented for c/c++.  These signed pointer
constants are then lowered to ELF/MachO relocations.

These can be simply thought of as a constant `llvm.ptrauth.sign`, with
the interesting addition of discriminator computation: the `ptrauth`
constant can also represent a combined blend, when both address and
integer discriminator operands are used.  Both operands are otherwise
optional, with default values 0/null.
DECLARE_TRANSPARENT_OPERAND_ACCESSORS(Constant);

/// The pointer that is signed in this ptrauth signed pointer.
Constant *getPointer() const { return cast<Constant>(Op<0>().get()); }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ahmedbougacha Is it a desired behavior that we return a non-const-qualified pointer to associated data when calling getters allowed for calls on const-qualified object? Shouldn't we have two getter overloads, one with const for implicit this parameter and one without that?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have generally treated a lot of these sorts of types as value types, and there is little you can do with a mutable Constant that you can't with a const pointer. What little you can do is deeply suspicious and not common.
It's certainly reasonable to ask to add const variants here and elsewhere, but that comes at the cost of verbosity for arguable utility given the above – either way I'm very specifically trying not to litigate that here or in these PRs ;) This is just matching the rest of the file, see e.g. BlockAddress above

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
Status: Done
Development

Successfully merging this pull request may close these issues.

[PAC] Implement proper constant signing
7 participants