Skip to content

[SelectionDAG] Add an ISD node for vector.extract.last.active #118810

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

huntergr-arm
Copy link
Collaborator

@huntergr-arm huntergr-arm commented Dec 5, 2024

Based on feedback from the clastb codegen PR, I'm refactoring basic codegen for the vector.extract.last.active intrinsic to lower to an ISD node in SelectionDAGBuilder then expand in LegalizeVectorOps, instead of doing everything in the builder.

@llvmbot llvmbot added the llvm:SelectionDAG SelectionDAGISel as well label Dec 5, 2024
@huntergr-arm huntergr-arm marked this pull request as draft December 5, 2024 14:20
@llvmbot
Copy link
Member

llvmbot commented Dec 5, 2024

@llvm/pr-subscribers-backend-aarch64

@llvm/pr-subscribers-llvm-selectiondag

Author: Graham Hunter (huntergr-arm)

Changes

Based on feedback from the clastb codegen PR, I'm refactoring basic codegen for the vector.extract.last.active intrinsic to lower to an ISD node in SelectionDAGBuilder then expand in LegalizeVectorOps, instead of doing everything in the builder.

This doesn't quite work yet, but I'm sharing it now to help with similar refactoring for the partial reduction intrinsic codegen.


Full diff: https://github.com/llvm/llvm-project/pull/118810.diff

