Skip to content

[VPlan] Only store RecurKind + FastMathFlags in VPReductionRecipe. NFCI #131300

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Mar 24, 2025

Conversation

lukel97
Copy link
Contributor

@lukel97 lukel97 commented Mar 14, 2025

VPReductionRecipes take a RecurrenceDescriptor, but only use the RecurKind and FastMathFlags in it when executing. This patch makes the recipe more lightweight by stripping it to only take the latter two.

The motiviation for this is to simplify an upcoming patch to support in-loop AnyOf reductions. For an in-loop AnyOf reduction we want to create an Or reduction, and by using RecurKind we can create an arbitrary reduction without needing a full RecurrenceDescriptor.

@llvmbot
Copy link
Member

llvmbot commented Mar 14, 2025

@llvm/pr-subscribers-llvm-transforms

Author: Luke Lau (lukel97)

Changes

VPReductionRecipes take a RecurrenceDescriptor, but only use the RecurKind and FastMathFlags in it when executing. This patch makes the recipe more lightweight by stripping it to only take the latter, which allows it inherit from VPRecipeWithIRFlags.

This also allows us to remove createReduction in LoopUtils since it now only has one user in VPInstruction::ComputeReductionResult.

The motiviation for this is to simplify an upcoming patch to support in-loop AnyOf reductions. For an in-loop AnyOf reduction we want to create an Or reduction, and by using RecurKind we can create an arbitrary reduction without needing a full RecurrenceDescriptor.


Full diff: https://github.com/llvm/llvm-project/pull/131300.diff

5 Files Affected:

  • (modified) llvm/include/llvm/Transforms/Utils/LoopUtils.h (+5-12)
  • (modified) llvm/lib/Transforms/Utils/LoopUtils.cpp (+6-29)
  • (modified) llvm/lib/Transforms/Vectorize/VPlan.h (+26-15)
  • (modified) llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp (+39-37)
  • (modified) llvm/test/Transforms/LoopVectorize/vplan-printing.ll (+1-1)
diff --git a/llvm/include/llvm/Transforms/Utils/LoopUtils.h b/llvm/include/llvm/Transforms/Utils/LoopUtils.h
index 8f4c0c88336ac..3ad7b8f17856c 100644
--- a/llvm/include/llvm/Transforms/Utils/LoopUtils.h
+++ b/llvm/include/llvm/Transforms/Utils/LoopUtils.h
@@ -411,8 +411,8 @@ Value *createSimpleReduction(IRBuilderBase &B, Value *Src,
                              RecurKind RdxKind);
 /// Overloaded function to generate vector-predication intrinsics for
 /// reduction.
-Value *createSimpleReduction(VectorBuilder &VB, Value *Src,
-                             const RecurrenceDescriptor &Desc);
+Value *createSimpleReduction(VectorBuilder &VB, Value *Src, RecurKind RdxKind,
+                             FastMathFlags FMFs);
 
 /// Create a reduction of the given vector \p Src for a reduction of the
 /// kind RecurKind::IAnyOf or RecurKind::FAnyOf. The reduction operation is
@@ -427,20 +427,13 @@ Value *createAnyOfReduction(IRBuilderBase &B, Value *Src,
 Value *createFindLastIVReduction(IRBuilderBase &B, Value *Src,
                                  const RecurrenceDescriptor &Desc);
 
-/// Create a generic reduction using a recurrence descriptor \p Desc
-/// Fast-math-flags are propagated using the RecurrenceDescriptor.
-Value *createReduction(IRBuilderBase &B, const RecurrenceDescriptor &Desc,
-                       Value *Src, PHINode *OrigPhi = nullptr);
-
 /// Create an ordered reduction intrinsic using the given recurrence
-/// descriptor \p Desc.
-Value *createOrderedReduction(IRBuilderBase &B,
-                              const RecurrenceDescriptor &Desc, Value *Src,
+/// kind \p Kind.
+Value *createOrderedReduction(IRBuilderBase &B, RecurKind Kind, Value *Src,
                               Value *Start);
 /// Overloaded function to generate vector-predication intrinsics for ordered
 /// reduction.
-Value *createOrderedReduction(VectorBuilder &VB,
-                              const RecurrenceDescriptor &Desc, Value *Src,
+Value *createOrderedReduction(VectorBuilder &VB, RecurKind Kind, Value *Src,
                               Value *Start);
 
 /// Get the intersection (logical and) of all of the potential IR flags
diff --git a/llvm/lib/Transforms/Utils/LoopUtils.cpp b/llvm/lib/Transforms/Utils/LoopUtils.cpp
index 84c08556f8a25..b20ce27f8cfb3 100644
--- a/llvm/lib/Transforms/Utils/LoopUtils.cpp
+++ b/llvm/lib/Transforms/Utils/LoopUtils.cpp
@@ -1333,41 +1333,20 @@ Value *llvm::createSimpleReduction(IRBuilderBase &Builder, Value *Src,
 }
 
 Value *llvm::createSimpleReduction(VectorBuilder &VBuilder, Value *Src,
-                                   const RecurrenceDescriptor &Desc) {
-  RecurKind Kind = Desc.getRecurrenceKind();
+                                   RecurKind Kind, FastMathFlags FMFs) {
   assert(!RecurrenceDescriptor::isAnyOfRecurrenceKind(Kind) &&
          "AnyOf reduction is not supported.");
   Intrinsic::ID Id = getReductionIntrinsicID(Kind);
   auto *SrcTy = cast<VectorType>(Src->getType());
   Type *SrcEltTy = SrcTy->getElementType();
-  Value *Iden = getRecurrenceIdentity(Kind, SrcEltTy, Desc.getFastMathFlags());
+  Value *Iden = getRecurrenceIdentity(Kind, SrcEltTy, FMFs);
   Value *Ops[] = {Iden, Src};
   return VBuilder.createSimpleReduction(Id, SrcTy, Ops);
 }
 
-Value *llvm::createReduction(IRBuilderBase &B,
-                             const RecurrenceDescriptor &Desc, Value *Src,
-                             PHINode *OrigPhi) {
-  // TODO: Support in-order reductions based on the recurrence descriptor.
-  // All ops in the reduction inherit fast-math-flags from the recurrence
-  // descriptor.
-  IRBuilderBase::FastMathFlagGuard FMFGuard(B);
-  B.setFastMathFlags(Desc.getFastMathFlags());
-
-  RecurKind RK = Desc.getRecurrenceKind();
-  if (RecurrenceDescriptor::isAnyOfRecurrenceKind(RK))
-    return createAnyOfReduction(B, Src, Desc, OrigPhi);
-  if (RecurrenceDescriptor::isFindLastIVRecurrenceKind(RK))
-    return createFindLastIVReduction(B, Src, Desc);
-
-  return createSimpleReduction(B, Src, RK);
-}
-
-Value *llvm::createOrderedReduction(IRBuilderBase &B,
-                                    const RecurrenceDescriptor &Desc,
+Value *llvm::createOrderedReduction(IRBuilderBase &B, RecurKind Kind,
                                     Value *Src, Value *Start) {
-  assert((Desc.getRecurrenceKind() == RecurKind::FAdd ||
-          Desc.getRecurrenceKind() == RecurKind::FMulAdd) &&
+  assert((Kind == RecurKind::FAdd || Kind == RecurKind::FMulAdd) &&
          "Unexpected reduction kind");
   assert(Src->getType()->isVectorTy() && "Expected a vector type");
   assert(!Start->getType()->isVectorTy() && "Expected a scalar type");
@@ -1375,11 +1354,9 @@ Value *llvm::createOrderedReduction(IRBuilderBase &B,
   return B.CreateFAddReduce(Start, Src);
 }
 
-Value *llvm::createOrderedReduction(VectorBuilder &VBuilder,
-                                    const RecurrenceDescriptor &Desc,
+Value *llvm::createOrderedReduction(VectorBuilder &VBuilder, RecurKind Kind,
                                     Value *Src, Value *Start) {
-  assert((Desc.getRecurrenceKind() == RecurKind::FAdd ||
-          Desc.getRecurrenceKind() == RecurKind::FMulAdd) &&
+  assert((Kind == RecurKind::FAdd || Kind == RecurKind::FMulAdd) &&
          "Unexpected reduction kind");
   assert(Src->getType()->isVectorTy() && "Expected a vector type");
   assert(!Start->getType()->isVectorTy() && "Expected a scalar type");
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index f78eb84b0c445..890eff8d28b7f 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -2237,19 +2237,21 @@ class VPInterleaveRecipe : public VPRecipeBase {
 /// A recipe to represent inloop reduction operations, performing a reduction on
 /// a vector operand into a scalar value, and adding the result to a chain.
 /// The Operands are {ChainOp, VecOp, [Condition]}.
-class VPReductionRecipe : public VPSingleDefRecipe {
+class VPReductionRecipe : public VPRecipeWithIRFlags {
   /// The recurrence decriptor for the reduction in question.
-  const RecurrenceDescriptor &RdxDesc;
+  RecurKind RdxKind;
   bool IsOrdered;
   /// Whether the reduction is conditional.
   bool IsConditional = false;
 
 protected:
-  VPReductionRecipe(const unsigned char SC, const RecurrenceDescriptor &R,
-                    Instruction *I, ArrayRef<VPValue *> Operands,
-                    VPValue *CondOp, bool IsOrdered, DebugLoc DL)
-      : VPSingleDefRecipe(SC, Operands, I, DL), RdxDesc(R),
+  VPReductionRecipe(const unsigned char SC, RecurKind RdxKind,
+                    FastMathFlags FMFs, Instruction *I,
+                    ArrayRef<VPValue *> Operands, VPValue *CondOp,
+                    bool IsOrdered, DebugLoc DL)
+      : VPRecipeWithIRFlags(SC, Operands, FMFs, DL), RdxKind(RdxKind),
         IsOrdered(IsOrdered) {
+    setUnderlyingValue(I);
     if (CondOp) {
       IsConditional = true;
       addOperand(CondOp);
@@ -2257,19 +2259,25 @@ class VPReductionRecipe : public VPSingleDefRecipe {
   }
 
 public:
-  VPReductionRecipe(const RecurrenceDescriptor &R, Instruction *I,
+  VPReductionRecipe(RecurKind RdxKind, FastMathFlags FMFs, Instruction *I,
                     VPValue *ChainOp, VPValue *VecOp, VPValue *CondOp,
                     bool IsOrdered, DebugLoc DL = {})
-      : VPReductionRecipe(VPDef::VPReductionSC, R, I,
+      : VPReductionRecipe(VPRecipeBase::VPReductionSC, RdxKind, FMFs, I,
                           ArrayRef<VPValue *>({ChainOp, VecOp}), CondOp,
                           IsOrdered, DL) {}
 
+  VPReductionRecipe(const RecurrenceDescriptor &R, Instruction *I,
+                    VPValue *ChainOp, VPValue *VecOp, VPValue *CondOp,
+                    bool IsOrdered, DebugLoc DL = {})
+      : VPReductionRecipe(R.getRecurrenceKind(), R.getFastMathFlags(), I,
+                          ChainOp, VecOp, CondOp, IsOrdered, DL) {}
+
   ~VPReductionRecipe() override = default;
 
   VPReductionRecipe *clone() override {
-    return new VPReductionRecipe(RdxDesc, getUnderlyingInstr(), getChainOp(),
-                                 getVecOp(), getCondOp(), IsOrdered,
-                                 getDebugLoc());
+    return new VPReductionRecipe(RdxKind, getFastMathFlags(),
+                                 getUnderlyingInstr(), getChainOp(), getVecOp(),
+                                 getCondOp(), IsOrdered, getDebugLoc());
   }
 
   static inline bool classof(const VPRecipeBase *R) {
@@ -2295,9 +2303,11 @@ class VPReductionRecipe : public VPSingleDefRecipe {
              VPSlotTracker &SlotTracker) const override;
 #endif
 
-  /// Return the recurrence decriptor for the in-loop reduction.
-  const RecurrenceDescriptor &getRecurrenceDescriptor() const {
-    return RdxDesc;
+  /// Return the recurrence kind for the in-loop reduction.
+  RecurKind getRecurrenceKind() const { return RdxKind; }
+  /// Return the opcode for the recurrence for the in-loop reduction.
+  unsigned getOpcode() const {
+    return RecurrenceDescriptor::getOpcode(RdxKind);
   }
   /// Return true if the in-loop reduction is ordered.
   bool isOrdered() const { return IsOrdered; };
@@ -2321,7 +2331,8 @@ class VPReductionEVLRecipe : public VPReductionRecipe {
 public:
   VPReductionEVLRecipe(VPReductionRecipe &R, VPValue &EVL, VPValue *CondOp)
       : VPReductionRecipe(
-            VPDef::VPReductionEVLSC, R.getRecurrenceDescriptor(),
+            VPDef::VPReductionEVLSC, R.getRecurrenceKind(),
+            R.getFastMathFlags(),
             cast_or_null<Instruction>(R.getUnderlyingValue()),
             ArrayRef<VPValue *>({R.getChainOp(), R.getVecOp(), &EVL}), CondOp,
             R.isOrdered(), R.getDebugLoc()) {}
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index 6e396eda6aac6..3ea737f1a088b 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -666,8 +666,17 @@ Value *VPInstruction::generate(VPTransformState &State) {
          RecurrenceDescriptor::isAnyOfRecurrenceKind(RK) ||
          RecurrenceDescriptor::isFindLastIVRecurrenceKind(RK)) &&
         !PhiR->isInLoop()) {
-      ReducedPartRdx =
-          createReduction(Builder, RdxDesc, ReducedPartRdx, OrigPhi);
+      IRBuilderBase::FastMathFlagGuard FMFG(Builder);
+      Builder.setFastMathFlags(RdxDesc.getFastMathFlags());
+      if (RecurrenceDescriptor::isAnyOfRecurrenceKind(RK))
+        ReducedPartRdx =
+            createAnyOfReduction(Builder, ReducedPartRdx, RdxDesc, OrigPhi);
+      else if (RecurrenceDescriptor::isFindLastIVRecurrenceKind(RK))
+        ReducedPartRdx =
+            createFindLastIVReduction(Builder, ReducedPartRdx, RdxDesc);
+      else
+        ReducedPartRdx = createSimpleReduction(Builder, ReducedPartRdx, RK);
+
       // If the reduction can be performed in a smaller type, we need to extend
       // the reduction to the wider type before we branch to the original loop.
       if (PhiTy != RdxDesc.getRecurrenceType())
@@ -2263,12 +2272,13 @@ void VPBlendRecipe::print(raw_ostream &O, const Twine &Indent,
 void VPReductionRecipe::execute(VPTransformState &State) {
   assert(!State.Lane && "Reduction being replicated.");
   Value *PrevInChain = State.get(getChainOp(), /*IsScalar*/ true);
-  RecurKind Kind = RdxDesc.getRecurrenceKind();
+  RecurKind Kind = getRecurrenceKind();
   assert(!RecurrenceDescriptor::isAnyOfRecurrenceKind(Kind) &&
          "In-loop AnyOf reductions aren't currently supported");
+
   // Propagate the fast-math flags carried by the underlying instruction.
   IRBuilderBase::FastMathFlagGuard FMFGuard(State.Builder);
-  State.Builder.setFastMathFlags(RdxDesc.getFastMathFlags());
+  State.Builder.setFastMathFlags(getFastMathFlags());
   State.setDebugLocFrom(getDebugLoc());
   Value *NewVecOp = State.get(getVecOp());
   if (VPValue *Cond = getCondOp()) {
@@ -2276,8 +2286,7 @@ void VPReductionRecipe::execute(VPTransformState &State) {
     VectorType *VecTy = dyn_cast<VectorType>(NewVecOp->getType());
     Type *ElementTy = VecTy ? VecTy->getElementType() : NewVecOp->getType();
 
-    Value *Start =
-        getRecurrenceIdentity(Kind, ElementTy, RdxDesc.getFastMathFlags());
+    Value *Start = getRecurrenceIdentity(Kind, ElementTy, getFastMathFlags());
     if (State.VF.isVector())
       Start = State.Builder.CreateVectorSplat(VecTy->getElementCount(), Start);
 
@@ -2289,21 +2298,20 @@ void VPReductionRecipe::execute(VPTransformState &State) {
   if (IsOrdered) {
     if (State.VF.isVector())
       NewRed =
-          createOrderedReduction(State.Builder, RdxDesc, NewVecOp, PrevInChain);
+          createOrderedReduction(State.Builder, Kind, NewVecOp, PrevInChain);
     else
-      NewRed = State.Builder.CreateBinOp(
-          (Instruction::BinaryOps)RdxDesc.getOpcode(), PrevInChain, NewVecOp);
+      NewRed = State.Builder.CreateBinOp((Instruction::BinaryOps)getOpcode(),
+                                         PrevInChain, NewVecOp);
     PrevInChain = NewRed;
     NextInChain = NewRed;
   } else {
     PrevInChain = State.get(getChainOp(), /*IsScalar*/ true);
-    NewRed = createReduction(State.Builder, RdxDesc, NewVecOp);
+    NewRed = createSimpleReduction(State.Builder, NewVecOp, Kind);
     if (RecurrenceDescriptor::isMinMaxRecurrenceKind(Kind))
-      NextInChain = createMinMaxOp(State.Builder, RdxDesc.getRecurrenceKind(),
-                                   NewRed, PrevInChain);
+      NextInChain = createMinMaxOp(State.Builder, Kind, NewRed, PrevInChain);
     else
       NextInChain = State.Builder.CreateBinOp(
-          (Instruction::BinaryOps)RdxDesc.getOpcode(), NewRed, PrevInChain);
+          (Instruction::BinaryOps)getOpcode(), NewRed, PrevInChain);
   }
   State.set(this, NextInChain, /*IsScalar*/ true);
 }
@@ -2314,10 +2322,9 @@ void VPReductionEVLRecipe::execute(VPTransformState &State) {
   auto &Builder = State.Builder;
   // Propagate the fast-math flags carried by the underlying instruction.
   IRBuilderBase::FastMathFlagGuard FMFGuard(Builder);
-  const RecurrenceDescriptor &RdxDesc = getRecurrenceDescriptor();
-  Builder.setFastMathFlags(RdxDesc.getFastMathFlags());
+  Builder.setFastMathFlags(getFastMathFlags());
 
-  RecurKind Kind = RdxDesc.getRecurrenceKind();
+  RecurKind Kind = getRecurrenceKind();
   Value *Prev = State.get(getChainOp(), /*IsScalar*/ true);
   Value *VecOp = State.get(getVecOp());
   Value *EVL = State.get(getEVL(), VPLane(0));
@@ -2334,24 +2341,23 @@ void VPReductionEVLRecipe::execute(VPTransformState &State) {
 
   Value *NewRed;
   if (isOrdered()) {
-    NewRed = createOrderedReduction(VBuilder, RdxDesc, VecOp, Prev);
+    NewRed = createOrderedReduction(VBuilder, Kind, VecOp, Prev);
   } else {
-    NewRed = createSimpleReduction(VBuilder, VecOp, RdxDesc);
+    NewRed = createSimpleReduction(VBuilder, VecOp, Kind, getFastMathFlags());
     if (RecurrenceDescriptor::isMinMaxRecurrenceKind(Kind))
       NewRed = createMinMaxOp(Builder, Kind, NewRed, Prev);
     else
-      NewRed = Builder.CreateBinOp((Instruction::BinaryOps)RdxDesc.getOpcode(),
-                                   NewRed, Prev);
+      NewRed = Builder.CreateBinOp((Instruction::BinaryOps)getOpcode(), NewRed,
+                                   Prev);
   }
   State.set(this, NewRed, /*IsScalar*/ true);
 }
 
 InstructionCost VPReductionRecipe::computeCost(ElementCount VF,
                                                VPCostContext &Ctx) const {
-  RecurKind RdxKind = RdxDesc.getRecurrenceKind();
+  RecurKind RdxKind = getRecurrenceKind();
   Type *ElementTy = Ctx.Types.inferScalarType(this);
   auto *VectorTy = cast<VectorType>(toVectorTy(ElementTy, VF));
-  unsigned Opcode = RdxDesc.getOpcode();
 
   // TODO: Support any-of and in-loop reductions.
   assert(
@@ -2363,20 +2369,17 @@ InstructionCost VPReductionRecipe::computeCost(ElementCount VF,
        ForceTargetInstructionCost.getNumOccurrences() > 0) &&
       "In-loop reduction not implemented in VPlan-based cost model currently.");
 
-  assert(ElementTy->getTypeID() == RdxDesc.getRecurrenceType()->getTypeID() &&
-         "Inferred type and recurrence type mismatch.");
-
   // Cost = Reduction cost + BinOp cost
   InstructionCost Cost =
-      Ctx.TTI.getArithmeticInstrCost(Opcode, ElementTy, Ctx.CostKind);
+      Ctx.TTI.getArithmeticInstrCost(getOpcode(), ElementTy, Ctx.CostKind);
   if (RecurrenceDescriptor::isMinMaxRecurrenceKind(RdxKind)) {
     Intrinsic::ID Id = getMinMaxReductionIntrinsicOp(RdxKind);
     return Cost + Ctx.TTI.getMinMaxReductionCost(
-                      Id, VectorTy, RdxDesc.getFastMathFlags(), Ctx.CostKind);
+                      Id, VectorTy, getFastMathFlags(), Ctx.CostKind);
   }
 
   return Cost + Ctx.TTI.getArithmeticReductionCost(
-                    Opcode, VectorTy, RdxDesc.getFastMathFlags(), Ctx.CostKind);
+                    getOpcode(), VectorTy, getFastMathFlags(), Ctx.CostKind);
 }
 
 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
@@ -2389,21 +2392,21 @@ void VPReductionRecipe::print(raw_ostream &O, const Twine &Indent,
   O << " +";
   if (isa<FPMathOperator>(getUnderlyingInstr()))
     O << getUnderlyingInstr()->getFastMathFlags();
-  O << " reduce." << Instruction::getOpcodeName(RdxDesc.getOpcode()) << " (";
+  O << " reduce."
+    << Instruction::getOpcodeName(
+           RecurrenceDescriptor::getOpcode(getRecurrenceKind()))
+    << " (";
   getVecOp()->printAsOperand(O, SlotTracker);
   if (isConditional()) {
     O << ", ";
     getCondOp()->printAsOperand(O, SlotTracker);
   }
   O << ")";
-  if (RdxDesc.IntermediateStore)
-    O << " (with final reduction value stored in invariant address sank "
-         "outside of loop)";
 }
 
 void VPReductionEVLRecipe::print(raw_ostream &O, const Twine &Indent,
                                  VPSlotTracker &SlotTracker) const {
-  const RecurrenceDescriptor &RdxDesc = getRecurrenceDescriptor();
+  RecurKind Kind = getRecurrenceKind();
   O << Indent << "REDUCE ";
   printAsOperand(O, SlotTracker);
   O << " = ";
@@ -2411,7 +2414,9 @@ void VPReductionEVLRecipe::print(raw_ostream &O, const Twine &Indent,
   O << " +";
   if (isa<FPMathOperator>(getUnderlyingInstr()))
     O << getUnderlyingInstr()->getFastMathFlags();
-  O << " vp.reduce." << Instruction::getOpcodeName(RdxDesc.getOpcode()) << " (";
+  O << " vp.reduce."
+    << Instruction::getOpcodeName(RecurrenceDescriptor::getOpcode(Kind))
+    << " (";
   getVecOp()->printAsOperand(O, SlotTracker);
   O << ", ";
   getEVL()->printAsOperand(O, SlotTracker);
@@ -2420,9 +2425,6 @@ void VPReductionEVLRecipe::print(raw_ostream &O, const Twine &Indent,
     getCondOp()->printAsOperand(O, SlotTracker);
   }
   O << ")";
-  if (RdxDesc.IntermediateStore)
-    O << " (with final reduction value stored in invariant address sank "
-         "outside of loop)";
 }
 #endif
 
diff --git a/llvm/test/Transforms/LoopVectorize/vplan-printing.ll b/llvm/test/Transforms/LoopVectorize/vplan-printing.ll
index 00d8de67a3b40..791db17b9a7d1 100644
--- a/llvm/test/Transforms/LoopVectorize/vplan-printing.ll
+++ b/llvm/test/Transforms/LoopVectorize/vplan-printing.ll
@@ -234,7 +234,7 @@ define void @print_reduction_with_invariant_store(i64 %n, ptr noalias %y, ptr no
 ; CHECK-NEXT:   CLONE ir<%arrayidx> = getelementptr inbounds ir<%y>, vp<[[IV]]>
 ; CHECK-NEXT:   vp<[[VEC_PTR:%.+]]> = vector-pointer ir<%arrayidx>
 ; CHECK-NEXT:   WIDEN ir<%lv> = load vp<[[VEC_PTR]]>
-; CHECK-NEXT:   REDUCE ir<%red.next> = ir<%red> + fast reduce.fadd (ir<%lv>) (with final reduction value stored in invariant address sank outside of loop)
+; CHECK-NEXT:   REDUCE ir<%red.next> = ir<%red> + fast reduce.fadd (ir<%lv>)
 ; CHECK-NEXT:   EMIT vp<[[CAN_IV_NEXT]]> = add nuw vp<[[CAN_IV]]>, vp<[[VFxUF]]>
 ; CHECK-NEXT:   EMIT branch-on-count vp<[[CAN_IV_NEXT]]>, vp<[[VTC]]>
 ; CHECK-NEXT: No successors

@llvmbot
Copy link
Member

llvmbot commented Mar 14, 2025

@llvm/pr-subscribers-vectorizers

Author: Luke Lau (lukel97)

Changes

VPReductionRecipes take a RecurrenceDescriptor, but only use the RecurKind and FastMathFlags in it when executing. This patch makes the recipe more lightweight by stripping it to only take the latter, which allows it inherit from VPRecipeWithIRFlags.

This also allows us to remove createReduction in LoopUtils since it now only has one user in VPInstruction::ComputeReductionResult.

The motiviation for this is to simplify an upcoming patch to support in-loop AnyOf reductions. For an in-loop AnyOf reduction we want to create an Or reduction, and by using RecurKind we can create an arbitrary reduction without needing a full RecurrenceDescriptor.


Full diff: https://github.com/llvm/llvm-project/pull/131300.diff

5 Files Affected:

  • (modified) llvm/include/llvm/Transforms/Utils/LoopUtils.h (+5-12)
  • (modified) llvm/lib/Transforms/Utils/LoopUtils.cpp (+6-29)
  • (modified) llvm/lib/Transforms/Vectorize/VPlan.h (+26-15)
  • (modified) llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp (+39-37)
  • (modified) llvm/test/Transforms/LoopVectorize/vplan-printing.ll (+1-1)
diff --git a/llvm/include/llvm/Transforms/Utils/LoopUtils.h b/llvm/include/llvm/Transforms/Utils/LoopUtils.h
index 8f4c0c88336ac..3ad7b8f17856c 100644
--- a/llvm/include/llvm/Transforms/Utils/LoopUtils.h
+++ b/llvm/include/llvm/Transforms/Utils/LoopUtils.h
@@ -411,8 +411,8 @@ Value *createSimpleReduction(IRBuilderBase &B, Value *Src,
                              RecurKind RdxKind);
 /// Overloaded function to generate vector-predication intrinsics for
 /// reduction.
-Value *createSimpleReduction(VectorBuilder &VB, Value *Src,
-                             const RecurrenceDescriptor &Desc);
+Value *createSimpleReduction(VectorBuilder &VB, Value *Src, RecurKind RdxKind,
+                             FastMathFlags FMFs);
 
 /// Create a reduction of the given vector \p Src for a reduction of the
 /// kind RecurKind::IAnyOf or RecurKind::FAnyOf. The reduction operation is
@@ -427,20 +427,13 @@ Value *createAnyOfReduction(IRBuilderBase &B, Value *Src,
 Value *createFindLastIVReduction(IRBuilderBase &B, Value *Src,
                                  const RecurrenceDescriptor &Desc);
 
-/// Create a generic reduction using a recurrence descriptor \p Desc
-/// Fast-math-flags are propagated using the RecurrenceDescriptor.
-Value *createReduction(IRBuilderBase &B, const RecurrenceDescriptor &Desc,
-                       Value *Src, PHINode *OrigPhi = nullptr);
-
 /// Create an ordered reduction intrinsic using the given recurrence
-/// descriptor \p Desc.
-Value *createOrderedReduction(IRBuilderBase &B,
-                              const RecurrenceDescriptor &Desc, Value *Src,
+/// kind \p Kind.
+Value *createOrderedReduction(IRBuilderBase &B, RecurKind Kind, Value *Src,
                               Value *Start);
 /// Overloaded function to generate vector-predication intrinsics for ordered
 /// reduction.
-Value *createOrderedReduction(VectorBuilder &VB,
-                              const RecurrenceDescriptor &Desc, Value *Src,
+Value *createOrderedReduction(VectorBuilder &VB, RecurKind Kind, Value *Src,
                               Value *Start);
 
 /// Get the intersection (logical and) of all of the potential IR flags
diff --git a/llvm/lib/Transforms/Utils/LoopUtils.cpp b/llvm/lib/Transforms/Utils/LoopUtils.cpp
index 84c08556f8a25..b20ce27f8cfb3 100644
--- a/llvm/lib/Transforms/Utils/LoopUtils.cpp
+++ b/llvm/lib/Transforms/Utils/LoopUtils.cpp
@@ -1333,41 +1333,20 @@ Value *llvm::createSimpleReduction(IRBuilderBase &Builder, Value *Src,
 }
 
 Value *llvm::createSimpleReduction(VectorBuilder &VBuilder, Value *Src,
-                                   const RecurrenceDescriptor &Desc) {
-  RecurKind Kind = Desc.getRecurrenceKind();
+                                   RecurKind Kind, FastMathFlags FMFs) {
   assert(!RecurrenceDescriptor::isAnyOfRecurrenceKind(Kind) &&
          "AnyOf reduction is not supported.");
   Intrinsic::ID Id = getReductionIntrinsicID(Kind);
   auto *SrcTy = cast<VectorType>(Src->getType());
   Type *SrcEltTy = SrcTy->getElementType();
-  Value *Iden = getRecurrenceIdentity(Kind, SrcEltTy, Desc.getFastMathFlags());
+  Value *Iden = getRecurrenceIdentity(Kind, SrcEltTy, FMFs);
   Value *Ops[] = {Iden, Src};
   return VBuilder.createSimpleReduction(Id, SrcTy, Ops);
 }
 
-Value *llvm::createReduction(IRBuilderBase &B,
-                             const RecurrenceDescriptor &Desc, Value *Src,
-                             PHINode *OrigPhi) {
-  // TODO: Support in-order reductions based on the recurrence descriptor.
-  // All ops in the reduction inherit fast-math-flags from the recurrence
-  // descriptor.
-  IRBuilderBase::FastMathFlagGuard FMFGuard(B);
-  B.setFastMathFlags(Desc.getFastMathFlags());
-
-  RecurKind RK = Desc.getRecurrenceKind();
-  if (RecurrenceDescriptor::isAnyOfRecurrenceKind(RK))
-    return createAnyOfReduction(B, Src, Desc, OrigPhi);
-  if (RecurrenceDescriptor::isFindLastIVRecurrenceKind(RK))
-    return createFindLastIVReduction(B, Src, Desc);
-
-  return createSimpleReduction(B, Src, RK);
-}
-
-Value *llvm::createOrderedReduction(IRBuilderBase &B,
-                                    const RecurrenceDescriptor &Desc,
+Value *llvm::createOrderedReduction(IRBuilderBase &B, RecurKind Kind,
                                     Value *Src, Value *Start) {
-  assert((Desc.getRecurrenceKind() == RecurKind::FAdd ||
-          Desc.getRecurrenceKind() == RecurKind::FMulAdd) &&
+  assert((Kind == RecurKind::FAdd || Kind == RecurKind::FMulAdd) &&
          "Unexpected reduction kind");
   assert(Src->getType()->isVectorTy() && "Expected a vector type");
   assert(!Start->getType()->isVectorTy() && "Expected a scalar type");
@@ -1375,11 +1354,9 @@ Value *llvm::createOrderedReduction(IRBuilderBase &B,
   return B.CreateFAddReduce(Start, Src);
 }
 
-Value *llvm::createOrderedReduction(VectorBuilder &VBuilder,
-                                    const RecurrenceDescriptor &Desc,
+Value *llvm::createOrderedReduction(VectorBuilder &VBuilder, RecurKind Kind,
                                     Value *Src, Value *Start) {
-  assert((Desc.getRecurrenceKind() == RecurKind::FAdd ||
-          Desc.getRecurrenceKind() == RecurKind::FMulAdd) &&
+  assert((Kind == RecurKind::FAdd || Kind == RecurKind::FMulAdd) &&
          "Unexpected reduction kind");
   assert(Src->getType()->isVectorTy() && "Expected a vector type");
   assert(!Start->getType()->isVectorTy() && "Expected a scalar type");
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index f78eb84b0c445..890eff8d28b7f 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -2237,19 +2237,21 @@ class VPInterleaveRecipe : public VPRecipeBase {
 /// A recipe to represent inloop reduction operations, performing a reduction on
 /// a vector operand into a scalar value, and adding the result to a chain.
 /// The Operands are {ChainOp, VecOp, [Condition]}.
-class VPReductionRecipe : public VPSingleDefRecipe {
+class VPReductionRecipe : public VPRecipeWithIRFlags {
   /// The recurrence decriptor for the reduction in question.
-  const RecurrenceDescriptor &RdxDesc;
+  RecurKind RdxKind;
   bool IsOrdered;
   /// Whether the reduction is conditional.
   bool IsConditional = false;
 
 protected:
-  VPReductionRecipe(const unsigned char SC, const RecurrenceDescriptor &R,
-                    Instruction *I, ArrayRef<VPValue *> Operands,
-                    VPValue *CondOp, bool IsOrdered, DebugLoc DL)
-      : VPSingleDefRecipe(SC, Operands, I, DL), RdxDesc(R),
+  VPReductionRecipe(const unsigned char SC, RecurKind RdxKind,
+                    FastMathFlags FMFs, Instruction *I,
+                    ArrayRef<VPValue *> Operands, VPValue *CondOp,
+                    bool IsOrdered, DebugLoc DL)
+      : VPRecipeWithIRFlags(SC, Operands, FMFs, DL), RdxKind(RdxKind),
         IsOrdered(IsOrdered) {
+    setUnderlyingValue(I);
     if (CondOp) {
       IsConditional = true;
       addOperand(CondOp);
@@ -2257,19 +2259,25 @@ class VPReductionRecipe : public VPSingleDefRecipe {
   }
 
 public:
-  VPReductionRecipe(const RecurrenceDescriptor &R, Instruction *I,
+  VPReductionRecipe(RecurKind RdxKind, FastMathFlags FMFs, Instruction *I,
                     VPValue *ChainOp, VPValue *VecOp, VPValue *CondOp,
                     bool IsOrdered, DebugLoc DL = {})
-      : VPReductionRecipe(VPDef::VPReductionSC, R, I,
+      : VPReductionRecipe(VPRecipeBase::VPReductionSC, RdxKind, FMFs, I,
                           ArrayRef<VPValue *>({ChainOp, VecOp}), CondOp,
                           IsOrdered, DL) {}
 
+  VPReductionRecipe(const RecurrenceDescriptor &R, Instruction *I,
+                    VPValue *ChainOp, VPValue *VecOp, VPValue *CondOp,
+                    bool IsOrdered, DebugLoc DL = {})
+      : VPReductionRecipe(R.getRecurrenceKind(), R.getFastMathFlags(), I,
+                          ChainOp, VecOp, CondOp, IsOrdered, DL) {}
+
   ~VPReductionRecipe() override = default;
 
   VPReductionRecipe *clone() override {
-    return new VPReductionRecipe(RdxDesc, getUnderlyingInstr(), getChainOp(),
-                                 getVecOp(), getCondOp(), IsOrdered,
-                                 getDebugLoc());
+    return new VPReductionRecipe(RdxKind, getFastMathFlags(),
+                                 getUnderlyingInstr(), getChainOp(), getVecOp(),
+                                 getCondOp(), IsOrdered, getDebugLoc());
   }
 
   static inline bool classof(const VPRecipeBase *R) {
@@ -2295,9 +2303,11 @@ class VPReductionRecipe : public VPSingleDefRecipe {
              VPSlotTracker &SlotTracker) const override;
 #endif
 
-  /// Return the recurrence decriptor for the in-loop reduction.
-  const RecurrenceDescriptor &getRecurrenceDescriptor() const {
-    return RdxDesc;
+  /// Return the recurrence kind for the in-loop reduction.
+  RecurKind getRecurrenceKind() const { return RdxKind; }
+  /// Return the opcode for the recurrence for the in-loop reduction.
+  unsigned getOpcode() const {
+    return RecurrenceDescriptor::getOpcode(RdxKind);
   }
   /// Return true if the in-loop reduction is ordered.
   bool isOrdered() const { return IsOrdered; };
@@ -2321,7 +2331,8 @@ class VPReductionEVLRecipe : public VPReductionRecipe {
 public:
   VPReductionEVLRecipe(VPReductionRecipe &R, VPValue &EVL, VPValue *CondOp)
       : VPReductionRecipe(
-            VPDef::VPReductionEVLSC, R.getRecurrenceDescriptor(),
+            VPDef::VPReductionEVLSC, R.getRecurrenceKind(),
+            R.getFastMathFlags(),
             cast_or_null<Instruction>(R.getUnderlyingValue()),
             ArrayRef<VPValue *>({R.getChainOp(), R.getVecOp(), &EVL}), CondOp,
             R.isOrdered(), R.getDebugLoc()) {}
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index 6e396eda6aac6..3ea737f1a088b 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -666,8 +666,17 @@ Value *VPInstruction::generate(VPTransformState &State) {
          RecurrenceDescriptor::isAnyOfRecurrenceKind(RK) ||
          RecurrenceDescriptor::isFindLastIVRecurrenceKind(RK)) &&
         !PhiR->isInLoop()) {
-      ReducedPartRdx =
-          createReduction(Builder, RdxDesc, ReducedPartRdx, OrigPhi);
+      IRBuilderBase::FastMathFlagGuard FMFG(Builder);
+      Builder.setFastMathFlags(RdxDesc.getFastMathFlags());
+      if (RecurrenceDescriptor::isAnyOfRecurrenceKind(RK))
+        ReducedPartRdx =
+            createAnyOfReduction(Builder, ReducedPartRdx, RdxDesc, OrigPhi);
+      else if (RecurrenceDescriptor::isFindLastIVRecurrenceKind(RK))
+        ReducedPartRdx =
+            createFindLastIVReduction(Builder, ReducedPartRdx, RdxDesc);
+      else
+        ReducedPartRdx = createSimpleReduction(Builder, ReducedPartRdx, RK);
+
       // If the reduction can be performed in a smaller type, we need to extend
       // the reduction to the wider type before we branch to the original loop.
       if (PhiTy != RdxDesc.getRecurrenceType())
@@ -2263,12 +2272,13 @@ void VPBlendRecipe::print(raw_ostream &O, const Twine &Indent,
 void VPReductionRecipe::execute(VPTransformState &State) {
   assert(!State.Lane && "Reduction being replicated.");
   Value *PrevInChain = State.get(getChainOp(), /*IsScalar*/ true);
-  RecurKind Kind = RdxDesc.getRecurrenceKind();
+  RecurKind Kind = getRecurrenceKind();
   assert(!RecurrenceDescriptor::isAnyOfRecurrenceKind(Kind) &&
          "In-loop AnyOf reductions aren't currently supported");
+
   // Propagate the fast-math flags carried by the underlying instruction.
   IRBuilderBase::FastMathFlagGuard FMFGuard(State.Builder);
-  State.Builder.setFastMathFlags(RdxDesc.getFastMathFlags());
+  State.Builder.setFastMathFlags(getFastMathFlags());
   State.setDebugLocFrom(getDebugLoc());
   Value *NewVecOp = State.get(getVecOp());
   if (VPValue *Cond = getCondOp()) {
@@ -2276,8 +2286,7 @@ void VPReductionRecipe::execute(VPTransformState &State) {
     VectorType *VecTy = dyn_cast<VectorType>(NewVecOp->getType());
     Type *ElementTy = VecTy ? VecTy->getElementType() : NewVecOp->getType();
 
-    Value *Start =
-        getRecurrenceIdentity(Kind, ElementTy, RdxDesc.getFastMathFlags());
+    Value *Start = getRecurrenceIdentity(Kind, ElementTy, getFastMathFlags());
     if (State.VF.isVector())
       Start = State.Builder.CreateVectorSplat(VecTy->getElementCount(), Start);
 
@@ -2289,21 +2298,20 @@ void VPReductionRecipe::execute(VPTransformState &State) {
   if (IsOrdered) {
     if (State.VF.isVector())
       NewRed =
-          createOrderedReduction(State.Builder, RdxDesc, NewVecOp, PrevInChain);
+          createOrderedReduction(State.Builder, Kind, NewVecOp, PrevInChain);
     else
-      NewRed = State.Builder.CreateBinOp(
-          (Instruction::BinaryOps)RdxDesc.getOpcode(), PrevInChain, NewVecOp);
+      NewRed = State.Builder.CreateBinOp((Instruction::BinaryOps)getOpcode(),
+                                         PrevInChain, NewVecOp);
     PrevInChain = NewRed;
     NextInChain = NewRed;
   } else {
     PrevInChain = State.get(getChainOp(), /*IsScalar*/ true);
-    NewRed = createReduction(State.Builder, RdxDesc, NewVecOp);
+    NewRed = createSimpleReduction(State.Builder, NewVecOp, Kind);
     if (RecurrenceDescriptor::isMinMaxRecurrenceKind(Kind))
-      NextInChain = createMinMaxOp(State.Builder, RdxDesc.getRecurrenceKind(),
-                                   NewRed, PrevInChain);
+      NextInChain = createMinMaxOp(State.Builder, Kind, NewRed, PrevInChain);
     else
       NextInChain = State.Builder.CreateBinOp(
-          (Instruction::BinaryOps)RdxDesc.getOpcode(), NewRed, PrevInChain);
+          (Instruction::BinaryOps)getOpcode(), NewRed, PrevInChain);
   }
   State.set(this, NextInChain, /*IsScalar*/ true);
 }
@@ -2314,10 +2322,9 @@ void VPReductionEVLRecipe::execute(VPTransformState &State) {
   auto &Builder = State.Builder;
   // Propagate the fast-math flags carried by the underlying instruction.
   IRBuilderBase::FastMathFlagGuard FMFGuard(Builder);
-  const RecurrenceDescriptor &RdxDesc = getRecurrenceDescriptor();
-  Builder.setFastMathFlags(RdxDesc.getFastMathFlags());
+  Builder.setFastMathFlags(getFastMathFlags());
 
-  RecurKind Kind = RdxDesc.getRecurrenceKind();
+  RecurKind Kind = getRecurrenceKind();
   Value *Prev = State.get(getChainOp(), /*IsScalar*/ true);
   Value *VecOp = State.get(getVecOp());
   Value *EVL = State.get(getEVL(), VPLane(0));
@@ -2334,24 +2341,23 @@ void VPReductionEVLRecipe::execute(VPTransformState &State) {
 
   Value *NewRed;
   if (isOrdered()) {
-    NewRed = createOrderedReduction(VBuilder, RdxDesc, VecOp, Prev);
+    NewRed = createOrderedReduction(VBuilder, Kind, VecOp, Prev);
   } else {
-    NewRed = createSimpleReduction(VBuilder, VecOp, RdxDesc);
+    NewRed = createSimpleReduction(VBuilder, VecOp, Kind, getFastMathFlags());
     if (RecurrenceDescriptor::isMinMaxRecurrenceKind(Kind))
       NewRed = createMinMaxOp(Builder, Kind, NewRed, Prev);
     else
-      NewRed = Builder.CreateBinOp((Instruction::BinaryOps)RdxDesc.getOpcode(),
-                                   NewRed, Prev);
+      NewRed = Builder.CreateBinOp((Instruction::BinaryOps)getOpcode(), NewRed,
+                                   Prev);
   }
   State.set(this, NewRed, /*IsScalar*/ true);
 }
 
 InstructionCost VPReductionRecipe::computeCost(ElementCount VF,
                                                VPCostContext &Ctx) const {
-  RecurKind RdxKind = RdxDesc.getRecurrenceKind();
+  RecurKind RdxKind = getRecurrenceKind();
   Type *ElementTy = Ctx.Types.inferScalarType(this);
   auto *VectorTy = cast<VectorType>(toVectorTy(ElementTy, VF));
-  unsigned Opcode = RdxDesc.getOpcode();
 
   // TODO: Support any-of and in-loop reductions.
   assert(
@@ -2363,20 +2369,17 @@ InstructionCost VPReductionRecipe::computeCost(ElementCount VF,
        ForceTargetInstructionCost.getNumOccurrences() > 0) &&
       "In-loop reduction not implemented in VPlan-based cost model currently.");
 
-  assert(ElementTy->getTypeID() == RdxDesc.getRecurrenceType()->getTypeID() &&
-         "Inferred type and recurrence type mismatch.");
-
   // Cost = Reduction cost + BinOp cost
   InstructionCost Cost =
-      Ctx.TTI.getArithmeticInstrCost(Opcode, ElementTy, Ctx.CostKind);
+      Ctx.TTI.getArithmeticInstrCost(getOpcode(), ElementTy, Ctx.CostKind);
   if (RecurrenceDescriptor::isMinMaxRecurrenceKind(RdxKind)) {
     Intrinsic::ID Id = getMinMaxReductionIntrinsicOp(RdxKind);
     return Cost + Ctx.TTI.getMinMaxReductionCost(
-                      Id, VectorTy, RdxDesc.getFastMathFlags(), Ctx.CostKind);
+                      Id, VectorTy, getFastMathFlags(), Ctx.CostKind);
   }
 
   return Cost + Ctx.TTI.getArithmeticReductionCost(
-                    Opcode, VectorTy, RdxDesc.getFastMathFlags(), Ctx.CostKind);
+                    getOpcode(), VectorTy, getFastMathFlags(), Ctx.CostKind);
 }
 
 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
@@ -2389,21 +2392,21 @@ void VPReductionRecipe::print(raw_ostream &O, const Twine &Indent,
   O << " +";
   if (isa<FPMathOperator>(getUnderlyingInstr()))
     O << getUnderlyingInstr()->getFastMathFlags();
-  O << " reduce." << Instruction::getOpcodeName(RdxDesc.getOpcode()) << " (";
+  O << " reduce."
+    << Instruction::getOpcodeName(
+           RecurrenceDescriptor::getOpcode(getRecurrenceKind()))
+    << " (";
   getVecOp()->printAsOperand(O, SlotTracker);
   if (isConditional()) {
     O << ", ";
     getCondOp()->printAsOperand(O, SlotTracker);
   }
   O << ")";
-  if (RdxDesc.IntermediateStore)
-    O << " (with final reduction value stored in invariant address sank "
-         "outside of loop)";
 }
 
 void VPReductionEVLRecipe::print(raw_ostream &O, const Twine &Indent,
                                  VPSlotTracker &SlotTracker) const {
-  const RecurrenceDescriptor &RdxDesc = getRecurrenceDescriptor();
+  RecurKind Kind = getRecurrenceKind();
   O << Indent << "REDUCE ";
   printAsOperand(O, SlotTracker);
   O << " = ";
@@ -2411,7 +2414,9 @@ void VPReductionEVLRecipe::print(raw_ostream &O, const Twine &Indent,
   O << " +";
   if (isa<FPMathOperator>(getUnderlyingInstr()))
     O << getUnderlyingInstr()->getFastMathFlags();
-  O << " vp.reduce." << Instruction::getOpcodeName(RdxDesc.getOpcode()) << " (";
+  O << " vp.reduce."
+    << Instruction::getOpcodeName(RecurrenceDescriptor::getOpcode(Kind))
+    << " (";
   getVecOp()->printAsOperand(O, SlotTracker);
   O << ", ";
   getEVL()->printAsOperand(O, SlotTracker);
@@ -2420,9 +2425,6 @@ void VPReductionEVLRecipe::print(raw_ostream &O, const Twine &Indent,
     getCondOp()->printAsOperand(O, SlotTracker);
   }
   O << ")";
-  if (RdxDesc.IntermediateStore)
-    O << " (with final reduction value stored in invariant address sank "
-         "outside of loop)";
 }
 #endif
 
diff --git a/llvm/test/Transforms/LoopVectorize/vplan-printing.ll b/llvm/test/Transforms/LoopVectorize/vplan-printing.ll
index 00d8de67a3b40..791db17b9a7d1 100644
--- a/llvm/test/Transforms/LoopVectorize/vplan-printing.ll
+++ b/llvm/test/Transforms/LoopVectorize/vplan-printing.ll
@@ -234,7 +234,7 @@ define void @print_reduction_with_invariant_store(i64 %n, ptr noalias %y, ptr no
 ; CHECK-NEXT:   CLONE ir<%arrayidx> = getelementptr inbounds ir<%y>, vp<[[IV]]>
 ; CHECK-NEXT:   vp<[[VEC_PTR:%.+]]> = vector-pointer ir<%arrayidx>
 ; CHECK-NEXT:   WIDEN ir<%lv> = load vp<[[VEC_PTR]]>
-; CHECK-NEXT:   REDUCE ir<%red.next> = ir<%red> + fast reduce.fadd (ir<%lv>) (with final reduction value stored in invariant address sank outside of loop)
+; CHECK-NEXT:   REDUCE ir<%red.next> = ir<%red> + fast reduce.fadd (ir<%lv>)
 ; CHECK-NEXT:   EMIT vp<[[CAN_IV_NEXT]]> = add nuw vp<[[CAN_IV]]>, vp<[[VFxUF]]>
 ; CHECK-NEXT:   EMIT branch-on-count vp<[[CAN_IV_NEXT]]>, vp<[[VTC]]>
 ; CHECK-NEXT: No successors

@ElvisWang123
Copy link
Contributor

FYI, I have two similar patches of changing the parent of VPReductionRecipe to VPRecipeWithIRFlags.

My patches are focus on eliminating the dependency of the underlying instruction when printing the fastMathFlags and control the FMF by the VPRecipeWIthIRFlags, which is different to your patch that only construct the VPReductionRecipe by the RecurrenceKind.

@lukel97
Copy link
Contributor Author

lukel97 commented Mar 14, 2025

FYI, I have two similar patches of changing the parent of VPReductionRecipe to VPRecipeWithIRFlags.

My patches are focus on eliminating the dependency of the underlying instruction when printing the fastMathFlags and control the FMF by the VPRecipeWIthIRFlags, which is different to your patch that only construct the VPReductionRecipe by the RecurrenceKind.

Oh sorry I never noticed those patches! Yeah, adding the VPRecipeWIthIRFlags bit is the part that is similar here, I use it in this patch to store the FMF. I can try and stack this patch on top of #130881 then.

lukel97 added a commit to lukel97/llvm-project that referenced this pull request Mar 14, 2025
This is split off from llvm#131300.

A VPReductionRecipe will never have a AnyOf or FindLastIV recurrence, so when it calls createReduction it always calls createSimpleReduction.

If we replace the call then it leaves createReduction with one user in VPInstruction::ComputeReductionResult, which we can inline and then remove.
lukel97 added a commit that referenced this pull request Mar 17, 2025
This is split off from #131300.

A VPReductionRecipe will never have a AnyOf or FindLastIV recurrence, so
when it calls createReduction it always calls createSimpleReduction.

If we replace the call then it leaves createReduction with one user in
VPInstruction::ComputeReductionResult, which we can inline and then
remove.
@lukel97 lukel97 force-pushed the loop-vectorize/VPReductionRecipe-RecurKind branch from aa6ada6 to 58d7419 Compare March 17, 2025 17:01
lukel97 added a commit to lukel97/llvm-project that referenced this pull request Mar 17, 2025
This is split off from llvm#131300.

A VPReductionRecipe will never have a AnyOf or FindLastIV recurrence, so
when it calls createReduction it always calls createSimpleReduction.

If we replace the call then it leaves createReduction with one user in
VPInstruction::ComputeReductionResult, which we can inline and then
remove.
@lukel97 lukel97 force-pushed the loop-vectorize/VPReductionRecipe-RecurKind branch from 58d7419 to c14193a Compare March 17, 2025 17:02
lukel97 added 2 commits March 18, 2025 10:31
VPReductionRecipes take a RecurrenceDescriptor, but only use the RecurKind and FastMathFlags in it when executing. This patch makes the recipe more lightweight by stripping it to only take the latter, which allows it inherit from VPRecipeWithIRFlags.

This also allows us to remove createReduction in LoopUtils since it now only has one user in VPInstruction::ComputeReductionResult.

The motiviation for this is to simplify an upcoming patch to support in-loop AnyOf reductions. For an in-loop AnyOf reduction we want to create an Or reduction, and by using RecurKind we can create an arbitrary reduction without needing a full RecurrenceDescriptor.
@lukel97 lukel97 force-pushed the loop-vectorize/VPReductionRecipe-RecurKind branch from 734abc2 to 8fcfcb4 Compare March 18, 2025 06:34
@lukel97
Copy link
Contributor Author

lukel97 commented Mar 18, 2025

I've rebased this on main now that #130881 has landed, thanks @ElvisWang123

bool IsOrdered, DebugLoc DL)
: VPRecipeWithIRFlags(
SC, Operands,
isa_and_nonnull<FPMathOperator>(I) ? FMFs : FastMathFlags(), DL),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks a little odd now. Can we just pass in FMFs directly now? I'd expect that if there are no flags then FMFs should also be empty. Also, I can imagine a user might want to set the flags without passing an instruction.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For now, FMFs are only supported for FPMath ops, and the RecurrenceDescriptor's fast-math flags are set to fast by default for non-FP ops. Arguably we may want to fix this separately. When updating to not pass the recurrence descriptor but FMFs, we should also move the check to the caller.

Copy link
Contributor Author

@lukel97 lukel97 Mar 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isa_and_nonnull check was just added in #130881 so that non-fp RecurKinds don't end up setting the flags, because by default a non-fp RecurrenceDescriptor will have all FastMathFlags set, see #130881 (comment)

It is a bit weird though I'll give you that. I can try and move this to the (only) call site in adjustRecipesForReductions, would that be cleaner?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've moved it to the caller in c244065

RecurKind getRecurrenceKind() const { return RdxKind; }
/// Return the opcode for the recurrence for the in-loop reduction.
unsigned getOpcode() const {
return RecurrenceDescriptor::getOpcode(RdxKind);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we don't expect RdxKind to change during the class lifetime, is it worth just caching the opcode in the class at the time of construction?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was just a helper method to save having to call RecurrenceDescriptor::getOpcode at the use sites, we could also just call it directly there if that's simpler?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've inlined it in 23bcb59 which should be more similar to the existing code

lukel97 added a commit to lukel97/llvm-project that referenced this pull request Mar 19, 2025
…pUtils. NFC

Split off from llvm#131300, this splits up RecurrenceDescriptor arguments so that
arbitrary recurrence kinds may be used down the line.
lukel97 added a commit to lukel97/llvm-project that referenced this pull request Mar 19, 2025
…pUtils. NFC

Split off from llvm#131300, this splits up RecurrenceDescriptor arguments so that
arbitrary recurrence kinds may be used down the line.
lukel97 added a commit that referenced this pull request Mar 19, 2025
…pUtils. NFC (#132014)

Split off from #131300, this splits up RecurrenceDescriptor arguments so
that arbitrary recurrence kinds may be used down the line.
@fhahn
Copy link
Contributor

fhahn commented Mar 20, 2025

There's a new use of the recurrence descriptor added in #113903 by @ElvisWang123 . Would be good to check if RecurrenceKind is enough for that patch as well

@lukel97
Copy link
Contributor Author

lukel97 commented Mar 20, 2025

There's a new use of the recurrence descriptor added in #113903 by @ElvisWang123 . Would be good to check if RecurrenceKind is enough for that patch as well

I just took a look there, #113903 should be fine since it's just using the opcode

Copy link
Contributor

@fhahn fhahn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, as long as it doesn't complicate #113903

Comment on lines -2421 to -2438
if (RdxDesc.IntermediateStore)
O << " (with final reduction value stored in invariant address sank "
"outside of loop)";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK to drop this, as the store is sunk explicitly.

@lukel97
Copy link
Contributor Author

lukel97 commented Mar 21, 2025

LGTM, as long as it doesn't complicate #113903

I'm not in a rush so I can wait until #113903 lands first, and then I can rebase this on top of it @ElvisWang123

@fhahn
Copy link
Contributor

fhahn commented Mar 24, 2025

I think

LGTM, as long as it doesn't complicate #113903

I'm not in a rush so I can wait until #113903 lands first, and then I can rebase this on top of it @ElvisWang123

Just looked and if that type is the only thing that needs storing in #113903 I think it should be fine to merge this one now

@lukel97
Copy link
Contributor Author

lukel97 commented Mar 24, 2025

I think

LGTM, as long as it doesn't complicate #113903

I'm not in a rush so I can wait until #113903 lands first, and then I can rebase this on top of it @ElvisWang123

Just looked and if that type is the only thing that needs storing in #113903 I think it should be fine to merge this one now

Ok, I'll land this now then and it get it out of the way

@lukel97 lukel97 merged commit 6a8606e into llvm:main Mar 24, 2025
11 checks passed
pawosm-arm pushed a commit to pawosm-arm/llvm-project that referenced this pull request Apr 10, 2025
This is split off from llvm#131300.

A VPReductionRecipe will never have a AnyOf or FindLastIV recurrence, so
when it calls createReduction it always calls createSimpleReduction.

If we replace the call then it leaves createReduction with one user in
VPInstruction::ComputeReductionResult, which we can inline and then
remove.
pawosm-arm pushed a commit to pawosm-arm/llvm-project that referenced this pull request Apr 10, 2025
…pUtils. NFC (llvm#132014)

Split off from llvm#131300, this splits up RecurrenceDescriptor arguments so
that arbitrary recurrence kinds may be used down the line.
pawosm-arm pushed a commit to pawosm-arm/llvm-project that referenced this pull request Apr 10, 2025
…CI (llvm#131300)

VPReductionRecipes take a RecurrenceDescriptor, but only use the
RecurKind and FastMathFlags in it when executing. This patch makes the
recipe more lightweight by stripping it to only take the latter two.

The motiviation for this is to simplify an upcoming patch to support
in-loop AnyOf reductions. For an in-loop AnyOf reduction we want to
create an Or reduction, and by using RecurKind we can create an
arbitrary reduction without needing a full RecurrenceDescriptor.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants