Skip to content

Commit 442d1dd

Browse files
committed
[VPlan] Impl VPlan-based pattern match for ExtendedRed and MulAccRed. NFCI
This patch implement the VPlan-based pattern match for extendedReduction and MulAccReduction. In above reduction patterns, extened instructions and mul instruction can fold into reduction instruction and the cost is free. We add `FoldedRecipes` in the `VPCostContext` to put recipes that can be folded into other recipes. ExtendedReductionPatterns: reduce(ext(...)) MulAccReductionPatterns: reduce.add(mul(...)) reduce.add(mul(ext(...), ext(...))) reduce.add(ext(mul(...))) reduce.add(ext(mul(ext(...), ext(...)))) Ref: Original instruction based implementation: https://reviews.llvm.org/D93476
1 parent 35bdec7 commit 442d1dd

File tree

3 files changed

+129
-57
lines changed

3 files changed

+129
-57
lines changed

llvm/lib/Transforms/Vectorize/LoopVectorize.cpp

Lines changed: 0 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -7303,51 +7303,6 @@ LoopVectorizationPlanner::precomputeCosts(VPlan &Plan, ElementCount VF,
73037303
Cost += ReductionCost;
73047304
continue;
73057305
}
7306-
7307-
const auto &ChainOps = RdxDesc.getReductionOpChain(RedPhi, OrigLoop);
7308-
SetVector<Instruction *> ChainOpsAndOperands(ChainOps.begin(),
7309-
ChainOps.end());
7310-
auto IsZExtOrSExt = [](const unsigned Opcode) -> bool {
7311-
return Opcode == Instruction::ZExt || Opcode == Instruction::SExt;
7312-
};
7313-
// Also include the operands of instructions in the chain, as the cost-model
7314-
// may mark extends as free.
7315-
//
7316-
// For ARM, some of the instruction can folded into the reducion
7317-
// instruction. So we need to mark all folded instructions free.
7318-
// For example: We can fold reduce(mul(ext(A), ext(B))) into one
7319-
// instruction.
7320-
for (auto *ChainOp : ChainOps) {
7321-
for (Value *Op : ChainOp->operands()) {
7322-
if (auto *I = dyn_cast<Instruction>(Op)) {
7323-
ChainOpsAndOperands.insert(I);
7324-
if (I->getOpcode() == Instruction::Mul) {
7325-
auto *Ext0 = dyn_cast<Instruction>(I->getOperand(0));
7326-
auto *Ext1 = dyn_cast<Instruction>(I->getOperand(1));
7327-
if (Ext0 && IsZExtOrSExt(Ext0->getOpcode()) && Ext1 &&
7328-
Ext0->getOpcode() == Ext1->getOpcode()) {
7329-
ChainOpsAndOperands.insert(Ext0);
7330-
ChainOpsAndOperands.insert(Ext1);
7331-
}
7332-
}
7333-
}
7334-
}
7335-
}
7336-
7337-
// Pre-compute the cost for I, if it has a reduction pattern cost.
7338-
for (Instruction *I : ChainOpsAndOperands) {
7339-
auto ReductionCost = CM.getReductionPatternCost(
7340-
I, VF, ToVectorTy(I->getType(), VF), TTI::TCK_RecipThroughput);
7341-
if (!ReductionCost)
7342-
continue;
7343-
7344-
assert(!CostCtx.SkipCostComputation.contains(I) &&
7345-
"reduction op visited multiple times");
7346-
CostCtx.SkipCostComputation.insert(I);
7347-
LLVM_DEBUG(dbgs() << "Cost of " << ReductionCost << " for VF " << VF
7348-
<< ":\n in-loop reduction " << *I << "\n");
7349-
Cost += *ReductionCost;
7350-
}
73517306
}
73527307

73537308
// Pre-compute the costs for branches except for the backedge, as the number

llvm/lib/Transforms/Vectorize/VPlan.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -725,6 +725,8 @@ struct VPCostContext {
725725
LLVMContext &LLVMCtx;
726726
LoopVectorizationCostModel &CM;
727727
SmallPtrSet<Instruction *, 8> SkipCostComputation;
728+
/// Contains recipes that are folded into other recipes.
729+
SmallDenseMap<ElementCount, SmallPtrSet<VPRecipeBase *, 4>, 4> FoldedRecipes;
728730

729731
VPCostContext(const TargetTransformInfo &TTI, const TargetLibraryInfo &TLI,
730732
Type *CanIVTy, LoopVectorizationCostModel &CM)

llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp

Lines changed: 127 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -299,7 +299,9 @@ InstructionCost VPRecipeBase::cost(ElementCount VF, VPCostContext &Ctx) {
299299
UI = &WidenMem->getIngredient();
300300

301301
InstructionCost RecipeCost;
302-
if (UI && Ctx.skipCostComputation(UI, VF.isVector())) {
302+
if ((UI && Ctx.skipCostComputation(UI, VF.isVector())) ||
303+
(Ctx.FoldedRecipes.contains(VF) &&
304+
Ctx.FoldedRecipes.at(VF).contains(this))) {
303305
RecipeCost = 0;
304306
} else {
305307
RecipeCost = computeCost(VF, Ctx);
@@ -2188,30 +2190,143 @@ InstructionCost VPReductionRecipe::computeCost(ElementCount VF,
21882190
TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
21892191
unsigned Opcode = RdxDesc.getOpcode();
21902192

2191-
// TODO: Support any-of and in-loop reductions.
2193+
// TODO: Support any-of reductions.
21922194
assert(
21932195
(!RecurrenceDescriptor::isAnyOfRecurrenceKind(RdxKind) ||
21942196
ForceTargetInstructionCost.getNumOccurrences() > 0) &&
21952197
"Any-of reduction not implemented in VPlan-based cost model currently.");
2196-
assert(
2197-
(!cast<VPReductionPHIRecipe>(getOperand(0))->isInLoop() ||
2198-
ForceTargetInstructionCost.getNumOccurrences() > 0) &&
2199-
"In-loop reduction not implemented in VPlan-based cost model currently.");
22002198

22012199
assert(ElementTy->getTypeID() == RdxDesc.getRecurrenceType()->getTypeID() &&
22022200
"Inferred type and recurrence type mismatch.");
22032201

2204-
// Cost = Reduction cost + BinOp cost
2205-
InstructionCost Cost =
2202+
// BaseCost = Reduction cost + BinOp cost
2203+
InstructionCost BaseCost =
22062204
Ctx.TTI.getArithmeticInstrCost(Opcode, ElementTy, CostKind);
22072205
if (RecurrenceDescriptor::isMinMaxRecurrenceKind(RdxKind)) {
22082206
Intrinsic::ID Id = getMinMaxReductionIntrinsicOp(RdxKind);
2209-
return Cost + Ctx.TTI.getMinMaxReductionCost(
2210-
Id, VectorTy, RdxDesc.getFastMathFlags(), CostKind);
2207+
BaseCost += Ctx.TTI.getMinMaxReductionCost(
2208+
Id, VectorTy, RdxDesc.getFastMathFlags(), CostKind);
2209+
} else {
2210+
BaseCost += Ctx.TTI.getArithmeticReductionCost(
2211+
Opcode, VectorTy, RdxDesc.getFastMathFlags(), CostKind);
22112212
}
22122213

2213-
return Cost + Ctx.TTI.getArithmeticReductionCost(
2214-
Opcode, VectorTy, RdxDesc.getFastMathFlags(), CostKind);
2214+
using namespace llvm::VPlanPatternMatch;
2215+
auto GetMulAccReductionCost =
2216+
[&](const VPReductionRecipe *Red) -> InstructionCost {
2217+
VPValue *A, *B;
2218+
InstructionCost InnerExt0Cost = 0;
2219+
InstructionCost InnerExt1Cost = 0;
2220+
InstructionCost ExtCost = 0;
2221+
InstructionCost MulCost = 0;
2222+
2223+
VectorType *SrcVecTy = VectorTy;
2224+
Type *InnerExt0Ty;
2225+
Type *InnerExt1Ty;
2226+
Type *MaxInnerExtTy;
2227+
bool IsUnsigned = true;
2228+
bool HasOuterExt = false;
2229+
2230+
auto *Ext = dyn_cast_if_present<VPWidenCastRecipe>(
2231+
Red->getVecOp()->getDefiningRecipe());
2232+
VPRecipeBase *Mul;
2233+
// Try to match outer extend reduce.add(ext(...))
2234+
if (Ext && match(Ext, m_ZExtOrSExt(m_VPValue())) &&
2235+
cast<VPWidenCastRecipe>(Ext)->getNumUsers() == 1) {
2236+
IsUnsigned =
2237+
Ext->getOpcode() == Instruction::CastOps::ZExt ? true : false;
2238+
ExtCost = Ext->computeCost(VF, Ctx);
2239+
Mul = Ext->getOperand(0)->getDefiningRecipe();
2240+
HasOuterExt = true;
2241+
} else {
2242+
Mul = Red->getVecOp()->getDefiningRecipe();
2243+
}
2244+
2245+
// Match reduce.add(mul())
2246+
if (Mul && match(Mul, m_Mul(m_VPValue(A), m_VPValue(B))) &&
2247+
cast<VPWidenRecipe>(Mul)->getNumUsers() == 1) {
2248+
MulCost = cast<VPWidenRecipe>(Mul)->computeCost(VF, Ctx);
2249+
auto *InnerExt0 =
2250+
dyn_cast_if_present<VPWidenCastRecipe>(A->getDefiningRecipe());
2251+
auto *InnerExt1 =
2252+
dyn_cast_if_present<VPWidenCastRecipe>(B->getDefiningRecipe());
2253+
bool HasInnerExt = false;
2254+
// Try to match inner extends.
2255+
if (InnerExt0 && InnerExt1 &&
2256+
match(InnerExt0, m_ZExtOrSExt(m_VPValue())) &&
2257+
match(InnerExt1, m_ZExtOrSExt(m_VPValue())) &&
2258+
InnerExt0->getOpcode() == InnerExt1->getOpcode() &&
2259+
(InnerExt0->getNumUsers() > 0 &&
2260+
!InnerExt0->hasMoreThanOneUniqueUser()) &&
2261+
(InnerExt1->getNumUsers() > 0 &&
2262+
!InnerExt1->hasMoreThanOneUniqueUser())) {
2263+
InnerExt0Cost = InnerExt0->computeCost(VF, Ctx);
2264+
InnerExt1Cost = InnerExt1->computeCost(VF, Ctx);
2265+
Type *InnerExt0Ty = Ctx.Types.inferScalarType(InnerExt0->getOperand(0));
2266+
Type *InnerExt1Ty = Ctx.Types.inferScalarType(InnerExt1->getOperand(0));
2267+
Type *MaxInnerExtTy = InnerExt0Ty->getIntegerBitWidth() >
2268+
InnerExt1Ty->getIntegerBitWidth()
2269+
? InnerExt0Ty
2270+
: InnerExt1Ty;
2271+
SrcVecTy = cast<VectorType>(ToVectorTy(MaxInnerExtTy, VF));
2272+
IsUnsigned = true;
2273+
HasInnerExt = true;
2274+
}
2275+
InstructionCost MulAccRedCost = Ctx.TTI.getMulAccReductionCost(
2276+
IsUnsigned, ElementTy, SrcVecTy, CostKind);
2277+
// Check if folding ext/mul into MulAccReduction is profitable.
2278+
if (MulAccRedCost.isValid() &&
2279+
MulAccRedCost <
2280+
ExtCost + MulCost + InnerExt0Cost + InnerExt1Cost + BaseCost) {
2281+
if (HasInnerExt) {
2282+
Ctx.FoldedRecipes[VF].insert(InnerExt0);
2283+
Ctx.FoldedRecipes[VF].insert(InnerExt1);
2284+
}
2285+
Ctx.FoldedRecipes[VF].insert(Mul);
2286+
if (HasOuterExt)
2287+
Ctx.FoldedRecipes[VF].insert(Ext);
2288+
return MulAccRedCost;
2289+
}
2290+
}
2291+
return InstructionCost::getInvalid();
2292+
};
2293+
2294+
// Match reduce(ext(...))
2295+
auto GetExtendedReductionCost =
2296+
[&](const VPReductionRecipe *Red) -> InstructionCost {
2297+
VPValue *VecOp = Red->getVecOp();
2298+
VPValue *A;
2299+
if (match(VecOp, m_ZExtOrSExt(m_VPValue(A))) && VecOp->getNumUsers() == 1) {
2300+
VPWidenCastRecipe *Ext =
2301+
cast<VPWidenCastRecipe>(VecOp->getDefiningRecipe());
2302+
bool IsUnsigned = Ext->getOpcode() == Instruction::CastOps::ZExt;
2303+
InstructionCost ExtCost = Ext->computeCost(VF, Ctx);
2304+
auto *ExtVecTy =
2305+
cast<VectorType>(ToVectorTy(Ctx.Types.inferScalarType(A), VF));
2306+
InstructionCost ExtendedRedCost = Ctx.TTI.getExtendedReductionCost(
2307+
Opcode, IsUnsigned, ElementTy, ExtVecTy, RdxDesc.getFastMathFlags(),
2308+
CostKind);
2309+
// Check if folding ext into ExtendedReduction is profitable.
2310+
if (ExtendedRedCost.isValid() && ExtendedRedCost < ExtCost + BaseCost) {
2311+
Ctx.FoldedRecipes[VF].insert(Ext);
2312+
return ExtendedRedCost;
2313+
}
2314+
}
2315+
return InstructionCost::getInvalid();
2316+
};
2317+
2318+
// Match MulAccReduction patterns.
2319+
InstructionCost MulAccCost = GetMulAccReductionCost(this);
2320+
if (MulAccCost.isValid())
2321+
return MulAccCost;
2322+
2323+
// Match ExtendedReduction patterns.
2324+
InstructionCost ExtendedCost = GetExtendedReductionCost(this);
2325+
if (ExtendedCost.isValid())
2326+
return ExtendedCost;
2327+
2328+
// Default cost.
2329+
return BaseCost;
22152330
}
22162331

22172332
#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)

0 commit comments

Comments
 (0)