8 Files Affected:

  • (modified) llvm/include/llvm/CodeGen/ISDOpcodes.h (+4)
  • (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp (+50)
  • (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h (+4)
  • (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeTypesGeneric.cpp (+10)
  • (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp (+61)
  • (modified) llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp (+42-31)
  • (modified) llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp (+3)
  • (modified) llvm/lib/CodeGen/TargetLoweringBase.cpp (+3)
diff --git a/llvm/include/llvm/CodeGen/ISDOpcodes.h b/llvm/include/llvm/CodeGen/ISDOpcodes.h
index 0b6d155b6d161e..f0a8ccc41eb3d3 100644
--- a/llvm/include/llvm/CodeGen/ISDOpcodes.h
+++ b/llvm/include/llvm/CodeGen/ISDOpcodes.h
@@ -1480,6 +1480,10 @@ enum NodeType {
   // Output: Output Chain
   EXPERIMENTAL_VECTOR_HISTOGRAM,
 
+  // experimental.vector.extract.last.active intrinsic
+  // Operands: Data, Mask, PassThru
+  VECTOR_EXTRACT_LAST_ACTIVE,
+
   // llvm.clear_cache intrinsic
   // Operands: Input Chain, Start Addres, End Address
   // Outputs: Output Chain
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
index 986d69e6c7a9e0..73c78c556158ba 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
@@ -155,6 +155,10 @@ void DAGTypeLegalizer::PromoteIntegerResult(SDNode *N, unsigned ResNo) {
   case ISD::ZERO_EXTEND_VECTOR_INREG:
                          Res = PromoteIntRes_EXTEND_VECTOR_INREG(N); break;
 
+  case ISD::VECTOR_EXTRACT_LAST_ACTIVE:
+    Res = PromoteIntRes_VECTOR_EXTRACT_LAST_ACTIVE(N);
+    break;
+
   case ISD::SIGN_EXTEND:
   case ISD::VP_SIGN_EXTEND:
   case ISD::ZERO_EXTEND:
@@ -2069,6 +2073,9 @@ bool DAGTypeLegalizer::PromoteIntegerOperand(SDNode *N, unsigned OpNo) {
   case ISD::EXPERIMENTAL_VECTOR_HISTOGRAM:
     Res = PromoteIntOp_VECTOR_HISTOGRAM(N, OpNo);
     break;
+  case ISD::VECTOR_EXTRACT_LAST_ACTIVE:
+    Res = PromoteIntOp_VECTOR_EXTRACT_LAST_ACTIVE(N, OpNo);
+    break;
   }
 
   // If the result is null, the sub-method took care of registering results etc.
@@ -2803,6 +2810,14 @@ SDValue DAGTypeLegalizer::PromoteIntOp_VECTOR_HISTOGRAM(SDNode *N,
   return SDValue(DAG.UpdateNodeOperands(N, NewOps), 0);
 }
 
+SDValue
+DAGTypeLegalizer::PromoteIntOp_VECTOR_EXTRACT_LAST_ACTIVE(SDNode *N,
+                                                          unsigned OpNo) {
+  SmallVector<SDValue, 3> NewOps(N->ops());
+  NewOps[OpNo] = GetPromotedInteger(N->getOperand(OpNo));
+  return SDValue(DAG.UpdateNodeOperands(N, NewOps), 0);
+}
+
 //===----------------------------------------------------------------------===//
 //  Integer Result Expansion
 //===----------------------------------------------------------------------===//
@@ -2840,6 +2855,9 @@ void DAGTypeLegalizer::ExpandIntegerResult(SDNode *N, unsigned ResNo) {
   case ISD::BUILD_PAIR:         ExpandRes_BUILD_PAIR(N, Lo, Hi); break;
   case ISD::EXTRACT_ELEMENT:    ExpandRes_EXTRACT_ELEMENT(N, Lo, Hi); break;
   case ISD::EXTRACT_VECTOR_ELT: ExpandRes_EXTRACT_VECTOR_ELT(N, Lo, Hi); break;
+  case ISD::VECTOR_EXTRACT_LAST_ACTIVE:
+    ExpandRes_VECTOR_EXTRACT_LAST_ACTIVE(N, Lo, Hi);
+    break;
   case ISD::VAARG:              ExpandRes_VAARG(N, Lo, Hi); break;
 
   case ISD::ANY_EXTEND:  ExpandIntRes_ANY_EXTEND(N, Lo, Hi); break;
@@ -6102,6 +6120,38 @@ SDValue DAGTypeLegalizer::PromoteIntRes_EXTEND_VECTOR_INREG(SDNode *N) {
   return DAG.getNode(N->getOpcode(), dl, NVT, N->getOperand(0));
 }
 
+SDValue DAGTypeLegalizer::PromoteIntRes_VECTOR_EXTRACT_LAST_ACTIVE(SDNode *N) {
+  EVT VT = N->getValueType(0);
+  EVT NVT = TLI.getTypeToTransformTo(*DAG.getContext(), VT);
+
+  SDLoc dl(N);
+
+  //  SDValue Data = N->getOperand(0);
+  //  SDValue Mask = N->getOperand(1);
+  //  SDValue PassThru = N->getOperand(2);
+  //
+  return DAG.getNode(ISD::VECTOR_EXTRACT_LAST_ACTIVE, dl, NVT, N->ops());
+
+  //
+  //  // If the input also needs to be promoted, do that first so we can get a
+  //  // get a good idea for the output type.
+  //  if (TLI.getTypeAction(*DAG.getContext(), Op0.getValueType())
+  //      == TargetLowering::TypePromoteInteger) {
+  //    SDValue In = GetPromotedInteger(Op0);
+  //
+  //    // If the new type is larger than NVT, use it. We probably won't need to
+  //    // promote it again.
+  //    EVT SVT = In.getValueType().getScalarType();
+  //    if (SVT.bitsGE(NVT)) {
+  //      SDValue Ext = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, SVT, In, Op1);
+  //      return DAG.getAnyExtOrTrunc(Ext, dl, NVT);
+  //    }
+  //  }
+  //
+  //  return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, NVT, Op0, Op1);
+  //
+}
+
 SDValue DAGTypeLegalizer::PromoteIntRes_INSERT_VECTOR_ELT(SDNode *N) {
   EVT OutVT = N->getValueType(0);
   EVT NOutVT = TLI.getTypeToTransformTo(*DAG.getContext(), OutVT);
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
index 1703149aca7463..ef3cd66df25363 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
@@ -378,6 +378,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
   SDValue PromoteIntRes_VPFunnelShift(SDNode *N);
   SDValue PromoteIntRes_IS_FPCLASS(SDNode *N);
   SDValue PromoteIntRes_PATCHPOINT(SDNode *N);
+  SDValue PromoteIntRes_VECTOR_EXTRACT_LAST_ACTIVE(SDNode *N);
 
   // Integer Operand Promotion.
   bool PromoteIntegerOperand(SDNode *N, unsigned OpNo);
@@ -428,6 +429,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
   SDValue PromoteIntOp_VP_STRIDED(SDNode *N, unsigned OpNo);
   SDValue PromoteIntOp_VP_SPLICE(SDNode *N, unsigned OpNo);
   SDValue PromoteIntOp_VECTOR_HISTOGRAM(SDNode *N, unsigned OpNo);
+  SDValue PromoteIntOp_VECTOR_EXTRACT_LAST_ACTIVE(SDNode *N, unsigned OpNo);
 
   void SExtOrZExtPromotedOperands(SDValue &LHS, SDValue &RHS);
   void PromoteSetCCOperands(SDValue &LHS,SDValue &RHS, ISD::CondCode Code);
@@ -1214,6 +1216,8 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
   void ExpandRes_BUILD_PAIR        (SDNode *N, SDValue &Lo, SDValue &Hi);
   void ExpandRes_EXTRACT_ELEMENT   (SDNode *N, SDValue &Lo, SDValue &Hi);
   void ExpandRes_EXTRACT_VECTOR_ELT(SDNode *N, SDValue &Lo, SDValue &Hi);
+  void ExpandRes_VECTOR_EXTRACT_LAST_ACTIVE(SDNode *N, SDValue &Lo,
+                                            SDValue &Hi);
   void ExpandRes_NormalLoad        (SDNode *N, SDValue &Lo, SDValue &Hi);
   void ExpandRes_VAARG             (SDNode *N, SDValue &Lo, SDValue &Hi);
 
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypesGeneric.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypesGeneric.cpp
index 2655e8428309da..cbd04bd3d67e6e 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypesGeneric.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypesGeneric.cpp
@@ -244,6 +244,16 @@ void DAGTypeLegalizer::ExpandRes_EXTRACT_VECTOR_ELT(SDNode *N, SDValue &Lo,
     std::swap(Lo, Hi);
 }
 
+void DAGTypeLegalizer::ExpandRes_VECTOR_EXTRACT_LAST_ACTIVE(SDNode *N,
+                                                            SDValue &Lo,
+                                                            SDValue &Hi) {
+// FIXME: We need to do this by casting to smaller elements, deinterleaving,
+//        then performing 2 extract_last_active operations and returning the
+//        two parts.
+
+  llvm_unreachable("Implement extract_last_active expand result!");
+}
+
 void DAGTypeLegalizer::ExpandRes_NormalLoad(SDNode *N, SDValue &Lo,
                                             SDValue &Hi) {
   assert(ISD::isNormalLoad(N) && "This routine only for normal loads!");
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
index db21e708970648..8e213f64536134 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
@@ -29,6 +29,7 @@
 #include "llvm/ADT/DenseMap.h"
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/Analysis/TargetLibraryInfo.h"
+#include "llvm/Analysis/ValueTracking.h"
 #include "llvm/Analysis/VectorUtils.h"
 #include "llvm/CodeGen/ISDOpcodes.h"
 #include "llvm/CodeGen/SelectionDAG.h"
@@ -138,6 +139,7 @@ class VectorLegalizer {
   SDValue ExpandVP_FNEG(SDNode *Node);
   SDValue ExpandVP_FABS(SDNode *Node);
   SDValue ExpandVP_FCOPYSIGN(SDNode *Node);
+  SDValue ExpandVECTOR_EXTRACT_LAST_ACTIVE(SDNode *Node);
   SDValue ExpandSELECT(SDNode *Node);
   std::pair<SDValue, SDValue> ExpandLoad(SDNode *N);
   SDValue ExpandStore(SDNode *N);
@@ -465,6 +467,7 @@ SDValue VectorLegalizer::LegalizeOp(SDValue Op) {
   case ISD::VECTOR_COMPRESS:
   case ISD::SCMP:
   case ISD::UCMP:
+  case ISD::VECTOR_EXTRACT_LAST_ACTIVE:
     Action = TLI.getOperationAction(Node->getOpcode(), Node->getValueType(0));
     break;
   case ISD::SMULFIX:
@@ -1202,6 +1205,9 @@ void VectorLegalizer::Expand(SDNode *Node, SmallVectorImpl<SDValue> &Results) {
   case ISD::VECTOR_COMPRESS:
     Results.push_back(TLI.expandVECTOR_COMPRESS(Node, DAG));
     return;
+  case ISD::VECTOR_EXTRACT_LAST_ACTIVE:
+    Results.push_back(ExpandVECTOR_EXTRACT_LAST_ACTIVE(Node));
+    return;
   case ISD::SCMP:
   case ISD::UCMP:
     Results.push_back(TLI.expandCMP(Node, DAG));
@@ -1713,6 +1719,61 @@ SDValue VectorLegalizer::ExpandVP_FCOPYSIGN(SDNode *Node) {
   return DAG.getNode(ISD::BITCAST, DL, VT, CopiedSign);
 }
 
+SDValue VectorLegalizer::ExpandVECTOR_EXTRACT_LAST_ACTIVE(SDNode *Node) {
+  dbgs() << "Expanding extract_last_active!!\n";
+  SDLoc DL(Node);
+  SDValue Data = Node->getOperand(0);
+  SDValue Mask = Node->getOperand(1);
+  SDValue PassThru = Node->getOperand(2);
+
+  EVT DataVT = Data.getValueType();
+  EVT ScalarVT = PassThru.getValueType();
+  EVT BoolVT = Mask.getValueType().getScalarType();
+
+  // Find a suitable type for a stepvector.
+  ConstantRange VScaleRange(1, /*isFullSet=*/true); // Dummy value.
+  if (DataVT.isScalableVector())
+    VScaleRange = getVScaleRange(&DAG.getMachineFunction().getFunction(), 64);
+  const TargetLowering &TLI = DAG.getTargetLoweringInfo();
+  unsigned EltWidth = TLI.getBitWidthForCttzElements(
+      ScalarVT.getTypeForEVT(*DAG.getContext()), DataVT.getVectorElementCount(),
+      /*ZeroIsPoison=*/true, &VScaleRange);
+  EVT StepVT = MVT::getIntegerVT(EltWidth);
+  EVT StepVecVT = DataVT.changeVectorElementType(StepVT);
+
+  // Promote to a legal type if necessary.
+  if (TLI.getTypeAction(StepVecVT.getSimpleVT()) ==
+      TargetLowering::TypePromoteInteger) {
+    StepVecVT = TLI.getTypeToTransformTo(*DAG.getContext(), StepVecVT);
+    StepVT = StepVecVT.getVectorElementType();
+  }
+
+  // Zero out lanes with inactive elements, then find the highest remaining
+  // value from the stepvector.
+  SDValue Zeroes = DAG.getConstant(0, DL, StepVecVT);
+  SDValue StepVec = DAG.getStepVector(DL, StepVecVT);
+  SDValue ActiveElts = DAG.getSelect(DL, StepVecVT, Mask, StepVec, Zeroes);
+  // Unfortunately, VectorLegalizer does not recursively legalize all added
+  // nodes, just the end result nodes. LegalizeDAG doesn't handle VSELECT at
+  // all presently. So if we need to legalize a vselect then we have to do
+  // it here.
+  if (!TLI.isTypeLegal(StepVecVT) ||
+      TLI.getOperationAction(ISD::VSELECT, StepVecVT) == TargetLowering::Expand)
+    ActiveElts = LegalizeOp(ActiveElts);
+
+  SDValue HighestIdx = DAG.getNode(ISD::VECREDUCE_UMAX, DL, StepVT, ActiveElts);
+
+  // Extract the corresponding lane from the data vector
+  EVT ExtVT = TLI.getVectorIdxTy(DAG.getDataLayout());
+  SDValue Idx = DAG.getZExtOrTrunc(HighestIdx, DL, ExtVT);
+  SDValue Extract =
+      DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ScalarVT, Data, Idx);
+
+  // If all mask lanes were inactive, choose the passthru value instead.
+  SDValue AnyActive = DAG.getNode(ISD::VECREDUCE_OR, DL, BoolVT, Mask);
+  return DAG.getSelect(DL, ScalarVT, AnyActive, Extract, PassThru);
+}
+
 void VectorLegalizer::ExpandFP_TO_UINT(SDNode *Node,
                                        SmallVectorImpl<SDValue> &Results) {
   // Attempt to expand using TargetLowering.
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index b72c5eff22f183..1f3e787df9ab03 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -6431,44 +6431,55 @@ void SelectionDAGBuilder::visitVectorExtractLastActive(const CallInst &I,
                                                        unsigned Intrinsic) {
   assert(Intrinsic == Intrinsic::experimental_vector_extract_last_active &&
          "Tried lowering invalid vector extract last");
+
+  //           VECTOR_EXTRACT_LAST_ACTIVE,
+
   SDLoc sdl = getCurSDLoc();
   SDValue Data = getValue(I.getOperand(0));
   SDValue Mask = getValue(I.getOperand(1));
   SDValue PassThru = getValue(I.getOperand(2));
 
-  EVT DataVT = Data.getValueType();
-  EVT ScalarVT = PassThru.getValueType();
-  EVT BoolVT = Mask.getValueType().getScalarType();
-
-  // Find a suitable type for a stepvector.
-  ConstantRange VScaleRange(1, /*isFullSet=*/true); // Dummy value.
-  if (DataVT.isScalableVector())
-    VScaleRange = getVScaleRange(I.getCaller(), 64);
   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
-  unsigned EltWidth = TLI.getBitWidthForCttzElements(
-      I.getType(), DataVT.getVectorElementCount(), /*ZeroIsPoison=*/true,
-      &VScaleRange);
-  MVT StepVT = MVT::getIntegerVT(EltWidth);
-  EVT StepVecVT = DataVT.changeVectorElementType(StepVT);
-
-  // Zero out lanes with inactive elements, then find the highest remaining
-  // value from the stepvector.
-  SDValue Zeroes = DAG.getConstant(0, sdl, StepVecVT);
-  SDValue StepVec = DAG.getStepVector(sdl, StepVecVT);
-  SDValue ActiveElts = DAG.getSelect(sdl, StepVecVT, Mask, StepVec, Zeroes);
-  SDValue HighestIdx =
-      DAG.getNode(ISD::VECREDUCE_UMAX, sdl, StepVT, ActiveElts);
-
-  // Extract the corresponding lane from the data vector
-  EVT ExtVT = TLI.getVectorIdxTy(DAG.getDataLayout());
-  SDValue Idx = DAG.getZExtOrTrunc(HighestIdx, sdl, ExtVT);
-  SDValue Extract =
-      DAG.getNode(ISD::EXTRACT_VECTOR_ELT, sdl, ScalarVT, Data, Idx);
-
-  // If all mask lanes were inactive, choose the passthru value instead.
-  SDValue AnyActive = DAG.getNode(ISD::VECREDUCE_OR, sdl, BoolVT, Mask);
-  SDValue Result = DAG.getSelect(sdl, ScalarVT, AnyActive, Extract, PassThru);
+  EVT ResultVT = TLI.getValueType(DAG.getDataLayout(), I.getType());
+
+  SDValue Result = DAG.getNode(ISD::VECTOR_EXTRACT_LAST_ACTIVE, sdl, ResultVT,
+                               Data, Mask, PassThru);
+
   setValue(&I, Result);
+
+  //  EVT DataVT = Data.getValueType();
+  //  EVT ScalarVT = PassThru.getValueType();
+  //  EVT BoolVT = Mask.getValueType().getScalarType();
+  //
+  //  // Find a suitable type for a stepvector.
+  //  ConstantRange VScaleRange(1, /*isFullSet=*/true); // Dummy value.
+  //  if (DataVT.isScalableVector())
+  //    VScaleRange = getVScaleRange(I.getCaller(), 64);
+  //  const TargetLowering &TLI = DAG.getTargetLoweringInfo();
+  //  unsigned EltWidth = TLI.getBitWidthForCttzElements(
+  //      I.getType(), DataVT.getVectorElementCount(), /*ZeroIsPoison=*/true,
+  //      &VScaleRange);
+  //  MVT StepVT = MVT::getIntegerVT(EltWidth);
+  //  EVT StepVecVT = DataVT.changeVectorElementType(StepVT);
+  //
+  //  // Zero out lanes with inactive elements, then find the highest remaining
+  //  // value from the stepvector.
+  //  SDValue Zeroes = DAG.getConstant(0, sdl, StepVecVT);
+  //  SDValue StepVec = DAG.getStepVector(sdl, StepVecVT);
+  //  SDValue ActiveElts = DAG.getSelect(sdl, StepVecVT, Mask, StepVec, Zeroes);
+  //  SDValue HighestIdx =
+  //      DAG.getNode(ISD::VECREDUCE_UMAX, sdl, StepVT, ActiveElts);
+  //
+  //  // Extract the corresponding lane from the data vector
+  //  EVT ExtVT = TLI.getVectorIdxTy(DAG.getDataLayout());
+  //  SDValue Idx = DAG.getZExtOrTrunc(HighestIdx, sdl, ExtVT);
+  //  SDValue Extract =
+  //      DAG.getNode(ISD::EXTRACT_VECTOR_ELT, sdl, ScalarVT, Data, Idx);
+  //
+  //  // If all mask lanes were inactive, choose the passthru value instead.
+  //  SDValue AnyActive = DAG.getNode(ISD::VECREDUCE_OR, sdl, BoolVT, Mask);
+  //  SDValue Result = DAG.getSelect(sdl, ScalarVT, AnyActive, Extract,
+  //  PassThru); setValue(&I, Result);
 }
 
 /// Lower the call to the specified intrinsic function.
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
index 580ff19065557b..42cbb721703d99 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
@@ -567,6 +567,9 @@ std::string SDNode::getOperationName(const SelectionDAG *G) const {
   case ISD::EXPERIMENTAL_VECTOR_HISTOGRAM:
     return "histogram";
 
+  case ISD::VECTOR_EXTRACT_LAST_ACTIVE:
+    return "extract_last_active";
+
     // Vector Predication
 #define BEGIN_REGISTER_VP_SDNODE(SDID, LEGALARG, NAME, ...)                    \
   case ISD::SDID:                                                              \
diff --git a/llvm/lib/CodeGen/TargetLoweringBase.cpp b/llvm/lib/CodeGen/TargetLoweringBase.cpp
index 392cfbdd21273d..5ea39124a8e55a 100644
--- a/llvm/lib/CodeGen/TargetLoweringBase.cpp
+++ b/llvm/lib/CodeGen/TargetLoweringBase.cpp
@@ -818,6 +818,9 @@ void TargetLoweringBase::initActions() {
     setOperationAction(ISD::SDOPC, VT, Expand);
 #include "llvm/IR/VPIntrinsics.def"
 
+    // Masked vector extracts default to expand.
+    setOperationAction(ISD::VECTOR_EXTRACT_LAST_ACTIVE, VT, Expand);
+
     // FP environment operations default to expand.
     setOperationAction(ISD::GET_FPENV, VT, Expand);
     setOperationAction(ISD::SET_FPENV, VT, Expand);

Copy link

github-actions bot commented Dec 5, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

@huntergr-arm huntergr-arm force-pushed the expand-legalization-extract-last-active branch from ebf15e9 to c754687 Compare December 6, 2024 15:59
@huntergr-arm huntergr-arm changed the title Not-quite-working prototype with ISD node [SelectionDAG] Add an ISD node for vector.extract.last.active Dec 6, 2024
@huntergr-arm
Copy link
Collaborator Author

I've changed the base method of lowering for the vector.extract.last.active to use a dedicated ISD node, then implemented expansion in LegalizeVectorOps. I ran into a few problems along the way regarding legalization of types and operations being added that weren't recursively expanded if required by the target, leading to selection failures. I've put in workarounds for now, but we'll want something better before this PR lands.

One of the problems is the result of calling TLI.getBitWidthForCttzElements(...). For AArch64, we sometimes end up with a VT that's a little too small for a vector, giving us an illegal type of v4i8 when trying to extract from a v4i32. RISC-V on the other hand seems to always use i64, giving us the illegal nxv16i64 when extracting from nxv16i8, which resulted in trying to expand vecreduce_umax for a scalable vector.

The AArch64 backend can probably avoid the problem by being a bit more strict about which operations are marked as Expand vs. Legal. Not sure what those working on RVV would like to do here.

#112738 will be updated once this work is complete.

Currently the changes result in worse codegen for both AArch64 and RISC-V targets, but implementing custom lowering will be easy with just the single node to match. I do wonder if there's a better representation, however. A couple of options come to mind:

  1. Split into an extract_last_active node without a passthru, and use a separate or reduction + scalar select sequence if the passthru from IR wasn't poison.
  2. Split into a find_last_active_index which only finds the index of the last active lane, then use the existing extract_vector_elt node (and the or reduction + select if needed, as with the first option).

I won't be able to continue this before next year, so no hurry on review.

Since we shouldn't be changing lowering in SelectionDAGBuilder based on
the target, introduce a new ISD node for extract.last.active and
perform the current lowering in LegalizeVectorOps.

This results in worse codegen for now, but it's easy for a target to
match a single ISD node and improve the output.
@huntergr-arm huntergr-arm force-pushed the expand-legalization-extract-last-active branch from c754687 to 456f079 Compare January 8, 2025 10:21
@huntergr-arm huntergr-arm marked this pull request as ready for review January 8, 2025 10:22
@huntergr-arm
Copy link
Collaborator Author

Back to working on this. Latest patch is just a rebase, but I'll start experimenting with the different approaches I mentioned last year to see if those feel less hacky. I think we'll still need to look into getBitWidthForCttzElements, as making VTs based on the returned value often causes problems (nodes that can't be legalized for targets with scalable vectors).

@huntergr-arm
Copy link
Collaborator Author

Splitting out just the passthru handling didn't change the generated code much, but splitting out the extract from finding the last active index seems to have removed the regressions (and possibly improved RV64 codegen? slli+srli -> andi). DAG combining to lastb/clastb is very easy with this approach, so I think this is close to what we want.

There's still a slight problem with getBitWidthForCttzElements though, at least for RV32 -- I have it hardcoded to cap at 32bits to avoid needing to expand the result into 2 32b scalar registers. I'll investigate whether I can remove the type promotion (needed for AArch64 atm) from the vector op expansion code in the meantime.

@huntergr-arm
Copy link
Collaborator Author

A bit more experimentation shows that the 32b cap is no longer needed with the operation split across multiple operations.

The type promotion is still required, so I've improved the comment to explain why. Without it, we end up with a problem if the chosen type is illegal. For example, if we expand to a set of nodes including a <i8> = vecreduce_umax <nxv8i8>, that would warrant promotion of the input type to <nxv8i16>. If I use setOperationAction to set that up, then we hit an assert in VectorLegalizer::Promote, since it currently doesn't handle the case of a new integer vector VT being larger than the current VT -- it only wants to promote to a smaller number of wider elements. I don't think this PR is the best place to address that, so I've just chosen legal types while expanding instead.

@SamTebbs33 SamTebbs33 self-requested a review January 14, 2025 13:40
Copy link
Collaborator

@SamTebbs33 SamTebbs33 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me.

@huntergr-arm huntergr-arm merged commit d9f165d into llvm:main Jan 20, 2025
6 of 8 checks passed
@huntergr-arm huntergr-arm deleted the expand-legalization-extract-last-active branch January 22, 2025 16:02
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
backend:AArch64 llvm:SelectionDAG SelectionDAGISel as well
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants