Skip to content

Commit 6a38817

Browse files
committed
Recoginize the reduction patterns from D93476 in computeCost().
1 parent f2c2b8f commit 6a38817

File tree

2 files changed

+159
-7
lines changed

2 files changed

+159
-7
lines changed

llvm/lib/Transforms/Vectorize/LoopVectorize.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7220,13 +7220,28 @@ LoopVectorizationPlanner::precomputeCosts(VPlan &Plan, ElementCount VF,
72207220
ChainOps.end());
72217221
// Also include the operands of instructions in the chain, as the cost-model
72227222
// may mark extends as free.
7223+
// We only handle the reduction cost in the VPlan-based cost model
7224+
// currently.
7225+
// TODO: Handle this calculation in VPWidenRecipe and VPWidenCastRecipe.
72237226
for (auto *ChainOp : ChainOps) {
72247227
for (Value *Op : ChainOp->operands()) {
72257228
if (auto *I = dyn_cast<Instruction>(Op))
72267229
ChainOpsAndOperands.insert(I);
72277230
}
72287231
}
72297232

7233+
// Since we implemented the reduction cost for the VPReductionRecipe,
7234+
// removing the instruction here to prevent VPReductionRecipe::computeCost
7235+
// be skiped.
7236+
// TODO: Remove following checks when we can fully support reduction pattern
7237+
// cost in the VPlan-based cost model.
7238+
for (auto *I : ChainOpsAndOperands) {
7239+
if (I->getOpcode() == RdxDesc.getOpcode()) {
7240+
ChainOpsAndOperands.remove(I);
7241+
break;
7242+
}
7243+
}
7244+
72307245
// Pre-compute the cost for I, if it has a reduction pattern cost.
72317246
for (Instruction *I : ChainOpsAndOperands) {
72327247
auto ReductionCost = CM.getReductionPatternCost(

llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp

Lines changed: 144 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2022,6 +2022,11 @@ void VPReductionEVLRecipe::execute(VPTransformState &State) {
20222022
State.set(this, NewRed, /*IsScalar*/ true);
20232023
}
20242024

2025+
static bool isZExtOrSExt(Instruction::CastOps CastOpcode) {
2026+
return CastOpcode == Instruction::CastOps::ZExt ||
2027+
CastOpcode == Instruction::CastOps::SExt;
2028+
}
2029+
20252030
InstructionCost VPReductionRecipe::computeCost(ElementCount VF,
20262031
VPCostContext &Ctx) const {
20272032
RecurKind RdxKind = RdxDesc.getRecurrenceKind();
@@ -2030,17 +2035,149 @@ InstructionCost VPReductionRecipe::computeCost(ElementCount VF,
20302035
TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
20312036
unsigned Opcode = RdxDesc.getOpcode();
20322037

2033-
// Cost = Reduction cost + BinOp cost
2034-
InstructionCost Cost =
2035-
Ctx.TTI.getArithmeticInstrCost(Opcode, ElementTy, CostKind);
2038+
InstructionCost BaseCost;
20362039
if (RecurrenceDescriptor::isMinMaxRecurrenceKind(RdxKind)) {
20372040
Intrinsic::ID Id = getMinMaxReductionIntrinsicOp(RdxKind);
2038-
return Cost + Ctx.TTI.getMinMaxReductionCost(
2039-
Id, VectorTy, RdxDesc.getFastMathFlags(), CostKind);
2041+
BaseCost = Ctx.TTI.getMinMaxReductionCost(
2042+
Id, VectorTy, RdxDesc.getFastMathFlags(), CostKind);
2043+
} else {
2044+
BaseCost = Ctx.TTI.getArithmeticReductionCost(
2045+
Opcode, VectorTy, RdxDesc.getFastMathFlags(), CostKind);
2046+
}
2047+
2048+
// For a call to the llvm.fmuladd intrinsic we need to add the cost of a
2049+
// normal fmul instruction to the cost of the fadd reduction.
2050+
if (RdxKind == RecurKind::FMulAdd)
2051+
BaseCost +=
2052+
Ctx.TTI.getArithmeticInstrCost(Instruction::FMul, VectorTy, CostKind);
2053+
2054+
// If we're using ordered reductions then we can just return the base cost
2055+
// here, since getArithmeticReductionCost calculates the full ordered
2056+
// reduction cost when FP reassociation is not allowed.
2057+
if (IsOrdered && Opcode == Instruction::FAdd)
2058+
return BaseCost;
2059+
2060+
// Special case for arm from D93476
2061+
// The reduction instruction can be substituted in following condition.
2062+
//
2063+
// %sa = sext <16 x i8> A to <16 x i32>
2064+
// %sb = sext <16 x i8> B to <16 x i32>
2065+
// %m = mul <16 x i32> %sa, %sb
2066+
// %r = vecreduce.add(%m)
2067+
// ->
2068+
// R = VMLADAV A, B
2069+
//
2070+
// There are other instructions for performing add reductions of
2071+
// v4i32/v8i16/v16i8 into i32 (VADDV), for doing the same with v4i32->i64
2072+
// (VADDLV) and for performing a v4i32/v8i16 MLA into an i64 (VMLALDAV).
2073+
//
2074+
// We are looking for a pattern of, and finding the minimal acceptable cost:
2075+
// reduce.add(ext(mul(ext(A), ext(B)))) or
2076+
// reduce(ext(A)) or
2077+
// reduce.add(mul(ext(A), ext(B))) or
2078+
// reduce.add(mul(A, B)) or
2079+
// reduce(A).
2080+
2081+
// Try to match reduce(ext(...))
2082+
auto *Ext = dyn_cast<VPWidenCastRecipe>(getVecOp());
2083+
if (Ext && isZExtOrSExt(Ext->getOpcode())) {
2084+
bool isUnsigned = Ext->getOpcode() == Instruction::CastOps::ZExt;
2085+
2086+
// Try to match reduce.add(ext(mul(...)))
2087+
auto *ExtTy = cast<VectorType>(
2088+
ToVectorTy(Ext->getOperand(0)->getUnderlyingValue()->getType(), VF));
2089+
auto *Mul = dyn_cast_if_present<VPWidenRecipe>(
2090+
Ext->getOperand(0)->getDefiningRecipe());
2091+
if (Mul && Mul->getOpcode() == Instruction::Mul &&
2092+
Opcode == Instruction::Add) {
2093+
auto *MulTy = cast<VectorType>(
2094+
ToVectorTy(Mul->getUnderlyingValue()->getType(), VF));
2095+
auto *InnerExt0 = dyn_cast<VPWidenCastRecipe>(Mul->getOperand(0));
2096+
auto *InnerExt1 = dyn_cast<VPWidenCastRecipe>(Mul->getOperand(1));
2097+
2098+
// Match reduce.add(ext(mul(ext(A), ext(B))))
2099+
if (InnerExt0 && isZExtOrSExt(InnerExt0->getOpcode()) && InnerExt1 &&
2100+
isZExtOrSExt(InnerExt1->getOpcode()) &&
2101+
InnerExt0->getOpcode() == InnerExt1->getOpcode()) {
2102+
Type *InnerExt0Ty =
2103+
InnerExt0->getOperand(0)->getUnderlyingValue()->getType();
2104+
Type *InnerExt1Ty =
2105+
InnerExt1->getOperand(0)->getUnderlyingValue()->getType();
2106+
// Get the largest type.
2107+
auto *MaxExtVecTy = cast<VectorType>(
2108+
ToVectorTy(InnerExt0Ty->getIntegerBitWidth() >
2109+
InnerExt1Ty->getIntegerBitWidth()
2110+
? InnerExt0Ty
2111+
: InnerExt1Ty,
2112+
VF));
2113+
InstructionCost RedCost = Ctx.TTI.getMulAccReductionCost(
2114+
isUnsigned, ElementTy, MaxExtVecTy, CostKind);
2115+
InstructionCost InnerExtCost =
2116+
Ctx.TTI.getCastInstrCost(InnerExt0->getOpcode(), MulTy, MaxExtVecTy,
2117+
TTI::CastContextHint::None, CostKind);
2118+
InstructionCost MulCost =
2119+
Ctx.TTI.getArithmeticInstrCost(Instruction::Mul, MulTy, CostKind);
2120+
InstructionCost ExtCost =
2121+
Ctx.TTI.getCastInstrCost(Ext->getOpcode(), VectorTy, ExtTy,
2122+
TTI::CastContextHint::None, CostKind);
2123+
if (RedCost.isValid() &&
2124+
RedCost < InnerExtCost * 2 + MulCost + ExtCost + BaseCost)
2125+
return RedCost;
2126+
}
2127+
}
2128+
2129+
// Match reduce(ext(A))
2130+
InstructionCost RedCost =
2131+
Ctx.TTI.getExtendedReductionCost(Opcode, isUnsigned, ElementTy, ExtTy,
2132+
RdxDesc.getFastMathFlags(), CostKind);
2133+
InstructionCost ExtCost =
2134+
Ctx.TTI.getCastInstrCost(Ext->getOpcode(), VectorTy, ExtTy,
2135+
TTI::CastContextHint::None, CostKind);
2136+
if (RedCost.isValid() && RedCost < RedCost + ExtCost)
2137+
return RedCost;
2138+
}
2139+
2140+
// Try to match reduce.add(mul(...))
2141+
auto *Mul =
2142+
dyn_cast_if_present<VPWidenRecipe>(getVecOp()->getDefiningRecipe());
2143+
if (Mul && Mul->getOpcode() == Instruction::Mul &&
2144+
Opcode == Instruction::Add) {
2145+
// Match reduce.add(mul(ext(A), ext(B)))
2146+
auto *InnerExt0 = dyn_cast<VPWidenCastRecipe>(Mul->getOperand(0));
2147+
auto *InnerExt1 = dyn_cast<VPWidenCastRecipe>(Mul->getOperand(1));
2148+
auto *MulTy =
2149+
cast<VectorType>(ToVectorTy(Mul->getUnderlyingValue()->getType(), VF));
2150+
InstructionCost MulCost =
2151+
Ctx.TTI.getArithmeticInstrCost(Instruction::Mul, MulTy, CostKind);
2152+
if (InnerExt0 && isZExtOrSExt(InnerExt0->getOpcode()) && InnerExt1 &&
2153+
InnerExt0->getOpcode() == InnerExt1->getOpcode()) {
2154+
Type *InnerExt0Ty =
2155+
InnerExt0->getOperand(0)->getUnderlyingValue()->getType();
2156+
Type *InnerExt1Ty =
2157+
InnerExt1->getOperand(0)->getUnderlyingValue()->getType();
2158+
auto *MaxInnerExtVecTy = cast<VectorType>(ToVectorTy(
2159+
InnerExt0Ty->getIntegerBitWidth() > InnerExt1Ty->getIntegerBitWidth()
2160+
? InnerExt0Ty
2161+
: InnerExt1Ty,
2162+
VF));
2163+
bool isUnsigned = InnerExt0->getOpcode() == Instruction::CastOps::ZExt;
2164+
InstructionCost RedCost = Ctx.TTI.getMulAccReductionCost(
2165+
isUnsigned, ElementTy, MaxInnerExtVecTy, CostKind);
2166+
InstructionCost InnerExtCost = Ctx.TTI.getCastInstrCost(
2167+
InnerExt0->getOpcode(), MulTy, MaxInnerExtVecTy,
2168+
TTI::CastContextHint::None, CostKind);
2169+
if (RedCost.isValid() && RedCost < BaseCost + MulCost + 2 * InnerExtCost)
2170+
return RedCost;
2171+
}
2172+
// Match reduce.add(mul)
2173+
InstructionCost RedCost =
2174+
Ctx.TTI.getMulAccReductionCost(true, ElementTy, VectorTy, CostKind);
2175+
if (RedCost.isValid() && RedCost < BaseCost + MulCost)
2176+
return RedCost;
20402177
}
20412178

2042-
return Cost + Ctx.TTI.getArithmeticReductionCost(
2043-
Opcode, VectorTy, RdxDesc.getFastMathFlags(), CostKind);
2179+
// Normal cost = Reduction cost + BinOp cost
2180+
return BaseCost + Ctx.TTI.getArithmeticInstrCost(Opcode, ElementTy, CostKind);
20442181
}
20452182

20462183
#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)

0 commit comments

Comments
 (0)