-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[VPlan] Only store RecurKind + FastMathFlags in VPReductionRecipe. NFCI #131300
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[VPlan] Only store RecurKind + FastMathFlags in VPReductionRecipe. NFCI #131300
Conversation
@llvm/pr-subscribers-llvm-transforms Author: Luke Lau (lukel97) ChangesVPReductionRecipes take a RecurrenceDescriptor, but only use the RecurKind and FastMathFlags in it when executing. This patch makes the recipe more lightweight by stripping it to only take the latter, which allows it inherit from VPRecipeWithIRFlags. This also allows us to remove createReduction in LoopUtils since it now only has one user in VPInstruction::ComputeReductionResult. The motiviation for this is to simplify an upcoming patch to support in-loop AnyOf reductions. For an in-loop AnyOf reduction we want to create an Or reduction, and by using RecurKind we can create an arbitrary reduction without needing a full RecurrenceDescriptor. Full diff: https://github.com/llvm/llvm-project/pull/131300.diff 5 Files Affected:
diff --git a/llvm/include/llvm/Transforms/Utils/LoopUtils.h b/llvm/include/llvm/Transforms/Utils/LoopUtils.h
index 8f4c0c88336ac..3ad7b8f17856c 100644
--- a/llvm/include/llvm/Transforms/Utils/LoopUtils.h
+++ b/llvm/include/llvm/Transforms/Utils/LoopUtils.h
@@ -411,8 +411,8 @@ Value *createSimpleReduction(IRBuilderBase &B, Value *Src,
RecurKind RdxKind);
/// Overloaded function to generate vector-predication intrinsics for
/// reduction.
-Value *createSimpleReduction(VectorBuilder &VB, Value *Src,
- const RecurrenceDescriptor &Desc);
+Value *createSimpleReduction(VectorBuilder &VB, Value *Src, RecurKind RdxKind,
+ FastMathFlags FMFs);
/// Create a reduction of the given vector \p Src for a reduction of the
/// kind RecurKind::IAnyOf or RecurKind::FAnyOf. The reduction operation is
@@ -427,20 +427,13 @@ Value *createAnyOfReduction(IRBuilderBase &B, Value *Src,
Value *createFindLastIVReduction(IRBuilderBase &B, Value *Src,
const RecurrenceDescriptor &Desc);
-/// Create a generic reduction using a recurrence descriptor \p Desc
-/// Fast-math-flags are propagated using the RecurrenceDescriptor.
-Value *createReduction(IRBuilderBase &B, const RecurrenceDescriptor &Desc,
- Value *Src, PHINode *OrigPhi = nullptr);
-
/// Create an ordered reduction intrinsic using the given recurrence
-/// descriptor \p Desc.
-Value *createOrderedReduction(IRBuilderBase &B,
- const RecurrenceDescriptor &Desc, Value *Src,
+/// kind \p Kind.
+Value *createOrderedReduction(IRBuilderBase &B, RecurKind Kind, Value *Src,
Value *Start);
/// Overloaded function to generate vector-predication intrinsics for ordered
/// reduction.
-Value *createOrderedReduction(VectorBuilder &VB,
- const RecurrenceDescriptor &Desc, Value *Src,
+Value *createOrderedReduction(VectorBuilder &VB, RecurKind Kind, Value *Src,
Value *Start);
/// Get the intersection (logical and) of all of the potential IR flags
diff --git a/llvm/lib/Transforms/Utils/LoopUtils.cpp b/llvm/lib/Transforms/Utils/LoopUtils.cpp
index 84c08556f8a25..b20ce27f8cfb3 100644
--- a/llvm/lib/Transforms/Utils/LoopUtils.cpp
+++ b/llvm/lib/Transforms/Utils/LoopUtils.cpp
@@ -1333,41 +1333,20 @@ Value *llvm::createSimpleReduction(IRBuilderBase &Builder, Value *Src,
}
Value *llvm::createSimpleReduction(VectorBuilder &VBuilder, Value *Src,
- const RecurrenceDescriptor &Desc) {
- RecurKind Kind = Desc.getRecurrenceKind();
+ RecurKind Kind, FastMathFlags FMFs) {
assert(!RecurrenceDescriptor::isAnyOfRecurrenceKind(Kind) &&
"AnyOf reduction is not supported.");
Intrinsic::ID Id = getReductionIntrinsicID(Kind);
auto *SrcTy = cast<VectorType>(Src->getType());
Type *SrcEltTy = SrcTy->getElementType();
- Value *Iden = getRecurrenceIdentity(Kind, SrcEltTy, Desc.getFastMathFlags());
+ Value *Iden = getRecurrenceIdentity(Kind, SrcEltTy, FMFs);
Value *Ops[] = {Iden, Src};
return VBuilder.createSimpleReduction(Id, SrcTy, Ops);
}
-Value *llvm::createReduction(IRBuilderBase &B,
- const RecurrenceDescriptor &Desc, Value *Src,
- PHINode *OrigPhi) {
- // TODO: Support in-order reductions based on the recurrence descriptor.
- // All ops in the reduction inherit fast-math-flags from the recurrence
- // descriptor.
- IRBuilderBase::FastMathFlagGuard FMFGuard(B);
- B.setFastMathFlags(Desc.getFastMathFlags());
-
- RecurKind RK = Desc.getRecurrenceKind();
- if (RecurrenceDescriptor::isAnyOfRecurrenceKind(RK))
- return createAnyOfReduction(B, Src, Desc, OrigPhi);
- if (RecurrenceDescriptor::isFindLastIVRecurrenceKind(RK))
- return createFindLastIVReduction(B, Src, Desc);
-
- return createSimpleReduction(B, Src, RK);
-}
-
-Value *llvm::createOrderedReduction(IRBuilderBase &B,
- const RecurrenceDescriptor &Desc,
+Value *llvm::createOrderedReduction(IRBuilderBase &B, RecurKind Kind,
Value *Src, Value *Start) {
- assert((Desc.getRecurrenceKind() == RecurKind::FAdd ||
- Desc.getRecurrenceKind() == RecurKind::FMulAdd) &&
+ assert((Kind == RecurKind::FAdd || Kind == RecurKind::FMulAdd) &&
"Unexpected reduction kind");
assert(Src->getType()->isVectorTy() && "Expected a vector type");
assert(!Start->getType()->isVectorTy() && "Expected a scalar type");
@@ -1375,11 +1354,9 @@ Value *llvm::createOrderedReduction(IRBuilderBase &B,
return B.CreateFAddReduce(Start, Src);
}
-Value *llvm::createOrderedReduction(VectorBuilder &VBuilder,
- const RecurrenceDescriptor &Desc,
+Value *llvm::createOrderedReduction(VectorBuilder &VBuilder, RecurKind Kind,
Value *Src, Value *Start) {
- assert((Desc.getRecurrenceKind() == RecurKind::FAdd ||
- Desc.getRecurrenceKind() == RecurKind::FMulAdd) &&
+ assert((Kind == RecurKind::FAdd || Kind == RecurKind::FMulAdd) &&
"Unexpected reduction kind");
assert(Src->getType()->isVectorTy() && "Expected a vector type");
assert(!Start->getType()->isVectorTy() && "Expected a scalar type");
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index f78eb84b0c445..890eff8d28b7f 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -2237,19 +2237,21 @@ class VPInterleaveRecipe : public VPRecipeBase {
/// A recipe to represent inloop reduction operations, performing a reduction on
/// a vector operand into a scalar value, and adding the result to a chain.
/// The Operands are {ChainOp, VecOp, [Condition]}.
-class VPReductionRecipe : public VPSingleDefRecipe {
+class VPReductionRecipe : public VPRecipeWithIRFlags {
/// The recurrence decriptor for the reduction in question.
- const RecurrenceDescriptor &RdxDesc;
+ RecurKind RdxKind;
bool IsOrdered;
/// Whether the reduction is conditional.
bool IsConditional = false;
protected:
- VPReductionRecipe(const unsigned char SC, const RecurrenceDescriptor &R,
- Instruction *I, ArrayRef<VPValue *> Operands,
- VPValue *CondOp, bool IsOrdered, DebugLoc DL)
- : VPSingleDefRecipe(SC, Operands, I, DL), RdxDesc(R),
+ VPReductionRecipe(const unsigned char SC, RecurKind RdxKind,
+ FastMathFlags FMFs, Instruction *I,
+ ArrayRef<VPValue *> Operands, VPValue *CondOp,
+ bool IsOrdered, DebugLoc DL)
+ : VPRecipeWithIRFlags(SC, Operands, FMFs, DL), RdxKind(RdxKind),
IsOrdered(IsOrdered) {
+ setUnderlyingValue(I);
if (CondOp) {
IsConditional = true;
addOperand(CondOp);
@@ -2257,19 +2259,25 @@ class VPReductionRecipe : public VPSingleDefRecipe {
}
public:
- VPReductionRecipe(const RecurrenceDescriptor &R, Instruction *I,
+ VPReductionRecipe(RecurKind RdxKind, FastMathFlags FMFs, Instruction *I,
VPValue *ChainOp, VPValue *VecOp, VPValue *CondOp,
bool IsOrdered, DebugLoc DL = {})
- : VPReductionRecipe(VPDef::VPReductionSC, R, I,
+ : VPReductionRecipe(VPRecipeBase::VPReductionSC, RdxKind, FMFs, I,
ArrayRef<VPValue *>({ChainOp, VecOp}), CondOp,
IsOrdered, DL) {}
+ VPReductionRecipe(const RecurrenceDescriptor &R, Instruction *I,
+ VPValue *ChainOp, VPValue *VecOp, VPValue *CondOp,
+ bool IsOrdered, DebugLoc DL = {})
+ : VPReductionRecipe(R.getRecurrenceKind(), R.getFastMathFlags(), I,
+ ChainOp, VecOp, CondOp, IsOrdered, DL) {}
+
~VPReductionRecipe() override = default;
VPReductionRecipe *clone() override {
- return new VPReductionRecipe(RdxDesc, getUnderlyingInstr(), getChainOp(),
- getVecOp(), getCondOp(), IsOrdered,
- getDebugLoc());
+ return new VPReductionRecipe(RdxKind, getFastMathFlags(),
+ getUnderlyingInstr(), getChainOp(), getVecOp(),
+ getCondOp(), IsOrdered, getDebugLoc());
}
static inline bool classof(const VPRecipeBase *R) {
@@ -2295,9 +2303,11 @@ class VPReductionRecipe : public VPSingleDefRecipe {
VPSlotTracker &SlotTracker) const override;
#endif
- /// Return the recurrence decriptor for the in-loop reduction.
- const RecurrenceDescriptor &getRecurrenceDescriptor() const {
- return RdxDesc;
+ /// Return the recurrence kind for the in-loop reduction.
+ RecurKind getRecurrenceKind() const { return RdxKind; }
+ /// Return the opcode for the recurrence for the in-loop reduction.
+ unsigned getOpcode() const {
+ return RecurrenceDescriptor::getOpcode(RdxKind);
}
/// Return true if the in-loop reduction is ordered.
bool isOrdered() const { return IsOrdered; };
@@ -2321,7 +2331,8 @@ class VPReductionEVLRecipe : public VPReductionRecipe {
public:
VPReductionEVLRecipe(VPReductionRecipe &R, VPValue &EVL, VPValue *CondOp)
: VPReductionRecipe(
- VPDef::VPReductionEVLSC, R.getRecurrenceDescriptor(),
+ VPDef::VPReductionEVLSC, R.getRecurrenceKind(),
+ R.getFastMathFlags(),
cast_or_null<Instruction>(R.getUnderlyingValue()),
ArrayRef<VPValue *>({R.getChainOp(), R.getVecOp(), &EVL}), CondOp,
R.isOrdered(), R.getDebugLoc()) {}
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index 6e396eda6aac6..3ea737f1a088b 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -666,8 +666,17 @@ Value *VPInstruction::generate(VPTransformState &State) {
RecurrenceDescriptor::isAnyOfRecurrenceKind(RK) ||
RecurrenceDescriptor::isFindLastIVRecurrenceKind(RK)) &&
!PhiR->isInLoop()) {
- ReducedPartRdx =
- createReduction(Builder, RdxDesc, ReducedPartRdx, OrigPhi);
+ IRBuilderBase::FastMathFlagGuard FMFG(Builder);
+ Builder.setFastMathFlags(RdxDesc.getFastMathFlags());
+ if (RecurrenceDescriptor::isAnyOfRecurrenceKind(RK))
+ ReducedPartRdx =
+ createAnyOfReduction(Builder, ReducedPartRdx, RdxDesc, OrigPhi);
+ else if (RecurrenceDescriptor::isFindLastIVRecurrenceKind(RK))
+ ReducedPartRdx =
+ createFindLastIVReduction(Builder, ReducedPartRdx, RdxDesc);
+ else
+ ReducedPartRdx = createSimpleReduction(Builder, ReducedPartRdx, RK);
+
// If the reduction can be performed in a smaller type, we need to extend
// the reduction to the wider type before we branch to the original loop.
if (PhiTy != RdxDesc.getRecurrenceType())
@@ -2263,12 +2272,13 @@ void VPBlendRecipe::print(raw_ostream &O, const Twine &Indent,
void VPReductionRecipe::execute(VPTransformState &State) {
assert(!State.Lane && "Reduction being replicated.");
Value *PrevInChain = State.get(getChainOp(), /*IsScalar*/ true);
- RecurKind Kind = RdxDesc.getRecurrenceKind();
+ RecurKind Kind = getRecurrenceKind();
assert(!RecurrenceDescriptor::isAnyOfRecurrenceKind(Kind) &&
"In-loop AnyOf reductions aren't currently supported");
+
// Propagate the fast-math flags carried by the underlying instruction.
IRBuilderBase::FastMathFlagGuard FMFGuard(State.Builder);
- State.Builder.setFastMathFlags(RdxDesc.getFastMathFlags());
+ State.Builder.setFastMathFlags(getFastMathFlags());
State.setDebugLocFrom(getDebugLoc());
Value *NewVecOp = State.get(getVecOp());
if (VPValue *Cond = getCondOp()) {
@@ -2276,8 +2286,7 @@ void VPReductionRecipe::execute(VPTransformState &State) {
VectorType *VecTy = dyn_cast<VectorType>(NewVecOp->getType());
Type *ElementTy = VecTy ? VecTy->getElementType() : NewVecOp->getType();
- Value *Start =
- getRecurrenceIdentity(Kind, ElementTy, RdxDesc.getFastMathFlags());
+ Value *Start = getRecurrenceIdentity(Kind, ElementTy, getFastMathFlags());
if (State.VF.isVector())
Start = State.Builder.CreateVectorSplat(VecTy->getElementCount(), Start);
@@ -2289,21 +2298,20 @@ void VPReductionRecipe::execute(VPTransformState &State) {
if (IsOrdered) {
if (State.VF.isVector())
NewRed =
- createOrderedReduction(State.Builder, RdxDesc, NewVecOp, PrevInChain);
+ createOrderedReduction(State.Builder, Kind, NewVecOp, PrevInChain);
else
- NewRed = State.Builder.CreateBinOp(
- (Instruction::BinaryOps)RdxDesc.getOpcode(), PrevInChain, NewVecOp);
+ NewRed = State.Builder.CreateBinOp((Instruction::BinaryOps)getOpcode(),
+ PrevInChain, NewVecOp);
PrevInChain = NewRed;
NextInChain = NewRed;
} else {
PrevInChain = State.get(getChainOp(), /*IsScalar*/ true);
- NewRed = createReduction(State.Builder, RdxDesc, NewVecOp);
+ NewRed = createSimpleReduction(State.Builder, NewVecOp, Kind);
if (RecurrenceDescriptor::isMinMaxRecurrenceKind(Kind))
- NextInChain = createMinMaxOp(State.Builder, RdxDesc.getRecurrenceKind(),
- NewRed, PrevInChain);
+ NextInChain = createMinMaxOp(State.Builder, Kind, NewRed, PrevInChain);
else
NextInChain = State.Builder.CreateBinOp(
- (Instruction::BinaryOps)RdxDesc.getOpcode(), NewRed, PrevInChain);
+ (Instruction::BinaryOps)getOpcode(), NewRed, PrevInChain);
}
State.set(this, NextInChain, /*IsScalar*/ true);
}
@@ -2314,10 +2322,9 @@ void VPReductionEVLRecipe::execute(VPTransformState &State) {
auto &Builder = State.Builder;
// Propagate the fast-math flags carried by the underlying instruction.
IRBuilderBase::FastMathFlagGuard FMFGuard(Builder);
- const RecurrenceDescriptor &RdxDesc = getRecurrenceDescriptor();
- Builder.setFastMathFlags(RdxDesc.getFastMathFlags());
+ Builder.setFastMathFlags(getFastMathFlags());
- RecurKind Kind = RdxDesc.getRecurrenceKind();
+ RecurKind Kind = getRecurrenceKind();
Value *Prev = State.get(getChainOp(), /*IsScalar*/ true);
Value *VecOp = State.get(getVecOp());
Value *EVL = State.get(getEVL(), VPLane(0));
@@ -2334,24 +2341,23 @@ void VPReductionEVLRecipe::execute(VPTransformState &State) {
Value *NewRed;
if (isOrdered()) {
- NewRed = createOrderedReduction(VBuilder, RdxDesc, VecOp, Prev);
+ NewRed = createOrderedReduction(VBuilder, Kind, VecOp, Prev);
} else {
- NewRed = createSimpleReduction(VBuilder, VecOp, RdxDesc);
+ NewRed = createSimpleReduction(VBuilder, VecOp, Kind, getFastMathFlags());
if (RecurrenceDescriptor::isMinMaxRecurrenceKind(Kind))
NewRed = createMinMaxOp(Builder, Kind, NewRed, Prev);
else
- NewRed = Builder.CreateBinOp((Instruction::BinaryOps)RdxDesc.getOpcode(),
- NewRed, Prev);
+ NewRed = Builder.CreateBinOp((Instruction::BinaryOps)getOpcode(), NewRed,
+ Prev);
}
State.set(this, NewRed, /*IsScalar*/ true);
}
InstructionCost VPReductionRecipe::computeCost(ElementCount VF,
VPCostContext &Ctx) const {
- RecurKind RdxKind = RdxDesc.getRecurrenceKind();
+ RecurKind RdxKind = getRecurrenceKind();
Type *ElementTy = Ctx.Types.inferScalarType(this);
auto *VectorTy = cast<VectorType>(toVectorTy(ElementTy, VF));
- unsigned Opcode = RdxDesc.getOpcode();
// TODO: Support any-of and in-loop reductions.
assert(
@@ -2363,20 +2369,17 @@ InstructionCost VPReductionRecipe::computeCost(ElementCount VF,
ForceTargetInstructionCost.getNumOccurrences() > 0) &&
"In-loop reduction not implemented in VPlan-based cost model currently.");
- assert(ElementTy->getTypeID() == RdxDesc.getRecurrenceType()->getTypeID() &&
- "Inferred type and recurrence type mismatch.");
-
// Cost = Reduction cost + BinOp cost
InstructionCost Cost =
- Ctx.TTI.getArithmeticInstrCost(Opcode, ElementTy, Ctx.CostKind);
+ Ctx.TTI.getArithmeticInstrCost(getOpcode(), ElementTy, Ctx.CostKind);
if (RecurrenceDescriptor::isMinMaxRecurrenceKind(RdxKind)) {
Intrinsic::ID Id = getMinMaxReductionIntrinsicOp(RdxKind);
return Cost + Ctx.TTI.getMinMaxReductionCost(
- Id, VectorTy, RdxDesc.getFastMathFlags(), Ctx.CostKind);
+ Id, VectorTy, getFastMathFlags(), Ctx.CostKind);
}
return Cost + Ctx.TTI.getArithmeticReductionCost(
- Opcode, VectorTy, RdxDesc.getFastMathFlags(), Ctx.CostKind);
+ getOpcode(), VectorTy, getFastMathFlags(), Ctx.CostKind);
}
#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
@@ -2389,21 +2392,21 @@ void VPReductionRecipe::print(raw_ostream &O, const Twine &Indent,
O << " +";
if (isa<FPMathOperator>(getUnderlyingInstr()))
O << getUnderlyingInstr()->getFastMathFlags();
- O << " reduce." << Instruction::getOpcodeName(RdxDesc.getOpcode()) << " (";
+ O << " reduce."
+ << Instruction::getOpcodeName(
+ RecurrenceDescriptor::getOpcode(getRecurrenceKind()))
+ << " (";
getVecOp()->printAsOperand(O, SlotTracker);
if (isConditional()) {
O << ", ";
getCondOp()->printAsOperand(O, SlotTracker);
}
O << ")";
- if (RdxDesc.IntermediateStore)
- O << " (with final reduction value stored in invariant address sank "
- "outside of loop)";
}
void VPReductionEVLRecipe::print(raw_ostream &O, const Twine &Indent,
VPSlotTracker &SlotTracker) const {
- const RecurrenceDescriptor &RdxDesc = getRecurrenceDescriptor();
+ RecurKind Kind = getRecurrenceKind();
O << Indent << "REDUCE ";
printAsOperand(O, SlotTracker);
O << " = ";
@@ -2411,7 +2414,9 @@ void VPReductionEVLRecipe::print(raw_ostream &O, const Twine &Indent,
O << " +";
if (isa<FPMathOperator>(getUnderlyingInstr()))
O << getUnderlyingInstr()->getFastMathFlags();
- O << " vp.reduce." << Instruction::getOpcodeName(RdxDesc.getOpcode()) << " (";
+ O << " vp.reduce."
+ << Instruction::getOpcodeName(RecurrenceDescriptor::getOpcode(Kind))
+ << " (";
getVecOp()->printAsOperand(O, SlotTracker);
O << ", ";
getEVL()->printAsOperand(O, SlotTracker);
@@ -2420,9 +2425,6 @@ void VPReductionEVLRecipe::print(raw_ostream &O, const Twine &Indent,
getCondOp()->printAsOperand(O, SlotTracker);
}
O << ")";
- if (RdxDesc.IntermediateStore)
- O << " (with final reduction value stored in invariant address sank "
- "outside of loop)";
}
#endif
diff --git a/llvm/test/Transforms/LoopVectorize/vplan-printing.ll b/llvm/test/Transforms/LoopVectorize/vplan-printing.ll
index 00d8de67a3b40..791db17b9a7d1 100644
--- a/llvm/test/Transforms/LoopVectorize/vplan-printing.ll
+++ b/llvm/test/Transforms/LoopVectorize/vplan-printing.ll
@@ -234,7 +234,7 @@ define void @print_reduction_with_invariant_store(i64 %n, ptr noalias %y, ptr no
; CHECK-NEXT: CLONE ir<%arrayidx> = getelementptr inbounds ir<%y>, vp<[[IV]]>
; CHECK-NEXT: vp<[[VEC_PTR:%.+]]> = vector-pointer ir<%arrayidx>
; CHECK-NEXT: WIDEN ir<%lv> = load vp<[[VEC_PTR]]>
-; CHECK-NEXT: REDUCE ir<%red.next> = ir<%red> + fast reduce.fadd (ir<%lv>) (with final reduction value stored in invariant address sank outside of loop)
+; CHECK-NEXT: REDUCE ir<%red.next> = ir<%red> + fast reduce.fadd (ir<%lv>)
; CHECK-NEXT: EMIT vp<[[CAN_IV_NEXT]]> = add nuw vp<[[CAN_IV]]>, vp<[[VFxUF]]>
; CHECK-NEXT: EMIT branch-on-count vp<[[CAN_IV_NEXT]]>, vp<[[VTC]]>
; CHECK-NEXT: No successors
|
@llvm/pr-subscribers-vectorizers Author: Luke Lau (lukel97) ChangesVPReductionRecipes take a RecurrenceDescriptor, but only use the RecurKind and FastMathFlags in it when executing. This patch makes the recipe more lightweight by stripping it to only take the latter, which allows it inherit from VPRecipeWithIRFlags. This also allows us to remove createReduction in LoopUtils since it now only has one user in VPInstruction::ComputeReductionResult. The motiviation for this is to simplify an upcoming patch to support in-loop AnyOf reductions. For an in-loop AnyOf reduction we want to create an Or reduction, and by using RecurKind we can create an arbitrary reduction without needing a full RecurrenceDescriptor. Full diff: https://github.com/llvm/llvm-project/pull/131300.diff 5 Files Affected:
diff --git a/llvm/include/llvm/Transforms/Utils/LoopUtils.h b/llvm/include/llvm/Transforms/Utils/LoopUtils.h
index 8f4c0c88336ac..3ad7b8f17856c 100644
--- a/llvm/include/llvm/Transforms/Utils/LoopUtils.h
+++ b/llvm/include/llvm/Transforms/Utils/LoopUtils.h
@@ -411,8 +411,8 @@ Value *createSimpleReduction(IRBuilderBase &B, Value *Src,
RecurKind RdxKind);
/// Overloaded function to generate vector-predication intrinsics for
/// reduction.
-Value *createSimpleReduction(VectorBuilder &VB, Value *Src,
- const RecurrenceDescriptor &Desc);
+Value *createSimpleReduction(VectorBuilder &VB, Value *Src, RecurKind RdxKind,
+ FastMathFlags FMFs);
/// Create a reduction of the given vector \p Src for a reduction of the
/// kind RecurKind::IAnyOf or RecurKind::FAnyOf. The reduction operation is
@@ -427,20 +427,13 @@ Value *createAnyOfReduction(IRBuilderBase &B, Value *Src,
Value *createFindLastIVReduction(IRBuilderBase &B, Value *Src,
const RecurrenceDescriptor &Desc);
-/// Create a generic reduction using a recurrence descriptor \p Desc
-/// Fast-math-flags are propagated using the RecurrenceDescriptor.
-Value *createReduction(IRBuilderBase &B, const RecurrenceDescriptor &Desc,
- Value *Src, PHINode *OrigPhi = nullptr);
-
/// Create an ordered reduction intrinsic using the given recurrence
-/// descriptor \p Desc.
-Value *createOrderedReduction(IRBuilderBase &B,
- const RecurrenceDescriptor &Desc, Value *Src,
+/// kind \p Kind.
+Value *createOrderedReduction(IRBuilderBase &B, RecurKind Kind, Value *Src,
Value *Start);
/// Overloaded function to generate vector-predication intrinsics for ordered
/// reduction.
-Value *createOrderedReduction(VectorBuilder &VB,
- const RecurrenceDescriptor &Desc, Value *Src,
+Value *createOrderedReduction(VectorBuilder &VB, RecurKind Kind, Value *Src,
Value *Start);
/// Get the intersection (logical and) of all of the potential IR flags
diff --git a/llvm/lib/Transforms/Utils/LoopUtils.cpp b/llvm/lib/Transforms/Utils/LoopUtils.cpp
index 84c08556f8a25..b20ce27f8cfb3 100644
--- a/llvm/lib/Transforms/Utils/LoopUtils.cpp
+++ b/llvm/lib/Transforms/Utils/LoopUtils.cpp
@@ -1333,41 +1333,20 @@ Value *llvm::createSimpleReduction(IRBuilderBase &Builder, Value *Src,
}
Value *llvm::createSimpleReduction(VectorBuilder &VBuilder, Value *Src,
- const RecurrenceDescriptor &Desc) {
- RecurKind Kind = Desc.getRecurrenceKind();
+ RecurKind Kind, FastMathFlags FMFs) {
assert(!RecurrenceDescriptor::isAnyOfRecurrenceKind(Kind) &&
"AnyOf reduction is not supported.");
Intrinsic::ID Id = getReductionIntrinsicID(Kind);
auto *SrcTy = cast<VectorType>(Src->getType());
Type *SrcEltTy = SrcTy->getElementType();
- Value *Iden = getRecurrenceIdentity(Kind, SrcEltTy, Desc.getFastMathFlags());
+ Value *Iden = getRecurrenceIdentity(Kind, SrcEltTy, FMFs);
Value *Ops[] = {Iden, Src};
return VBuilder.createSimpleReduction(Id, SrcTy, Ops);
}
-Value *llvm::createReduction(IRBuilderBase &B,
- const RecurrenceDescriptor &Desc, Value *Src,
- PHINode *OrigPhi) {
- // TODO: Support in-order reductions based on the recurrence descriptor.
- // All ops in the reduction inherit fast-math-flags from the recurrence
- // descriptor.
- IRBuilderBase::FastMathFlagGuard FMFGuard(B);
- B.setFastMathFlags(Desc.getFastMathFlags());
-
- RecurKind RK = Desc.getRecurrenceKind();
- if (RecurrenceDescriptor::isAnyOfRecurrenceKind(RK))
- return createAnyOfReduction(B, Src, Desc, OrigPhi);
- if (RecurrenceDescriptor::isFindLastIVRecurrenceKind(RK))
- return createFindLastIVReduction(B, Src, Desc);
-
- return createSimpleReduction(B, Src, RK);
-}
-
-Value *llvm::createOrderedReduction(IRBuilderBase &B,
- const RecurrenceDescriptor &Desc,
+Value *llvm::createOrderedReduction(IRBuilderBase &B, RecurKind Kind,
Value *Src, Value *Start) {
- assert((Desc.getRecurrenceKind() == RecurKind::FAdd ||
- Desc.getRecurrenceKind() == RecurKind::FMulAdd) &&
+ assert((Kind == RecurKind::FAdd || Kind == RecurKind::FMulAdd) &&
"Unexpected reduction kind");
assert(Src->getType()->isVectorTy() && "Expected a vector type");
assert(!Start->getType()->isVectorTy() && "Expected a scalar type");
@@ -1375,11 +1354,9 @@ Value *llvm::createOrderedReduction(IRBuilderBase &B,
return B.CreateFAddReduce(Start, Src);
}
-Value *llvm::createOrderedReduction(VectorBuilder &VBuilder,
- const RecurrenceDescriptor &Desc,
+Value *llvm::createOrderedReduction(VectorBuilder &VBuilder, RecurKind Kind,
Value *Src, Value *Start) {
- assert((Desc.getRecurrenceKind() == RecurKind::FAdd ||
- Desc.getRecurrenceKind() == RecurKind::FMulAdd) &&
+ assert((Kind == RecurKind::FAdd || Kind == RecurKind::FMulAdd) &&
"Unexpected reduction kind");
assert(Src->getType()->isVectorTy() && "Expected a vector type");
assert(!Start->getType()->isVectorTy() && "Expected a scalar type");
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index f78eb84b0c445..890eff8d28b7f 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -2237,19 +2237,21 @@ class VPInterleaveRecipe : public VPRecipeBase {
/// A recipe to represent inloop reduction operations, performing a reduction on
/// a vector operand into a scalar value, and adding the result to a chain.
/// The Operands are {ChainOp, VecOp, [Condition]}.
-class VPReductionRecipe : public VPSingleDefRecipe {
+class VPReductionRecipe : public VPRecipeWithIRFlags {
/// The recurrence decriptor for the reduction in question.
- const RecurrenceDescriptor &RdxDesc;
+ RecurKind RdxKind;
bool IsOrdered;
/// Whether the reduction is conditional.
bool IsConditional = false;
protected:
- VPReductionRecipe(const unsigned char SC, const RecurrenceDescriptor &R,
- Instruction *I, ArrayRef<VPValue *> Operands,
- VPValue *CondOp, bool IsOrdered, DebugLoc DL)
- : VPSingleDefRecipe(SC, Operands, I, DL), RdxDesc(R),
+ VPReductionRecipe(const unsigned char SC, RecurKind RdxKind,
+ FastMathFlags FMFs, Instruction *I,
+ ArrayRef<VPValue *> Operands, VPValue *CondOp,
+ bool IsOrdered, DebugLoc DL)
+ : VPRecipeWithIRFlags(SC, Operands, FMFs, DL), RdxKind(RdxKind),
IsOrdered(IsOrdered) {
+ setUnderlyingValue(I);
if (CondOp) {
IsConditional = true;
addOperand(CondOp);
@@ -2257,19 +2259,25 @@ class VPReductionRecipe : public VPSingleDefRecipe {
}
public:
- VPReductionRecipe(const RecurrenceDescriptor &R, Instruction *I,
+ VPReductionRecipe(RecurKind RdxKind, FastMathFlags FMFs, Instruction *I,
VPValue *ChainOp, VPValue *VecOp, VPValue *CondOp,
bool IsOrdered, DebugLoc DL = {})
- : VPReductionRecipe(VPDef::VPReductionSC, R, I,
+ : VPReductionRecipe(VPRecipeBase::VPReductionSC, RdxKind, FMFs, I,
ArrayRef<VPValue *>({ChainOp, VecOp}), CondOp,
IsOrdered, DL) {}
+ VPReductionRecipe(const RecurrenceDescriptor &R, Instruction *I,
+ VPValue *ChainOp, VPValue *VecOp, VPValue *CondOp,
+ bool IsOrdered, DebugLoc DL = {})
+ : VPReductionRecipe(R.getRecurrenceKind(), R.getFastMathFlags(), I,
+ ChainOp, VecOp, CondOp, IsOrdered, DL) {}
+
~VPReductionRecipe() override = default;
VPReductionRecipe *clone() override {
- return new VPReductionRecipe(RdxDesc, getUnderlyingInstr(), getChainOp(),
- getVecOp(), getCondOp(), IsOrdered,
- getDebugLoc());
+ return new VPReductionRecipe(RdxKind, getFastMathFlags(),
+ getUnderlyingInstr(), getChainOp(), getVecOp(),
+ getCondOp(), IsOrdered, getDebugLoc());
}
static inline bool classof(const VPRecipeBase *R) {
@@ -2295,9 +2303,11 @@ class VPReductionRecipe : public VPSingleDefRecipe {
VPSlotTracker &SlotTracker) const override;
#endif
- /// Return the recurrence decriptor for the in-loop reduction.
- const RecurrenceDescriptor &getRecurrenceDescriptor() const {
- return RdxDesc;
+ /// Return the recurrence kind for the in-loop reduction.
+ RecurKind getRecurrenceKind() const { return RdxKind; }
+ /// Return the opcode for the recurrence for the in-loop reduction.
+ unsigned getOpcode() const {
+ return RecurrenceDescriptor::getOpcode(RdxKind);
}
/// Return true if the in-loop reduction is ordered.
bool isOrdered() const { return IsOrdered; };
@@ -2321,7 +2331,8 @@ class VPReductionEVLRecipe : public VPReductionRecipe {
public:
VPReductionEVLRecipe(VPReductionRecipe &R, VPValue &EVL, VPValue *CondOp)
: VPReductionRecipe(
- VPDef::VPReductionEVLSC, R.getRecurrenceDescriptor(),
+ VPDef::VPReductionEVLSC, R.getRecurrenceKind(),
+ R.getFastMathFlags(),
cast_or_null<Instruction>(R.getUnderlyingValue()),
ArrayRef<VPValue *>({R.getChainOp(), R.getVecOp(), &EVL}), CondOp,
R.isOrdered(), R.getDebugLoc()) {}
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index 6e396eda6aac6..3ea737f1a088b 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -666,8 +666,17 @@ Value *VPInstruction::generate(VPTransformState &State) {
RecurrenceDescriptor::isAnyOfRecurrenceKind(RK) ||
RecurrenceDescriptor::isFindLastIVRecurrenceKind(RK)) &&
!PhiR->isInLoop()) {
- ReducedPartRdx =
- createReduction(Builder, RdxDesc, ReducedPartRdx, OrigPhi);
+ IRBuilderBase::FastMathFlagGuard FMFG(Builder);
+ Builder.setFastMathFlags(RdxDesc.getFastMathFlags());
+ if (RecurrenceDescriptor::isAnyOfRecurrenceKind(RK))
+ ReducedPartRdx =
+ createAnyOfReduction(Builder, ReducedPartRdx, RdxDesc, OrigPhi);
+ else if (RecurrenceDescriptor::isFindLastIVRecurrenceKind(RK))
+ ReducedPartRdx =
+ createFindLastIVReduction(Builder, ReducedPartRdx, RdxDesc);
+ else
+ ReducedPartRdx = createSimpleReduction(Builder, ReducedPartRdx, RK);
+
// If the reduction can be performed in a smaller type, we need to extend
// the reduction to the wider type before we branch to the original loop.
if (PhiTy != RdxDesc.getRecurrenceType())
@@ -2263,12 +2272,13 @@ void VPBlendRecipe::print(raw_ostream &O, const Twine &Indent,
void VPReductionRecipe::execute(VPTransformState &State) {
assert(!State.Lane && "Reduction being replicated.");
Value *PrevInChain = State.get(getChainOp(), /*IsScalar*/ true);
- RecurKind Kind = RdxDesc.getRecurrenceKind();
+ RecurKind Kind = getRecurrenceKind();
assert(!RecurrenceDescriptor::isAnyOfRecurrenceKind(Kind) &&
"In-loop AnyOf reductions aren't currently supported");
+
// Propagate the fast-math flags carried by the underlying instruction.
IRBuilderBase::FastMathFlagGuard FMFGuard(State.Builder);
- State.Builder.setFastMathFlags(RdxDesc.getFastMathFlags());
+ State.Builder.setFastMathFlags(getFastMathFlags());
State.setDebugLocFrom(getDebugLoc());
Value *NewVecOp = State.get(getVecOp());
if (VPValue *Cond = getCondOp()) {
@@ -2276,8 +2286,7 @@ void VPReductionRecipe::execute(VPTransformState &State) {
VectorType *VecTy = dyn_cast<VectorType>(NewVecOp->getType());
Type *ElementTy = VecTy ? VecTy->getElementType() : NewVecOp->getType();
- Value *Start =
- getRecurrenceIdentity(Kind, ElementTy, RdxDesc.getFastMathFlags());
+ Value *Start = getRecurrenceIdentity(Kind, ElementTy, getFastMathFlags());
if (State.VF.isVector())
Start = State.Builder.CreateVectorSplat(VecTy->getElementCount(), Start);
@@ -2289,21 +2298,20 @@ void VPReductionRecipe::execute(VPTransformState &State) {
if (IsOrdered) {
if (State.VF.isVector())
NewRed =
- createOrderedReduction(State.Builder, RdxDesc, NewVecOp, PrevInChain);
+ createOrderedReduction(State.Builder, Kind, NewVecOp, PrevInChain);
else
- NewRed = State.Builder.CreateBinOp(
- (Instruction::BinaryOps)RdxDesc.getOpcode(), PrevInChain, NewVecOp);
+ NewRed = State.Builder.CreateBinOp((Instruction::BinaryOps)getOpcode(),
+ PrevInChain, NewVecOp);
PrevInChain = NewRed;
NextInChain = NewRed;
} else {
PrevInChain = State.get(getChainOp(), /*IsScalar*/ true);
- NewRed = createReduction(State.Builder, RdxDesc, NewVecOp);
+ NewRed = createSimpleReduction(State.Builder, NewVecOp, Kind);
if (RecurrenceDescriptor::isMinMaxRecurrenceKind(Kind))
- NextInChain = createMinMaxOp(State.Builder, RdxDesc.getRecurrenceKind(),
- NewRed, PrevInChain);
+ NextInChain = createMinMaxOp(State.Builder, Kind, NewRed, PrevInChain);
else
NextInChain = State.Builder.CreateBinOp(
- (Instruction::BinaryOps)RdxDesc.getOpcode(), NewRed, PrevInChain);
+ (Instruction::BinaryOps)getOpcode(), NewRed, PrevInChain);
}
State.set(this, NextInChain, /*IsScalar*/ true);
}
@@ -2314,10 +2322,9 @@ void VPReductionEVLRecipe::execute(VPTransformState &State) {
auto &Builder = State.Builder;
// Propagate the fast-math flags carried by the underlying instruction.
IRBuilderBase::FastMathFlagGuard FMFGuard(Builder);
- const RecurrenceDescriptor &RdxDesc = getRecurrenceDescriptor();
- Builder.setFastMathFlags(RdxDesc.getFastMathFlags());
+ Builder.setFastMathFlags(getFastMathFlags());
- RecurKind Kind = RdxDesc.getRecurrenceKind();
+ RecurKind Kind = getRecurrenceKind();
Value *Prev = State.get(getChainOp(), /*IsScalar*/ true);
Value *VecOp = State.get(getVecOp());
Value *EVL = State.get(getEVL(), VPLane(0));
@@ -2334,24 +2341,23 @@ void VPReductionEVLRecipe::execute(VPTransformState &State) {
Value *NewRed;
if (isOrdered()) {
- NewRed = createOrderedReduction(VBuilder, RdxDesc, VecOp, Prev);
+ NewRed = createOrderedReduction(VBuilder, Kind, VecOp, Prev);
} else {
- NewRed = createSimpleReduction(VBuilder, VecOp, RdxDesc);
+ NewRed = createSimpleReduction(VBuilder, VecOp, Kind, getFastMathFlags());
if (RecurrenceDescriptor::isMinMaxRecurrenceKind(Kind))
NewRed = createMinMaxOp(Builder, Kind, NewRed, Prev);
else
- NewRed = Builder.CreateBinOp((Instruction::BinaryOps)RdxDesc.getOpcode(),
- NewRed, Prev);
+ NewRed = Builder.CreateBinOp((Instruction::BinaryOps)getOpcode(), NewRed,
+ Prev);
}
State.set(this, NewRed, /*IsScalar*/ true);
}
InstructionCost VPReductionRecipe::computeCost(ElementCount VF,
VPCostContext &Ctx) const {
- RecurKind RdxKind = RdxDesc.getRecurrenceKind();
+ RecurKind RdxKind = getRecurrenceKind();
Type *ElementTy = Ctx.Types.inferScalarType(this);
auto *VectorTy = cast<VectorType>(toVectorTy(ElementTy, VF));
- unsigned Opcode = RdxDesc.getOpcode();
// TODO: Support any-of and in-loop reductions.
assert(
@@ -2363,20 +2369,17 @@ InstructionCost VPReductionRecipe::computeCost(ElementCount VF,
ForceTargetInstructionCost.getNumOccurrences() > 0) &&
"In-loop reduction not implemented in VPlan-based cost model currently.");
- assert(ElementTy->getTypeID() == RdxDesc.getRecurrenceType()->getTypeID() &&
- "Inferred type and recurrence type mismatch.");
-
// Cost = Reduction cost + BinOp cost
InstructionCost Cost =
- Ctx.TTI.getArithmeticInstrCost(Opcode, ElementTy, Ctx.CostKind);
+ Ctx.TTI.getArithmeticInstrCost(getOpcode(), ElementTy, Ctx.CostKind);
if (RecurrenceDescriptor::isMinMaxRecurrenceKind(RdxKind)) {
Intrinsic::ID Id = getMinMaxReductionIntrinsicOp(RdxKind);
return Cost + Ctx.TTI.getMinMaxReductionCost(
- Id, VectorTy, RdxDesc.getFastMathFlags(), Ctx.CostKind);
+ Id, VectorTy, getFastMathFlags(), Ctx.CostKind);
}
return Cost + Ctx.TTI.getArithmeticReductionCost(
- Opcode, VectorTy, RdxDesc.getFastMathFlags(), Ctx.CostKind);
+ getOpcode(), VectorTy, getFastMathFlags(), Ctx.CostKind);
}
#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
@@ -2389,21 +2392,21 @@ void VPReductionRecipe::print(raw_ostream &O, const Twine &Indent,
O << " +";
if (isa<FPMathOperator>(getUnderlyingInstr()))
O << getUnderlyingInstr()->getFastMathFlags();
- O << " reduce." << Instruction::getOpcodeName(RdxDesc.getOpcode()) << " (";
+ O << " reduce."
+ << Instruction::getOpcodeName(
+ RecurrenceDescriptor::getOpcode(getRecurrenceKind()))
+ << " (";
getVecOp()->printAsOperand(O, SlotTracker);
if (isConditional()) {
O << ", ";
getCondOp()->printAsOperand(O, SlotTracker);
}
O << ")";
- if (RdxDesc.IntermediateStore)
- O << " (with final reduction value stored in invariant address sank "
- "outside of loop)";
}
void VPReductionEVLRecipe::print(raw_ostream &O, const Twine &Indent,
VPSlotTracker &SlotTracker) const {
- const RecurrenceDescriptor &RdxDesc = getRecurrenceDescriptor();
+ RecurKind Kind = getRecurrenceKind();
O << Indent << "REDUCE ";
printAsOperand(O, SlotTracker);
O << " = ";
@@ -2411,7 +2414,9 @@ void VPReductionEVLRecipe::print(raw_ostream &O, const Twine &Indent,
O << " +";
if (isa<FPMathOperator>(getUnderlyingInstr()))
O << getUnderlyingInstr()->getFastMathFlags();
- O << " vp.reduce." << Instruction::getOpcodeName(RdxDesc.getOpcode()) << " (";
+ O << " vp.reduce."
+ << Instruction::getOpcodeName(RecurrenceDescriptor::getOpcode(Kind))
+ << " (";
getVecOp()->printAsOperand(O, SlotTracker);
O << ", ";
getEVL()->printAsOperand(O, SlotTracker);
@@ -2420,9 +2425,6 @@ void VPReductionEVLRecipe::print(raw_ostream &O, const Twine &Indent,
getCondOp()->printAsOperand(O, SlotTracker);
}
O << ")";
- if (RdxDesc.IntermediateStore)
- O << " (with final reduction value stored in invariant address sank "
- "outside of loop)";
}
#endif
diff --git a/llvm/test/Transforms/LoopVectorize/vplan-printing.ll b/llvm/test/Transforms/LoopVectorize/vplan-printing.ll
index 00d8de67a3b40..791db17b9a7d1 100644
--- a/llvm/test/Transforms/LoopVectorize/vplan-printing.ll
+++ b/llvm/test/Transforms/LoopVectorize/vplan-printing.ll
@@ -234,7 +234,7 @@ define void @print_reduction_with_invariant_store(i64 %n, ptr noalias %y, ptr no
; CHECK-NEXT: CLONE ir<%arrayidx> = getelementptr inbounds ir<%y>, vp<[[IV]]>
; CHECK-NEXT: vp<[[VEC_PTR:%.+]]> = vector-pointer ir<%arrayidx>
; CHECK-NEXT: WIDEN ir<%lv> = load vp<[[VEC_PTR]]>
-; CHECK-NEXT: REDUCE ir<%red.next> = ir<%red> + fast reduce.fadd (ir<%lv>) (with final reduction value stored in invariant address sank outside of loop)
+; CHECK-NEXT: REDUCE ir<%red.next> = ir<%red> + fast reduce.fadd (ir<%lv>)
; CHECK-NEXT: EMIT vp<[[CAN_IV_NEXT]]> = add nuw vp<[[CAN_IV]]>, vp<[[VFxUF]]>
; CHECK-NEXT: EMIT branch-on-count vp<[[CAN_IV_NEXT]]>, vp<[[VTC]]>
; CHECK-NEXT: No successors
|
FYI, I have two similar patches of changing the parent of VPReductionRecipe to VPRecipeWithIRFlags.
My patches are focus on eliminating the dependency of the underlying instruction when printing the fastMathFlags and control the FMF by the VPRecipeWIthIRFlags, which is different to your patch that only construct the VPReductionRecipe by the RecurrenceKind. |
Oh sorry I never noticed those patches! Yeah, adding the VPRecipeWIthIRFlags bit is the part that is similar here, I use it in this patch to store the FMF. I can try and stack this patch on top of #130881 then. |
This is split off from llvm#131300. A VPReductionRecipe will never have a AnyOf or FindLastIV recurrence, so when it calls createReduction it always calls createSimpleReduction. If we replace the call then it leaves createReduction with one user in VPInstruction::ComputeReductionResult, which we can inline and then remove.
This is split off from #131300. A VPReductionRecipe will never have a AnyOf or FindLastIV recurrence, so when it calls createReduction it always calls createSimpleReduction. If we replace the call then it leaves createReduction with one user in VPInstruction::ComputeReductionResult, which we can inline and then remove.
aa6ada6
to
58d7419
Compare
This is split off from llvm#131300. A VPReductionRecipe will never have a AnyOf or FindLastIV recurrence, so when it calls createReduction it always calls createSimpleReduction. If we replace the call then it leaves createReduction with one user in VPInstruction::ComputeReductionResult, which we can inline and then remove.
58d7419
to
c14193a
Compare
VPReductionRecipes take a RecurrenceDescriptor, but only use the RecurKind and FastMathFlags in it when executing. This patch makes the recipe more lightweight by stripping it to only take the latter, which allows it inherit from VPRecipeWithIRFlags. This also allows us to remove createReduction in LoopUtils since it now only has one user in VPInstruction::ComputeReductionResult. The motiviation for this is to simplify an upcoming patch to support in-loop AnyOf reductions. For an in-loop AnyOf reduction we want to create an Or reduction, and by using RecurKind we can create an arbitrary reduction without needing a full RecurrenceDescriptor.
734abc2
to
8fcfcb4
Compare
I've rebased this on main now that #130881 has landed, thanks @ElvisWang123 |
bool IsOrdered, DebugLoc DL) | ||
: VPRecipeWithIRFlags( | ||
SC, Operands, | ||
isa_and_nonnull<FPMathOperator>(I) ? FMFs : FastMathFlags(), DL), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks a little odd now. Can we just pass in FMFs
directly now? I'd expect that if there are no flags then FMFs
should also be empty. Also, I can imagine a user might want to set the flags without passing an instruction.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For now, FMFs are only supported for FPMath ops, and the RecurrenceDescriptor's fast-math flags are set to fast by default for non-FP ops. Arguably we may want to fix this separately. When updating to not pass the recurrence descriptor but FMFs, we should also move the check to the caller.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This isa_and_nonnull check was just added in #130881 so that non-fp RecurKinds don't end up setting the flags, because by default a non-fp RecurrenceDescriptor will have all FastMathFlags set, see #130881 (comment)
It is a bit weird though I'll give you that. I can try and move this to the (only) call site in adjustRecipesForReductions
, would that be cleaner?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've moved it to the caller in c244065
RecurKind getRecurrenceKind() const { return RdxKind; } | ||
/// Return the opcode for the recurrence for the in-loop reduction. | ||
unsigned getOpcode() const { | ||
return RecurrenceDescriptor::getOpcode(RdxKind); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we don't expect RdxKind
to change during the class lifetime, is it worth just caching the opcode in the class at the time of construction?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was just a helper method to save having to call RecurrenceDescriptor::getOpcode
at the use sites, we could also just call it directly there if that's simpler?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've inlined it in 23bcb59 which should be more similar to the existing code
…pUtils. NFC Split off from llvm#131300, this splits up RecurrenceDescriptor arguments so that arbitrary recurrence kinds may be used down the line.
…pUtils. NFC Split off from llvm#131300, this splits up RecurrenceDescriptor arguments so that arbitrary recurrence kinds may be used down the line.
…ctionRecipe-RecurKind
There's a new use of the recurrence descriptor added in #113903 by @ElvisWang123 . Would be good to check if RecurrenceKind is enough for that patch as well |
I just took a look there, #113903 should be fine since it's just using the opcode |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, as long as it doesn't complicate #113903
if (RdxDesc.IntermediateStore) | ||
O << " (with final reduction value stored in invariant address sank " | ||
"outside of loop)"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK to drop this, as the store is sunk explicitly.
I'm not in a rush so I can wait until #113903 lands first, and then I can rebase this on top of it @ElvisWang123 |
I think
Just looked and if that type is the only thing that needs storing in #113903 I think it should be fine to merge this one now |
Ok, I'll land this now then and it get it out of the way |
This is split off from llvm#131300. A VPReductionRecipe will never have a AnyOf or FindLastIV recurrence, so when it calls createReduction it always calls createSimpleReduction. If we replace the call then it leaves createReduction with one user in VPInstruction::ComputeReductionResult, which we can inline and then remove.
…pUtils. NFC (llvm#132014) Split off from llvm#131300, this splits up RecurrenceDescriptor arguments so that arbitrary recurrence kinds may be used down the line.
…CI (llvm#131300) VPReductionRecipes take a RecurrenceDescriptor, but only use the RecurKind and FastMathFlags in it when executing. This patch makes the recipe more lightweight by stripping it to only take the latter two. The motiviation for this is to simplify an upcoming patch to support in-loop AnyOf reductions. For an in-loop AnyOf reduction we want to create an Or reduction, and by using RecurKind we can create an arbitrary reduction without needing a full RecurrenceDescriptor.
VPReductionRecipes take a RecurrenceDescriptor, but only use the RecurKind and FastMathFlags in it when executing. This patch makes the recipe more lightweight by stripping it to only take the latter two.
The motiviation for this is to simplify an upcoming patch to support in-loop AnyOf reductions. For an in-loop AnyOf reduction we want to create an Or reduction, and by using RecurKind we can create an arbitrary reduction without needing a full RecurrenceDescriptor.