Skip to content

[LLVM][AArch64] Enable verifyTargetSDNode for scalable vectors and fix the fallout. #104820

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

paulwalker-arm
Copy link
Collaborator

@paulwalker-arm paulwalker-arm commented Aug 19, 2024

Fix incorrect use of AArch64ISD::UZP1/UUNPK{HI,LO} in:
AArch64TargetLowering::LowerDIV
AArch64TargetLowering::LowerINSERT_SUBVECTOR

The latter highlighted DAG combines that relied on broken behaviour, which this patch also fixes.

@llvmbot llvmbot added backend:AArch64 llvm:SelectionDAG SelectionDAGISel as well labels Aug 19, 2024
@llvmbot
Copy link
Member

llvmbot commented Aug 19, 2024

@llvm/pr-subscribers-llvm-selectiondag

@llvm/pr-subscribers-backend-aarch64

Author: Paul Walker (paulwalker-arm)

Changes

Fix incorrect use of AArch64ISD::UZP1 in:
AArch64TargetLowering::LowerDIV
AArch64TargetLowering::LowerINSERT_SUBVECTOR

The later highlighted DAG combines that relied on broken behaviour, which this patch also fixes.


Patch is 105.82 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/104820.diff

4 Files Affected:

  • (modified) llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp (+2)
  • (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.cpp (+99-43)
  • (modified) llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td (+56-107)
  • (modified) llvm/test/CodeGen/AArch64/sve-bitcast.ll (+243-779)
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index ab12c3b0e728a8..8d06206755d4f8 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -6140,6 +6140,8 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
     break;
   case ISD::BSWAP:
     assert(VT.isInteger() && VT == N1.getValueType() && "Invalid BSWAP!");
+    if (VT.getScalarSizeInBits() == 8)
+      return N1;
     assert((VT.getScalarSizeInBits() % 16 == 0) &&
            "BSWAP types must be a multiple of 16 bits!");
     if (OpOpcode == ISD::UNDEF)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 97fb2c5f552731..265a425125d02d 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1496,7 +1496,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
       setOperationAction(ISD::AVGCEILU, VT, Custom);
 
       if (!Subtarget->isLittleEndian())
-        setOperationAction(ISD::BITCAST, VT, Expand);
+        setOperationAction(ISD::BITCAST, VT, Custom);
 
       if (Subtarget->hasSVE2() ||
           (Subtarget->hasSME() && Subtarget->isStreaming()))
@@ -1510,9 +1510,8 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
       setOperationAction(ISD::INSERT_SUBVECTOR, VT, Custom);
     }
 
-    // Legalize unpacked bitcasts to REINTERPRET_CAST.
-    for (auto VT : {MVT::nxv2i16, MVT::nxv4i16, MVT::nxv2i32, MVT::nxv2bf16,
-                    MVT::nxv4bf16, MVT::nxv2f16, MVT::nxv4f16, MVT::nxv2f32})
+    // Type legalize unpacked bitcasts.
+    for (auto VT : {MVT::nxv2i16, MVT::nxv4i16, MVT::nxv2i32})
       setOperationAction(ISD::BITCAST, VT, Custom);
 
     for (auto VT :
@@ -1587,6 +1586,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
 
     for (auto VT : {MVT::nxv2f16, MVT::nxv4f16, MVT::nxv8f16, MVT::nxv2f32,
                     MVT::nxv4f32, MVT::nxv2f64}) {
+      setOperationAction(ISD::BITCAST, VT, Custom);
       setOperationAction(ISD::CONCAT_VECTORS, VT, Custom);
       setOperationAction(ISD::INSERT_SUBVECTOR, VT, Custom);
       setOperationAction(ISD::MLOAD, VT, Custom);
@@ -1658,20 +1658,15 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
       setCondCodeAction(ISD::SETUGT, VT, Expand);
       setCondCodeAction(ISD::SETUEQ, VT, Expand);
       setCondCodeAction(ISD::SETONE, VT, Expand);
-
-      if (!Subtarget->isLittleEndian())
-        setOperationAction(ISD::BITCAST, VT, Expand);
     }
 
     for (auto VT : {MVT::nxv2bf16, MVT::nxv4bf16, MVT::nxv8bf16}) {
+      setOperationAction(ISD::BITCAST, VT, Custom);
       setOperationAction(ISD::CONCAT_VECTORS, VT, Custom);
       setOperationAction(ISD::MLOAD, VT, Custom);
       setOperationAction(ISD::INSERT_SUBVECTOR, VT, Custom);
       setOperationAction(ISD::SPLAT_VECTOR, VT, Legal);
       setOperationAction(ISD::VECTOR_SPLICE, VT, Custom);
-
-      if (!Subtarget->isLittleEndian())
-        setOperationAction(ISD::BITCAST, VT, Expand);
     }
 
     setOperationAction(ISD::INTRINSIC_WO_CHAIN, MVT::i8, Custom);
@@ -4960,22 +4955,35 @@ SDValue AArch64TargetLowering::LowerBITCAST(SDValue Op,
     return LowerFixedLengthBitcastToSVE(Op, DAG);
 
   if (OpVT.isScalableVector()) {
-    // Bitcasting between unpacked vector types of different element counts is
-    // not a NOP because the live elements are laid out differently.
-    //                01234567
-    // e.g. nxv2i32 = XX??XX??
-    //      nxv4f16 = X?X?X?X?
-    if (OpVT.getVectorElementCount() != ArgVT.getVectorElementCount())
-      return SDValue();
+    assert(isTypeLegal(OpVT) && "Unexpected result type!");
 
-    if (isTypeLegal(OpVT) && !isTypeLegal(ArgVT)) {
+    // Handle type legalisation first.
+    if (!isTypeLegal(ArgVT)) {
       assert(OpVT.isFloatingPoint() && !ArgVT.isFloatingPoint() &&
              "Expected int->fp bitcast!");
+
+      // Bitcasting between unpacked vector types of different element counts is
+      // not a NOP because the live elements are laid out differently.
+      //                01234567
+      // e.g. nxv2i32 = XX??XX??
+      //      nxv4f16 = X?X?X?X?
+      if (OpVT.getVectorElementCount() != ArgVT.getVectorElementCount())
+        return SDValue();
+
       SDValue ExtResult =
           DAG.getNode(ISD::ANY_EXTEND, SDLoc(Op), getSVEContainerType(ArgVT),
                       Op.getOperand(0));
       return getSVESafeBitCast(OpVT, ExtResult, DAG);
     }
+
+    // Bitcasts between legal types with the same element count are legal.
+    if (OpVT.getVectorElementCount() == ArgVT.getVectorElementCount())
+      return Op;
+
+    // getSVESafeBitCast does not support casting between unpacked types.
+    if (!isPackedVectorType(OpVT, DAG))
+      return SDValue();
+
     return getSVESafeBitCast(OpVT, Op.getOperand(0), DAG);
   }
 
@@ -14877,10 +14885,11 @@ SDValue AArch64TargetLowering::LowerINSERT_SUBVECTOR(SDValue Op,
     // NOP cast operands to the largest legal vector of the same element count.
     if (VT.isFloatingPoint()) {
       Vec0 = getSVESafeBitCast(NarrowVT, Vec0, DAG);
-      Vec1 = getSVESafeBitCast(WideVT, Vec1, DAG);
+      Vec1 = getSVESafeBitCast(NarrowVT, Vec1, DAG);
     } else {
       // Legal integer vectors are already their largest so Vec0 is fine as is.
       Vec1 = DAG.getNode(ISD::ANY_EXTEND, DL, WideVT, Vec1);
+      Vec1 = DAG.getNode(AArch64ISD::NVCAST, DL, NarrowVT, Vec1);
     }
 
     // To replace the top/bottom half of vector V with vector SubV we widen the
@@ -14889,11 +14898,13 @@ SDValue AArch64TargetLowering::LowerINSERT_SUBVECTOR(SDValue Op,
     SDValue Narrow;
     if (Idx == 0) {
       SDValue HiVec0 = DAG.getNode(AArch64ISD::UUNPKHI, DL, WideVT, Vec0);
+      HiVec0 = DAG.getNode(AArch64ISD::NVCAST, DL, NarrowVT, HiVec0);
       Narrow = DAG.getNode(AArch64ISD::UZP1, DL, NarrowVT, Vec1, HiVec0);
     } else {
       assert(Idx == InVT.getVectorMinNumElements() &&
              "Invalid subvector index!");
       SDValue LoVec0 = DAG.getNode(AArch64ISD::UUNPKLO, DL, WideVT, Vec0);
+      LoVec0 = DAG.getNode(AArch64ISD::NVCAST, DL, NarrowVT, LoVec0);
       Narrow = DAG.getNode(AArch64ISD::UZP1, DL, NarrowVT, LoVec0, Vec1);
     }
 
@@ -14993,7 +15004,9 @@ SDValue AArch64TargetLowering::LowerDIV(SDValue Op, SelectionDAG &DAG) const {
   SDValue Op1Hi = DAG.getNode(UnpkHi, dl, WidenedVT, Op.getOperand(1));
   SDValue ResultLo = DAG.getNode(Op.getOpcode(), dl, WidenedVT, Op0Lo, Op1Lo);
   SDValue ResultHi = DAG.getNode(Op.getOpcode(), dl, WidenedVT, Op0Hi, Op1Hi);
-  return DAG.getNode(AArch64ISD::UZP1, dl, VT, ResultLo, ResultHi);
+  SDValue ResultLoCast = DAG.getNode(AArch64ISD::NVCAST, dl, VT, ResultLo);
+  SDValue ResultHiCast = DAG.getNode(AArch64ISD::NVCAST, dl, VT, ResultHi);
+  return DAG.getNode(AArch64ISD::UZP1, dl, VT, ResultLoCast, ResultHiCast);
 }
 
 bool AArch64TargetLowering::shouldExpandBuildVectorWithShuffles(
@@ -19811,7 +19824,6 @@ static SDValue performConcatVectorsCombine(SDNode *N,
     // This optimization reduces instruction count.
     if (N00Opc == AArch64ISD::VLSHR && N10Opc == AArch64ISD::VLSHR &&
         N00->getOperand(1) == N10->getOperand(1)) {
-
       SDValue N000 = N00->getOperand(0);
       SDValue N100 = N10->getOperand(0);
       uint64_t N001ConstVal = N00->getConstantOperandVal(1),
@@ -19819,7 +19831,8 @@ static SDValue performConcatVectorsCombine(SDNode *N,
                NScalarSize = N->getValueType(0).getScalarSizeInBits();
 
       if (N001ConstVal == N101ConstVal && N001ConstVal > NScalarSize) {
-
+        N000 = DAG.getNode(AArch64ISD::NVCAST, dl, VT, N000);
+        N100 = DAG.getNode(AArch64ISD::NVCAST, dl, VT, N100);
         SDValue Uzp = DAG.getNode(AArch64ISD::UZP2, dl, VT, N000, N100);
         SDValue NewShiftConstant =
             DAG.getConstant(N001ConstVal - NScalarSize, dl, MVT::i32);
@@ -22626,7 +22639,7 @@ static SDValue trySimplifySrlAddToRshrnb(SDValue Srl, SelectionDAG &DAG,
   SDValue Rshrnb = DAG.getNode(
       AArch64ISD::RSHRNB_I, DL, ResVT,
       {RShOperand, DAG.getTargetConstant(ShiftValue, DL, MVT::i32)});
-  return DAG.getNode(ISD::BITCAST, DL, VT, Rshrnb);
+  return DAG.getNode(AArch64ISD::NVCAST, DL, VT, Rshrnb);
 }
 
 static SDValue performUzpCombine(SDNode *N, SelectionDAG &DAG,
@@ -22689,25 +22702,41 @@ static SDValue performUzpCombine(SDNode *N, SelectionDAG &DAG,
   if (SDValue Urshr = tryCombineExtendRShTrunc(N, DAG))
     return Urshr;
 
-  if (SDValue Rshrnb = trySimplifySrlAddToRshrnb(Op0, DAG, Subtarget))
-    return DAG.getNode(AArch64ISD::UZP1, DL, ResVT, Rshrnb, Op1);
+  if (Op0.getOpcode() == AArch64ISD::NVCAST &&
+      ResVT.getVectorElementCount() == Op0.getOperand(0).getValueType().getVectorElementCount()*2){
+    if (SDValue Rshrnb = trySimplifySrlAddToRshrnb(Op0->getOperand(0), DAG, Subtarget)) {
+      Rshrnb = DAG.getNode(AArch64ISD::NVCAST, DL, ResVT, Rshrnb);
+      return DAG.getNode(AArch64ISD::UZP1, DL, ResVT, Rshrnb, Op1);
+    }
+  }
 
-  if (SDValue Rshrnb = trySimplifySrlAddToRshrnb(Op1, DAG, Subtarget))
-    return DAG.getNode(AArch64ISD::UZP1, DL, ResVT, Op0, Rshrnb);
+  if (Op1.getOpcode() == AArch64ISD::NVCAST &&
+      ResVT.getVectorElementCount() == Op1.getOperand(0).getValueType().getVectorElementCount()*2){
+    if (SDValue Rshrnb = trySimplifySrlAddToRshrnb(Op1->getOperand(0), DAG, Subtarget)) {
+      Rshrnb = DAG.getNode(AArch64ISD::NVCAST, DL, ResVT, Rshrnb);
+      return DAG.getNode(AArch64ISD::UZP1, DL, ResVT, Op0, Rshrnb);
+    }
+  }
 
-  // uzp1(unpklo(uzp1(x, y)), z) => uzp1(x, z)
-  if (Op0.getOpcode() == AArch64ISD::UUNPKLO) {
-    if (Op0.getOperand(0).getOpcode() == AArch64ISD::UZP1) {
-      SDValue X = Op0.getOperand(0).getOperand(0);
-      return DAG.getNode(AArch64ISD::UZP1, DL, ResVT, X, Op1);
+  // uzp1<ty>(nvcast(unpklo(uzp1<ty>(x, y))), z) => uzp1<ty>(x, z)
+  if (Op0.getOpcode() == AArch64ISD::NVCAST) {
+    if (Op0.getOperand(0).getOpcode() == AArch64ISD::UUNPKLO) {
+      if (Op0.getOperand(0).getOperand(0).getOpcode() == AArch64ISD::UZP1) {
+        SDValue X = Op0.getOperand(0).getOperand(0).getOperand(0);
+        if (ResVT == X.getValueType())
+          return DAG.getNode(AArch64ISD::UZP1, DL, ResVT, X, Op1);
+      }
     }
   }
 
-  // uzp1(x, unpkhi(uzp1(y, z))) => uzp1(x, z)
-  if (Op1.getOpcode() == AArch64ISD::UUNPKHI) {
-    if (Op1.getOperand(0).getOpcode() == AArch64ISD::UZP1) {
-      SDValue Z = Op1.getOperand(0).getOperand(1);
-      return DAG.getNode(AArch64ISD::UZP1, DL, ResVT, Op0, Z);
+  // uzp1<ty>(x, nvcast(unpkhi(uzp1<ty>(y, z)))) => uzp1<ty>(x, z)
+  if (Op1.getOpcode() == AArch64ISD::NVCAST) {
+    if (Op1.getOperand(0).getOpcode() == AArch64ISD::UUNPKHI) {
+      if (Op1.getOperand(0).getOperand(0).getOpcode() == AArch64ISD::UZP1) {
+        SDValue Z = Op1.getOperand(0).getOperand(0).getOperand(1);
+        if (ResVT == Z.getValueType())
+          return DAG.getNode(AArch64ISD::UZP1, DL, ResVT, Op0, Z);
+      }
     }
   }
 
@@ -28877,7 +28906,20 @@ SDValue AArch64TargetLowering::getSVESafeBitCast(EVT VT, SDValue Op,
   if (InVT != PackedInVT)
     Op = DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL, PackedInVT, Op);
 
-  Op = DAG.getNode(ISD::BITCAST, DL, PackedVT, Op);
+  if (Subtarget->isLittleEndian() ||
+      PackedVT.getScalarSizeInBits() == PackedInVT.getScalarSizeInBits())
+    Op = DAG.getNode(ISD::BITCAST, DL, PackedVT, Op);
+  else {
+    EVT PackedVTAsInt = PackedVT.changeTypeToInteger();
+    EVT PackedInVTAsInt = PackedInVT.changeTypeToInteger();
+
+    // Simulate the effect of casting through memory.
+    Op = DAG.getNode(ISD::BITCAST, DL, PackedInVTAsInt, Op);
+    Op = DAG.getNode(ISD::BSWAP, DL, PackedInVTAsInt, Op);
+    Op = DAG.getNode(AArch64ISD::NVCAST, DL, PackedVTAsInt, Op);
+    Op = DAG.getNode(ISD::BSWAP, DL, PackedVTAsInt, Op);
+    Op = DAG.getNode(ISD::BITCAST, DL, PackedVT, Op);
+  }
 
   // Unpack result if required.
   if (VT != PackedVT)
@@ -29263,9 +29305,8 @@ void AArch64TargetLowering::verifyTargetSDNode(const SDNode *N) const {
            VT.isInteger() && "Expected integer vectors!");
     assert(OpVT.getSizeInBits() == VT.getSizeInBits() &&
            "Expected vectors of equal size!");
-    // TODO: Enable assert once bogus creations have been fixed.
-    // assert(OpVT.getVectorElementCount() == VT.getVectorElementCount()*2 &&
-    //       "Expected result vector with half the lanes of its input!");
+    assert(OpVT.getVectorElementCount() == VT.getVectorElementCount()*2 &&
+           "Expected result vector with half the lanes of its input!");
     break;
   }
   case AArch64ISD::TRN1:
@@ -29281,8 +29322,23 @@ void AArch64TargetLowering::verifyTargetSDNode(const SDNode *N) const {
     EVT Op1VT = N->getOperand(1).getValueType();
     assert(VT.isVector() && Op0VT.isVector() && Op1VT.isVector() &&
            "Expected vectors!");
-    // TODO: Enable assert once bogus creations have been fixed.
-    // assert(VT == Op0VT && VT == Op1VT && "Expected matching vectors!");
+    assert(VT == Op0VT && VT == Op1VT && "Expected matching vectors!");
+    break;
+  }
+  case AArch64ISD::RSHRNB_I: {
+    assert(N->getNumValues() == 1 && "Expected one result!");
+    assert(N->getNumOperands() == 2 && "Expected two operand!");
+    EVT VT = N->getValueType(0);
+    EVT Op0VT = N->getOperand(0).getValueType();
+    EVT Op1VT = N->getOperand(1).getValueType();
+    assert(VT.isVector() && Op0VT.isVector() && VT.isInteger() &&
+           Op0VT.isInteger() && "Expected integer vectors!");
+    assert(VT.getSizeInBits() == Op0VT.getSizeInBits() &&
+           "Expected vectors of equal size!");
+    assert(VT.getVectorElementCount() == Op0VT.getVectorElementCount() * 2 &&
+           "Expected result vector with half the lanes of its input!");
+    assert(Op1VT == MVT::i32 && isa<ConstantSDNode>(N->getOperand(1)) &&
+           "Expected second operand to be a constant i32!");
     break;
   }
   }
diff --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
index d9a70b5ef02fcb..35035aae05ecb6 100644
--- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
@@ -2650,113 +2650,62 @@ let Predicates = [HasSVEorSME] in {
                                sub_32)>;
   }
 
-  // FIXME: BigEndian requires an additional REV instruction to satisfy the
-  // constraint that none of the bits change when stored to memory as one
-  // type, and reloaded as another type.
-  let Predicates = [IsLE] in {
-    def : Pat<(nxv16i8 (bitconvert nxv8i16:$src)), (nxv16i8 ZPR:$src)>;
-    def : Pat<(nxv16i8 (bitconvert nxv4i32:$src)), (nxv16i8 ZPR:$src)>;
-    def : Pat<(nxv16i8 (bitconvert nxv2i64:$src)), (nxv16i8 ZPR:$src)>;
-    def : Pat<(nxv16i8 (bitconvert nxv8f16:$src)), (nxv16i8 ZPR:$src)>;
-    def : Pat<(nxv16i8 (bitconvert nxv4f32:$src)), (nxv16i8 ZPR:$src)>;
-    def : Pat<(nxv16i8 (bitconvert nxv2f64:$src)), (nxv16i8 ZPR:$src)>;
-
-    def : Pat<(nxv8i16 (bitconvert nxv16i8:$src)), (nxv8i16 ZPR:$src)>;
-    def : Pat<(nxv8i16 (bitconvert nxv4i32:$src)), (nxv8i16 ZPR:$src)>;
-    def : Pat<(nxv8i16 (bitconvert nxv2i64:$src)), (nxv8i16 ZPR:$src)>;
-    def : Pat<(nxv8i16 (bitconvert nxv8f16:$src)), (nxv8i16 ZPR:$src)>;
-    def : Pat<(nxv8i16 (bitconvert nxv4f32:$src)), (nxv8i16 ZPR:$src)>;
-    def : Pat<(nxv8i16 (bitconvert nxv2f64:$src)), (nxv8i16 ZPR:$src)>;
-
-    def : Pat<(nxv4i32 (bitconvert nxv16i8:$src)), (nxv4i32 ZPR:$src)>;
-    def : Pat<(nxv4i32 (bitconvert nxv8i16:$src)), (nxv4i32 ZPR:$src)>;
-    def : Pat<(nxv4i32 (bitconvert nxv2i64:$src)), (nxv4i32 ZPR:$src)>;
-    def : Pat<(nxv4i32 (bitconvert nxv8f16:$src)), (nxv4i32 ZPR:$src)>;
-    def : Pat<(nxv4i32 (bitconvert nxv4f32:$src)), (nxv4i32 ZPR:$src)>;
-    def : Pat<(nxv4i32 (bitconvert nxv2f64:$src)), (nxv4i32 ZPR:$src)>;
-
-    def : Pat<(nxv2i64 (bitconvert nxv16i8:$src)), (nxv2i64 ZPR:$src)>;
-    def : Pat<(nxv2i64 (bitconvert nxv8i16:$src)), (nxv2i64 ZPR:$src)>;
-    def : Pat<(nxv2i64 (bitconvert nxv4i32:$src)), (nxv2i64 ZPR:$src)>;
-    def : Pat<(nxv2i64 (bitconvert nxv8f16:$src)), (nxv2i64 ZPR:$src)>;
-    def : Pat<(nxv2i64 (bitconvert nxv4f32:$src)), (nxv2i64 ZPR:$src)>;
-    def : Pat<(nxv2i64 (bitconvert nxv2f64:$src)), (nxv2i64 ZPR:$src)>;
-
-    def : Pat<(nxv8f16 (bitconvert nxv16i8:$src)), (nxv8f16 ZPR:$src)>;
-    def : Pat<(nxv8f16 (bitconvert nxv8i16:$src)), (nxv8f16 ZPR:$src)>;
-    def : Pat<(nxv8f16 (bitconvert nxv4i32:$src)), (nxv8f16 ZPR:$src)>;
-    def : Pat<(nxv8f16 (bitconvert nxv2i64:$src)), (nxv8f16 ZPR:$src)>;
-    def : Pat<(nxv8f16 (bitconvert nxv4f32:$src)), (nxv8f16 ZPR:$src)>;
-    def : Pat<(nxv8f16 (bitconvert nxv2f64:$src)), (nxv8f16 ZPR:$src)>;
-
-    def : Pat<(nxv4f32 (bitconvert nxv16i8:$src)), (nxv4f32 ZPR:$src)>;
-    def : Pat<(nxv4f32 (bitconvert nxv8i16:$src)), (nxv4f32 ZPR:$src)>;
-    def : Pat<(nxv4f32 (bitconvert nxv4i32:$src)), (nxv4f32 ZPR:$src)>;
-    def : Pat<(nxv4f32 (bitconvert nxv2i64:$src)), (nxv4f32 ZPR:$src)>;
-    def : Pat<(nxv4f32 (bitconvert nxv8f16:$src)), (nxv4f32 ZPR:$src)>;
-    def : Pat<(nxv4f32 (bitconvert nxv2f64:$src)), (nxv4f32 ZPR:$src)>;
-
-    def : Pat<(nxv2f64 (bitconvert nxv16i8:$src)), (nxv2f64 ZPR:$src)>;
-    def : Pat<(nxv2f64 (bitconvert nxv8i16:$src)), (nxv2f64 ZPR:$src)>;
-    def : Pat<(nxv2f64 (bitconvert nxv4i32:$src)), (nxv2f64 ZPR:$src)>;
-    def : Pat<(nxv2f64 (bitconvert nxv2i64:$src)), (nxv2f64 ZPR:$src)>;
-    def : Pat<(nxv2f64 (bitconvert nxv8f16:$src)), (nxv2f64 ZPR:$src)>;
-    def : Pat<(nxv2f64 (bitconvert nxv4f32:$src)), (nxv2f64 ZPR:$src)>;
-
-    def : Pat<(nxv8bf16 (bitconvert nxv16i8:$src)), (nxv8bf16 ZPR:$src)>;
-    def : Pat<(nxv8bf16 (bitconvert nxv8i16:$src)), (nxv8bf16 ZPR:$src)>;
-    def : Pat<(nxv8bf16 (bitconvert nxv4i32:$src)), (nxv8bf16 ZPR:$src)>;
-    def : Pat<(nxv8bf16 (bitconvert nxv2i64:$src)), (nxv8bf16 ZPR:$src)>;
-    def : Pat<(nxv8bf16 (bitconvert nxv8f16:$src)), (nxv8bf16 ZPR:$src)>;
-    def : Pat<(nxv8bf16 (bitconvert nxv4f32:$src)), (nxv8bf16 ZPR:$src)>;
-    def : Pat<(nxv8bf16 (bitconvert nxv2f64:$src)), (nxv8bf16 ZPR:$src)>;
-
-    def : Pat<(nxv16i8 (bitconvert nxv8bf16:$src)), (nxv16i8 ZPR:$src)>;
-    def : Pat<(nxv8i16 (bitconvert nxv8bf16:$src)), (nxv8i16 ZPR:$src)>;
-    def : Pat<(nxv4i32 (bitconvert nxv8bf16:$src)), (nxv4i32 ZPR:$src)>;
-    def : Pat<(nxv2i64 (bitconvert nxv8bf16:$src)), (nxv2i64 ZPR:$src)>;
-    def : Pat<(nxv8f16 (bitconvert nxv8bf16:$src)), (nxv8f16 ZPR:$src)>;
-    def : Pat<(nxv4f32 (bitconvert nxv8bf16:$src)), (nxv4f32 ZPR:$src)>;
-    def : Pat<(nxv2f64 (bitconvert nxv8bf16:$src)), (nxv2f64 ZPR:$src)>;
-
-    def : Pat<(nxv16i1 (bitconvert aarch64svcount:$src)), (nxv16i1 PPR:$src)>;
-    def : Pat<(aarch64svcount (bitconvert nxv16i1:$src)), (aarch64svcount PNR:$src)>;
-  }
-
-  // These allow casting from/to unpacked predicate types.
-  def : Pat<(nxv16i1 (reinterpret_cast nxv16i1:$src)), (COPY_TO_REGCLASS PPR:$src, PPR)>;
-  def : Pat<(nxv16i1 (reinterpret_cast nxv8i1:$src)), (COPY_TO_REGCLASS PPR:$src, PPR)>;
-  def : Pat<(nxv16i1 (reinterpret_cast nxv4i1:$src)), (COPY_TO_REGCLASS PPR:$src, PPR)>;
-  def : Pat<(nxv16i1 (reinterpret_cast nxv2i1:$src)), (COPY_TO_REGCLASS PPR:$src, PPR)>;
-  def : Pat<(nxv16i1 (reinterpret_cast nxv1i1:$src)), (COPY_TO_REGCLASS PPR:$src, PPR)>;
-  def : Pat<(nxv8i1 (reinterpret_cast nxv16i1:$src)), (COPY_TO_REGCLASS PPR:$src, PPR)>;
-  def : Pat<(nxv8i1 (reinterpret_cast  nxv4i1:$src)), (COPY_TO_REGCLASS PPR:$src, PPR)>;
-  def : Pat<(nxv8i1 (reinterpret_cast  nxv2i1:$src)), (COPY_TO_REGCLASS PPR:$src, PPR)>;
-  def : Pat<(nxv8i1 (reinterpret_cast  nxv1i1:$src)), (COPY_TO_REGCLASS PPR:$src, PPR)>;
-  def : Pat<(nxv4i1 (reinterpret_cast nxv16i1:$src)), (COPY_TO_REGCLASS PPR:$src, PPR)>;
-  def : Pat<(nxv4i1 (reinterpret_cast  nxv8i1:$src)), (COPY_TO_REGCLASS PPR:$src, PPR)>;
-  def : Pat<(nxv4i1 (reinterpret_cast  nxv2i1:$src)), (COPY_TO_REGCLASS PPR:$src, PPR)>;
-  def : Pat<(nxv4i1 (reinterpret_cast  nxv1i1:$src)), (COPY_TO_REGCLASS PPR:$src, PPR)>;
-  def : Pat<(nxv2i1 (reinterpret_cast nxv16i1:$src)), (COPY_TO_REGCLASS PPR:$src, PPR)>;
-  def : Pat<(nxv2i1 (reinterpret_cast  nxv8i1:$src)), (COPY_TO_REGCLASS PPR:$src, PPR)>;
-  def : Pat<(nxv2i1 (reinterpret_cast  nxv...
[truncated]

Copy link

github-actions bot commented Aug 19, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

@paulwalker-arm paulwalker-arm force-pushed the sve-invalid-node-construction branch from 245d3c6 to 747fde0 Compare August 20, 2024 13:07
…x the fallout.

Fix incorrect use of AArch64ISD::UZP1/UUNPK{HI,LO} in:
  AArch64TargetLowering::LowerDIV
  AArch64TargetLowering::LowerINSERT_SUBVECTOR

The latter highlighted DAG combines that relied on broken behaviour,
which this patch also fixes.
@paulwalker-arm paulwalker-arm force-pushed the sve-invalid-node-construction branch from 747fde0 to 8ed29b7 Compare August 30, 2024 12:57
@paulwalker-arm paulwalker-arm changed the title [Work In Progress][LLVM][AArch64] Enable verifyTargetSDNode for scalable vectors and fix the fallout. [LLVM][AArch64] Enable verifyTargetSDNode for scalable vectors and fix the fallout. Aug 30, 2024
Comment on lines +22682 to +22685
SDValue Op = V.getOperand(0);
if (V.getValueType().getVectorElementCount() !=
Op.getValueType().getVectorElementCount() * 2)
return SDValue();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pedantic question: If possible here, what happens if one type is a double tuple and one is a quad tuple here? Does getVectorElementCount() return the element count for each of the scalable vectors in the quad or just one scalable type within the tuple?

Copy link
Collaborator Author

@paulwalker-arm paulwalker-arm Aug 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

getValueType() returns an EVT, which has no concept of tuple types. The implementation of getVectorElementCount will itself assert that it is a plain vector type.

assert(VT == Op0VT && VT == Op1VT && "Expected matching vectors!");
break;
}
case AArch64ISD::RSHRNB_I: {
assert(N->getNumValues() == 1 && "Expected one result!");
assert(N->getNumOperands() == 2 && "Expected two operand!");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
assert(N->getNumOperands() == 2 && "Expected two operand!");
assert(N->getNumOperands() == 2 && "Expected two operands!");

EVT Op0VT = N->getOperand(0).getValueType();
EVT Op1VT = N->getOperand(1).getValueType();
assert(VT.isVector() && Op0VT.isVector() && VT.isInteger() &&
Op0VT.isInteger() && "Expected integer vectors!");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the message for the second operand is very specific, this message could be more specific about which operand numbers this refers to

@paulwalker-arm paulwalker-arm merged commit 2fef449 into llvm:main Sep 4, 2024
8 checks passed
@paulwalker-arm paulwalker-arm deleted the sve-invalid-node-construction branch September 20, 2024 13:43
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
backend:AArch64 llvm:SelectionDAG SelectionDAGISel as well
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants