Skip to content

Commit 49aead5

Browse files
committed
[VPlan] Implment VPReductionRecipe::computeCost(). NFC
Implementation of `computeCost()` function for `VPReductionRecipe`.
1 parent 3137b6a commit 49aead5

File tree

2 files changed

+28
-0
lines changed

2 files changed

+28
-0
lines changed

llvm/lib/Transforms/Vectorize/VPlan.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2399,6 +2399,10 @@ class VPReductionRecipe : public VPSingleDefRecipe {
23992399
/// Generate the reduction in the loop
24002400
void execute(VPTransformState &State) override;
24012401

2402+
/// Return the cost of VPReductionRecipe.
2403+
InstructionCost computeCost(ElementCount VF,
2404+
VPCostContext &Ctx) const override;
2405+
24022406
#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
24032407
/// Print the recipe.
24042408
void print(raw_ostream &O, const Twine &Indent,

llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2012,6 +2012,30 @@ void VPReductionEVLRecipe::execute(VPTransformState &State) {
20122012
State.set(this, NewRed, /*IsScalar*/ true);
20132013
}
20142014

2015+
InstructionCost VPReductionRecipe::computeCost(ElementCount VF,
2016+
VPCostContext &Ctx) const {
2017+
RecurKind RdxKind = RdxDesc.getRecurrenceKind();
2018+
Type *ElementTy = RdxDesc.getRecurrenceType();
2019+
auto *VectorTy = dyn_cast<VectorType>(ToVectorTy(ElementTy, VF));
2020+
TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
2021+
unsigned Opcode = RdxDesc.getOpcode();
2022+
2023+
if (VectorTy == nullptr)
2024+
return InstructionCost::getInvalid();
2025+
2026+
// Cost = Reduction cost + BinOp cost
2027+
InstructionCost Cost =
2028+
Ctx.TTI.getArithmeticInstrCost(Opcode, ElementTy, CostKind);
2029+
if (RecurrenceDescriptor::isMinMaxRecurrenceKind(RdxKind)) {
2030+
Intrinsic::ID Id = getMinMaxReductionIntrinsicOp(RdxKind);
2031+
return Cost + Ctx.TTI.getMinMaxReductionCost(
2032+
Id, VectorTy, RdxDesc.getFastMathFlags(), CostKind);
2033+
}
2034+
2035+
return Cost + Ctx.TTI.getArithmeticReductionCost(
2036+
Opcode, VectorTy, RdxDesc.getFastMathFlags(), CostKind);
2037+
}
2038+
20152039
#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
20162040
void VPReductionRecipe::print(raw_ostream &O, const Twine &Indent,
20172041
VPSlotTracker &SlotTracker) const {

0 commit comments

Comments
 (0)