Skip to content

Commit a60385c

Browse files
committed
[VPlan] Implement transformation for widen-cast/widen-mul + reduction to abstract recipe.
This patch introduce two new recipes. * VPExtendedReductionRecipe - cast + reduction. * VPMulAccumulateReductionRecipe - (cast) + mul + reduction. This patch also implements the transformation that match following patterns via vplan and converts to abstract recipes for better cost estimation. * VPExtendedReduction - reduce(cast(...)) * VPMulAccumulateReductionRecipe - reduce.add(mul(...)) - reduce.add(mul(ext(...), ext(...)) - reduce.add(ext(mul(ext(...), ext(...)))) The conveted abstract recipes will be lower to the concrete recipes (widen-cast + widen-mul + reduction) just before recipe execution. Split from #113903.
1 parent 59a73bd commit a60385c

12 files changed

+838
-78
lines changed

llvm/lib/Transforms/Vectorize/LoopVectorize.cpp

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9563,10 +9563,6 @@ LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(VFRange &Range) {
95639563
"entry block must be set to a VPRegionBlock having a non-empty entry "
95649564
"VPBasicBlock");
95659565

9566-
for (ElementCount VF : Range)
9567-
Plan->addVF(VF);
9568-
Plan->setName("Initial VPlan");
9569-
95709566
// Update wide induction increments to use the same step as the corresponding
95719567
// wide induction. This enables detecting induction increments directly in
95729568
// VPlan and removes redundant splats.
@@ -9602,6 +9598,21 @@ LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(VFRange &Range) {
96029598
// Adjust the recipes for any inloop reductions.
96039599
adjustRecipesForReductions(Plan, RecipeBuilder, Range.Start);
96049600

9601+
// Transform recipes to abstract recipes if it is legal and beneficial and
9602+
// clamp the range for better cost estimation.
9603+
// TODO: Enable following transform when the EVL-version of extended-reduction
9604+
// and mulacc-reduction are implemented.
9605+
if (!CM.foldTailWithEVL()) {
9606+
VPCostContext CostCtx(CM.TTI, *CM.TLI, Legal->getWidestInductionType(), CM,
9607+
CM.CostKind);
9608+
VPlanTransforms::runPass(VPlanTransforms::convertToAbstractRecipes, *Plan,
9609+
CostCtx, Range);
9610+
}
9611+
9612+
for (ElementCount VF : Range)
9613+
Plan->addVF(VF);
9614+
Plan->setName("Initial VPlan");
9615+
96059616
// Interleave memory: for each Interleave Group we marked earlier as relevant
96069617
// for this VPlan, replace the Recipes widening its memory instructions with a
96079618
// single VPInterleaveRecipe at its insertion point.

llvm/lib/Transforms/Vectorize/VPlan.h

Lines changed: 252 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -516,6 +516,8 @@ class VPSingleDefRecipe : public VPRecipeBase, public VPValue {
516516
case VPRecipeBase::VPInstructionSC:
517517
case VPRecipeBase::VPReductionEVLSC:
518518
case VPRecipeBase::VPReductionSC:
519+
case VPRecipeBase::VPMulAccumulateReductionSC:
520+
case VPRecipeBase::VPExtendedReductionSC:
519521
case VPRecipeBase::VPReplicateSC:
520522
case VPRecipeBase::VPScalarIVStepsSC:
521523
case VPRecipeBase::VPVectorPointerSC:
@@ -600,13 +602,15 @@ class VPRecipeWithIRFlags : public VPSingleDefRecipe {
600602
DisjointFlagsTy(bool IsDisjoint) : IsDisjoint(IsDisjoint) {}
601603
};
602604

605+
struct NonNegFlagsTy {
606+
char NonNeg : 1;
607+
NonNegFlagsTy(bool IsNonNeg) : NonNeg(IsNonNeg) {}
608+
};
609+
603610
private:
604611
struct ExactFlagsTy {
605612
char IsExact : 1;
606613
};
607-
struct NonNegFlagsTy {
608-
char NonNeg : 1;
609-
};
610614
struct FastMathFlagsTy {
611615
char AllowReassoc : 1;
612616
char NoNaNs : 1;
@@ -696,6 +700,12 @@ class VPRecipeWithIRFlags : public VPSingleDefRecipe {
696700
: VPSingleDefRecipe(SC, Operands, DL), OpType(OperationType::DisjointOp),
697701
DisjointFlags(DisjointFlags) {}
698702

703+
template <typename IterT>
704+
VPRecipeWithIRFlags(const unsigned char SC, IterT Operands,
705+
NonNegFlagsTy NonNegFlags, DebugLoc DL = {})
706+
: VPSingleDefRecipe(SC, Operands, DL), OpType(OperationType::NonNegOp),
707+
NonNegFlags(NonNegFlags) {}
708+
699709
protected:
700710
VPRecipeWithIRFlags(const unsigned char SC, ArrayRef<VPValue *> Operands,
701711
GEPNoWrapFlags GEPFlags, DebugLoc DL = {})
@@ -714,7 +724,9 @@ class VPRecipeWithIRFlags : public VPSingleDefRecipe {
714724
R->getVPDefID() == VPRecipeBase::VPReductionEVLSC ||
715725
R->getVPDefID() == VPRecipeBase::VPReplicateSC ||
716726
R->getVPDefID() == VPRecipeBase::VPVectorEndPointerSC ||
717-
R->getVPDefID() == VPRecipeBase::VPVectorPointerSC;
727+
R->getVPDefID() == VPRecipeBase::VPVectorPointerSC ||
728+
R->getVPDefID() == VPRecipeBase::VPExtendedReductionSC ||
729+
R->getVPDefID() == VPRecipeBase::VPMulAccumulateReductionSC;
718730
}
719731

720732
static inline bool classof(const VPUser *U) {
@@ -811,6 +823,15 @@ class VPRecipeWithIRFlags : public VPSingleDefRecipe {
811823

812824
FastMathFlags getFastMathFlags() const;
813825

826+
/// Returns true if the recipe has non-negative flag.
827+
bool hasNonNegFlag() const { return OpType == OperationType::NonNegOp; }
828+
829+
bool isNonNeg() const {
830+
assert(OpType == OperationType::NonNegOp &&
831+
"recipe doesn't have a NNEG flag");
832+
return NonNegFlags.NonNeg;
833+
}
834+
814835
bool hasNoUnsignedWrap() const {
815836
assert(OpType == OperationType::OverflowingBinOp &&
816837
"recipe doesn't have a NUW flag");
@@ -1243,10 +1264,21 @@ class VPWidenRecipe : public VPRecipeWithIRFlags, public VPIRMetadata {
12431264
: VPRecipeWithIRFlags(VPDefOpcode, Operands, I), VPIRMetadata(I),
12441265
Opcode(I.getOpcode()) {}
12451266

1267+
template <typename IterT>
1268+
VPWidenRecipe(unsigned VPDefOpcode, unsigned Opcode,
1269+
iterator_range<IterT> Operands, bool NUW, bool NSW, DebugLoc DL)
1270+
: VPRecipeWithIRFlags(VPDefOpcode, Operands, WrapFlagsTy(NUW, NSW), DL),
1271+
Opcode(Opcode) {}
1272+
12461273
public:
12471274
VPWidenRecipe(Instruction &I, ArrayRef<VPValue *> Operands)
12481275
: VPWidenRecipe(VPDef::VPWidenSC, I, Operands) {}
12491276

1277+
template <typename IterT>
1278+
VPWidenRecipe(unsigned Opcode, iterator_range<IterT> Operands, bool NUW,
1279+
bool NSW, DebugLoc DL)
1280+
: VPWidenRecipe(VPDef::VPWidenSC, Opcode, Operands, NUW, NSW, DL) {}
1281+
12501282
~VPWidenRecipe() override = default;
12511283

12521284
VPWidenRecipe *clone() override {
@@ -1291,8 +1323,15 @@ class VPWidenCastRecipe : public VPRecipeWithIRFlags, public VPIRMetadata {
12911323
"opcode of underlying cast doesn't match");
12921324
}
12931325

1294-
VPWidenCastRecipe(Instruction::CastOps Opcode, VPValue *Op, Type *ResultTy)
1295-
: VPRecipeWithIRFlags(VPDef::VPWidenCastSC, Op), VPIRMetadata(),
1326+
VPWidenCastRecipe(Instruction::CastOps Opcode, VPValue *Op, Type *ResultTy,
1327+
DebugLoc DL = {})
1328+
: VPRecipeWithIRFlags(VPDef::VPWidenCastSC, Op, DL), VPIRMetadata(),
1329+
Opcode(Opcode), ResultTy(ResultTy) {}
1330+
1331+
VPWidenCastRecipe(Instruction::CastOps Opcode, VPValue *Op, Type *ResultTy,
1332+
bool IsNonNeg, DebugLoc DL = {})
1333+
: VPRecipeWithIRFlags(VPDef::VPWidenCastSC, Op, NonNegFlagsTy(IsNonNeg),
1334+
DL),
12961335
Opcode(Opcode), ResultTy(ResultTy) {}
12971336

12981337
~VPWidenCastRecipe() override = default;
@@ -2325,6 +2364,28 @@ class VPReductionRecipe : public VPRecipeWithIRFlags {
23252364
setUnderlyingValue(I);
23262365
}
23272366

2367+
/// For VPExtendedReductionRecipe.
2368+
/// Note that the debug location is from the extend.
2369+
VPReductionRecipe(const unsigned char SC, const RecurKind RdxKind,
2370+
ArrayRef<VPValue *> Operands, VPValue *CondOp,
2371+
bool IsOrdered, DebugLoc DL)
2372+
: VPRecipeWithIRFlags(SC, Operands, DL), RdxKind(RdxKind),
2373+
IsOrdered(IsOrdered), IsConditional(CondOp) {
2374+
if (CondOp)
2375+
addOperand(CondOp);
2376+
}
2377+
2378+
/// For VPMulAccumulateReductionRecipe.
2379+
/// Note that the NUW/NSW flags and the debug location are from the Mul.
2380+
VPReductionRecipe(const unsigned char SC, const RecurKind RdxKind,
2381+
ArrayRef<VPValue *> Operands, VPValue *CondOp,
2382+
bool IsOrdered, WrapFlagsTy WrapFlags, DebugLoc DL)
2383+
: VPRecipeWithIRFlags(SC, Operands, WrapFlags, DL), RdxKind(RdxKind),
2384+
IsOrdered(IsOrdered), IsConditional(CondOp) {
2385+
if (CondOp)
2386+
addOperand(CondOp);
2387+
}
2388+
23282389
public:
23292390
VPReductionRecipe(RecurKind RdxKind, FastMathFlags FMFs, Instruction *I,
23302391
VPValue *ChainOp, VPValue *VecOp, VPValue *CondOp,
@@ -2333,6 +2394,13 @@ class VPReductionRecipe : public VPRecipeWithIRFlags {
23332394
ArrayRef<VPValue *>({ChainOp, VecOp}), CondOp,
23342395
IsOrdered, DL) {}
23352396

2397+
VPReductionRecipe(const RecurKind RdxKind, FastMathFlags FMFs,
2398+
VPValue *ChainOp, VPValue *VecOp, VPValue *CondOp,
2399+
bool IsOrdered, DebugLoc DL = {})
2400+
: VPReductionRecipe(VPDef::VPReductionSC, RdxKind, FMFs, nullptr,
2401+
ArrayRef<VPValue *>({ChainOp, VecOp}), CondOp,
2402+
IsOrdered, DL) {}
2403+
23362404
~VPReductionRecipe() override = default;
23372405

23382406
VPReductionRecipe *clone() override {
@@ -2343,7 +2411,9 @@ class VPReductionRecipe : public VPRecipeWithIRFlags {
23432411

23442412
static inline bool classof(const VPRecipeBase *R) {
23452413
return R->getVPDefID() == VPRecipeBase::VPReductionSC ||
2346-
R->getVPDefID() == VPRecipeBase::VPReductionEVLSC;
2414+
R->getVPDefID() == VPRecipeBase::VPReductionEVLSC ||
2415+
R->getVPDefID() == VPRecipeBase::VPExtendedReductionSC ||
2416+
R->getVPDefID() == VPRecipeBase::VPMulAccumulateReductionSC;
23472417
}
23482418

23492419
static inline bool classof(const VPUser *U) {
@@ -2482,6 +2552,181 @@ class VPReductionEVLRecipe : public VPReductionRecipe {
24822552
}
24832553
};
24842554

2555+
/// A recipe to represent inloop extended reduction operations, performing a
2556+
/// reduction on a extended vector operand into a scalar value, and adding the
2557+
/// result to a chain. This recipe is abstract and needs to be lowered to
2558+
/// concrete recipes before codegen. The operands are {ChainOp, VecOp,
2559+
/// [Condition]}.
2560+
class VPExtendedReductionRecipe : public VPReductionRecipe {
2561+
/// Opcode of the extend recipe will be lowered to.
2562+
Instruction::CastOps ExtOp;
2563+
2564+
Type *ResultTy;
2565+
2566+
/// For cloning VPExtendedReductionRecipe.
2567+
VPExtendedReductionRecipe(VPExtendedReductionRecipe *ExtRed)
2568+
: VPReductionRecipe(
2569+
VPDef::VPExtendedReductionSC, ExtRed->getRecurrenceKind(),
2570+
{ExtRed->getChainOp(), ExtRed->getVecOp()}, ExtRed->getCondOp(),
2571+
ExtRed->isOrdered(), ExtRed->getDebugLoc()),
2572+
ExtOp(ExtRed->getExtOpcode()), ResultTy(ExtRed->getResultType()) {
2573+
transferFlags(*ExtRed);
2574+
}
2575+
2576+
public:
2577+
VPExtendedReductionRecipe(VPReductionRecipe *R, VPWidenCastRecipe *Ext)
2578+
: VPReductionRecipe(VPDef::VPExtendedReductionSC, R->getRecurrenceKind(),
2579+
{R->getChainOp(), Ext->getOperand(0)}, R->getCondOp(),
2580+
R->isOrdered(), Ext->getDebugLoc()),
2581+
ExtOp(Ext->getOpcode()), ResultTy(Ext->getResultType()) {
2582+
// Not all WidenCastRecipes contain nneg flag. Need to transfer flags from
2583+
// the original recipe to prevent setting wrong flags.
2584+
transferFlags(*Ext);
2585+
}
2586+
2587+
~VPExtendedReductionRecipe() override = default;
2588+
2589+
VPExtendedReductionRecipe *clone() override {
2590+
auto *Copy = new VPExtendedReductionRecipe(this);
2591+
Copy->transferFlags(*this);
2592+
return Copy;
2593+
}
2594+
2595+
VP_CLASSOF_IMPL(VPDef::VPExtendedReductionSC);
2596+
2597+
void execute(VPTransformState &State) override {
2598+
llvm_unreachable("VPExtendedReductionRecipe should be transform to "
2599+
"VPExtendedRecipe + VPReductionRecipe before execution.");
2600+
};
2601+
2602+
/// Return the cost of VPExtendedReductionRecipe.
2603+
InstructionCost computeCost(ElementCount VF,
2604+
VPCostContext &Ctx) const override;
2605+
2606+
#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
2607+
/// Print the recipe.
2608+
void print(raw_ostream &O, const Twine &Indent,
2609+
VPSlotTracker &SlotTracker) const override;
2610+
#endif
2611+
2612+
/// The scalar type after extending.
2613+
Type *getResultType() const { return ResultTy; }
2614+
2615+
/// Is the extend ZExt?
2616+
bool isZExt() const { return getExtOpcode() == Instruction::ZExt; }
2617+
2618+
/// The opcode of extend recipe.
2619+
Instruction::CastOps getExtOpcode() const { return ExtOp; }
2620+
};
2621+
2622+
/// A recipe to represent inloop MulAccumulateReduction operations, performing a
2623+
/// reduction.add on the result of vector operands (might be extended)
2624+
/// multiplication into a scalar value, and adding the result to a chain. This
2625+
/// recipe is abstract and needs to be lowered to concrete recipes before
2626+
/// codegen. The operands are {ChainOp, VecOp1, VecOp2, [Condition]}.
2627+
class VPMulAccumulateReductionRecipe : public VPReductionRecipe {
2628+
/// Opcode of the extend recipe.
2629+
Instruction::CastOps ExtOp;
2630+
2631+
/// Non-neg flag of the extend recipe.
2632+
bool IsNonNeg = false;
2633+
2634+
Type *ResultTy;
2635+
2636+
/// For cloning VPMulAccumulateReductionRecipe.
2637+
VPMulAccumulateReductionRecipe(VPMulAccumulateReductionRecipe *MulAcc)
2638+
: VPReductionRecipe(
2639+
VPDef::VPMulAccumulateReductionSC, MulAcc->getRecurrenceKind(),
2640+
{MulAcc->getChainOp(), MulAcc->getVecOp0(), MulAcc->getVecOp1()},
2641+
MulAcc->getCondOp(), MulAcc->isOrdered(),
2642+
WrapFlagsTy(MulAcc->hasNoUnsignedWrap(), MulAcc->hasNoSignedWrap()),
2643+
MulAcc->getDebugLoc()),
2644+
ExtOp(MulAcc->getExtOpcode()), IsNonNeg(MulAcc->isNonNeg()),
2645+
ResultTy(MulAcc->getResultType()) {}
2646+
2647+
public:
2648+
VPMulAccumulateReductionRecipe(VPReductionRecipe *R, VPWidenRecipe *Mul,
2649+
VPWidenCastRecipe *Ext0,
2650+
VPWidenCastRecipe *Ext1, Type *ResultTy)
2651+
: VPReductionRecipe(
2652+
VPDef::VPMulAccumulateReductionSC, R->getRecurrenceKind(),
2653+
{R->getChainOp(), Ext0->getOperand(0), Ext1->getOperand(0)},
2654+
R->getCondOp(), R->isOrdered(),
2655+
WrapFlagsTy(Mul->hasNoUnsignedWrap(), Mul->hasNoSignedWrap()),
2656+
R->getDebugLoc()),
2657+
ExtOp(Ext0->getOpcode()), ResultTy(ResultTy) {
2658+
assert(RecurrenceDescriptor::getOpcode(getRecurrenceKind()) ==
2659+
Instruction::Add &&
2660+
"The reduction instruction in MulAccumulateteReductionRecipe must "
2661+
"be Add");
2662+
// Only set the non-negative flag if the original recipe contains.
2663+
if (Ext0->hasNonNegFlag())
2664+
IsNonNeg = Ext0->isNonNeg();
2665+
}
2666+
2667+
VPMulAccumulateReductionRecipe(VPReductionRecipe *R, VPWidenRecipe *Mul)
2668+
: VPReductionRecipe(
2669+
VPDef::VPMulAccumulateReductionSC, R->getRecurrenceKind(),
2670+
{R->getChainOp(), Mul->getOperand(0), Mul->getOperand(1)},
2671+
R->getCondOp(), R->isOrdered(),
2672+
WrapFlagsTy(Mul->hasNoUnsignedWrap(), Mul->hasNoSignedWrap()),
2673+
R->getDebugLoc()),
2674+
ExtOp(Instruction::CastOps::CastOpsEnd) {
2675+
assert(RecurrenceDescriptor::getOpcode(getRecurrenceKind()) ==
2676+
Instruction::Add &&
2677+
"The reduction instruction in MulAccumulateReductionRecipe must be "
2678+
"Add");
2679+
}
2680+
2681+
~VPMulAccumulateReductionRecipe() override = default;
2682+
2683+
VPMulAccumulateReductionRecipe *clone() override {
2684+
auto *Copy = new VPMulAccumulateReductionRecipe(this);
2685+
Copy->transferFlags(*this);
2686+
return Copy;
2687+
}
2688+
2689+
VP_CLASSOF_IMPL(VPDef::VPMulAccumulateReductionSC);
2690+
2691+
void execute(VPTransformState &State) override {
2692+
llvm_unreachable("VPMulAccumulateReductionRecipe should transform to "
2693+
"VPWidenCastRecipe + "
2694+
"VPWidenRecipe + VPReductionRecipe before execution");
2695+
}
2696+
2697+
/// Return the cost of VPMulAccumulateReductionRecipe.
2698+
InstructionCost computeCost(ElementCount VF,
2699+
VPCostContext &Ctx) const override;
2700+
2701+
#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
2702+
/// Print the recipe.
2703+
void print(raw_ostream &O, const Twine &Indent,
2704+
VPSlotTracker &SlotTracker) const override;
2705+
#endif
2706+
2707+
Type *getResultType() const {
2708+
assert(isExtended() && "Only support getResultType when this recipe "
2709+
"contains implicit extend.");
2710+
return ResultTy;
2711+
}
2712+
2713+
/// The VPValue of the vector value to be extended and reduced.
2714+
VPValue *getVecOp0() const { return getOperand(1); }
2715+
VPValue *getVecOp1() const { return getOperand(2); }
2716+
2717+
/// Return if this MulAcc recipe contains extended operands.
2718+
bool isExtended() const { return ExtOp != Instruction::CastOps::CastOpsEnd; }
2719+
2720+
/// Return the opcode of the extends for the operands.
2721+
Instruction::CastOps getExtOpcode() const { return ExtOp; }
2722+
2723+
/// Return if the operands are zero extended.
2724+
bool isZExt() const { return ExtOp == Instruction::CastOps::ZExt; }
2725+
2726+
/// Return the non negative flag of the ext recipe.
2727+
bool isNonNeg() const { return IsNonNeg; }
2728+
};
2729+
24852730
/// VPReplicateRecipe replicates a given instruction producing multiple scalar
24862731
/// copies of the original scalar type, one per lane, instead of producing a
24872732
/// single copy of widened type for all lanes. If the instruction is known to be

llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,8 @@ Type *VPTypeAnalysis::inferScalarType(const VPValue *V) {
273273
// TODO: Use info from interleave group.
274274
return V->getUnderlyingValue()->getType();
275275
})
276+
.Case<VPExtendedReductionRecipe, VPMulAccumulateReductionRecipe>(
277+
[](const auto *R) { return R->getResultType(); })
276278
.Case<VPExpandSCEVRecipe>([](const VPExpandSCEVRecipe *R) {
277279
return R->getSCEV()->getType();
278280
})

0 commit comments

Comments
 (0)