Skip to content

Commit 664c937

Browse files
authored
[VPlan] Implement VPExtendedReduction, VPMulAccumulateReductionRecipe and corresponding vplan transformations. (#137746)
This patch introduce two new recipes. * VPExtendedReductionRecipe - cast + reduction. * VPMulAccumulateReductionRecipe - (cast) + mul + reduction. This patch also implements the transformation that match following patterns via vplan and converts to abstract recipes for better cost estimation. * VPExtendedReduction - reduce(cast(...)) * VPMulAccumulateReductionRecipe - reduce.add(mul(...)) - reduce.add(mul(ext(...), ext(...)) - reduce.add(ext(mul(ext(...), ext(...)))) The converted abstract recipes will be lower to the concrete recipes (widen-cast + widen-mul + reduction) just before recipe execution. Note that this patch still relies on legacy cost model the calculate the cost for these patters. Will enable vplan-based cost decision in #113903. Split from #113903.
1 parent 13c484c commit 664c937

12 files changed

+809
-71
lines changed

llvm/lib/Transforms/Vectorize/LoopVectorize.cpp

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9568,10 +9568,6 @@ LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(VFRange &Range,
95689568
"entry block must be set to a VPRegionBlock having a non-empty entry "
95699569
"VPBasicBlock");
95709570

9571-
for (ElementCount VF : Range)
9572-
Plan->addVF(VF);
9573-
Plan->setName("Initial VPlan");
9574-
95759571
// Update wide induction increments to use the same step as the corresponding
95769572
// wide induction. This enables detecting induction increments directly in
95779573
// VPlan and removes redundant splats.
@@ -9601,6 +9597,21 @@ LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(VFRange &Range,
96019597
// Adjust the recipes for any inloop reductions.
96029598
adjustRecipesForReductions(Plan, RecipeBuilder, Range.Start);
96039599

9600+
// Transform recipes to abstract recipes if it is legal and beneficial and
9601+
// clamp the range for better cost estimation.
9602+
// TODO: Enable following transform when the EVL-version of extended-reduction
9603+
// and mulacc-reduction are implemented.
9604+
if (!CM.foldTailWithEVL()) {
9605+
VPCostContext CostCtx(CM.TTI, *CM.TLI, Legal->getWidestInductionType(), CM,
9606+
CM.CostKind);
9607+
VPlanTransforms::runPass(VPlanTransforms::convertToAbstractRecipes, *Plan,
9608+
CostCtx, Range);
9609+
}
9610+
9611+
for (ElementCount VF : Range)
9612+
Plan->addVF(VF);
9613+
Plan->setName("Initial VPlan");
9614+
96049615
// Interleave memory: for each Interleave Group we marked earlier as relevant
96059616
// for this VPlan, replace the Recipes widening its memory instructions with a
96069617
// single VPInterleaveRecipe at its insertion point.

llvm/lib/Transforms/Vectorize/VPlan.h

Lines changed: 251 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -517,6 +517,8 @@ class VPSingleDefRecipe : public VPRecipeBase, public VPValue {
517517
case VPRecipeBase::VPInstructionSC:
518518
case VPRecipeBase::VPReductionEVLSC:
519519
case VPRecipeBase::VPReductionSC:
520+
case VPRecipeBase::VPMulAccumulateReductionSC:
521+
case VPRecipeBase::VPExtendedReductionSC:
520522
case VPRecipeBase::VPReplicateSC:
521523
case VPRecipeBase::VPScalarIVStepsSC:
522524
case VPRecipeBase::VPVectorPointerSC:
@@ -601,13 +603,15 @@ class VPRecipeWithIRFlags : public VPSingleDefRecipe {
601603
DisjointFlagsTy(bool IsDisjoint) : IsDisjoint(IsDisjoint) {}
602604
};
603605

606+
struct NonNegFlagsTy {
607+
char NonNeg : 1;
608+
NonNegFlagsTy(bool IsNonNeg) : NonNeg(IsNonNeg) {}
609+
};
610+
604611
private:
605612
struct ExactFlagsTy {
606613
char IsExact : 1;
607614
};
608-
struct NonNegFlagsTy {
609-
char NonNeg : 1;
610-
};
611615
struct FastMathFlagsTy {
612616
char AllowReassoc : 1;
613617
char NoNaNs : 1;
@@ -697,6 +701,12 @@ class VPRecipeWithIRFlags : public VPSingleDefRecipe {
697701
: VPSingleDefRecipe(SC, Operands, DL), OpType(OperationType::DisjointOp),
698702
DisjointFlags(DisjointFlags) {}
699703

704+
template <typename IterT>
705+
VPRecipeWithIRFlags(const unsigned char SC, IterT Operands,
706+
NonNegFlagsTy NonNegFlags, DebugLoc DL = {})
707+
: VPSingleDefRecipe(SC, Operands, DL), OpType(OperationType::NonNegOp),
708+
NonNegFlags(NonNegFlags) {}
709+
700710
protected:
701711
VPRecipeWithIRFlags(const unsigned char SC, ArrayRef<VPValue *> Operands,
702712
GEPNoWrapFlags GEPFlags, DebugLoc DL = {})
@@ -715,7 +725,9 @@ class VPRecipeWithIRFlags : public VPSingleDefRecipe {
715725
R->getVPDefID() == VPRecipeBase::VPReductionEVLSC ||
716726
R->getVPDefID() == VPRecipeBase::VPReplicateSC ||
717727
R->getVPDefID() == VPRecipeBase::VPVectorEndPointerSC ||
718-
R->getVPDefID() == VPRecipeBase::VPVectorPointerSC;
728+
R->getVPDefID() == VPRecipeBase::VPVectorPointerSC ||
729+
R->getVPDefID() == VPRecipeBase::VPExtendedReductionSC ||
730+
R->getVPDefID() == VPRecipeBase::VPMulAccumulateReductionSC;
719731
}
720732

721733
static inline bool classof(const VPUser *U) {
@@ -812,6 +824,15 @@ class VPRecipeWithIRFlags : public VPSingleDefRecipe {
812824

813825
FastMathFlags getFastMathFlags() const;
814826

827+
/// Returns true if the recipe has non-negative flag.
828+
bool hasNonNegFlag() const { return OpType == OperationType::NonNegOp; }
829+
830+
bool isNonNeg() const {
831+
assert(OpType == OperationType::NonNegOp &&
832+
"recipe doesn't have a NNEG flag");
833+
return NonNegFlags.NonNeg;
834+
}
835+
815836
bool hasNoUnsignedWrap() const {
816837
assert(OpType == OperationType::OverflowingBinOp &&
817838
"recipe doesn't have a NUW flag");
@@ -1294,10 +1315,19 @@ class VPWidenRecipe : public VPRecipeWithIRFlags, public VPIRMetadata {
12941315
: VPRecipeWithIRFlags(VPDefOpcode, Operands, I), VPIRMetadata(I),
12951316
Opcode(I.getOpcode()) {}
12961317

1318+
VPWidenRecipe(unsigned VPDefOpcode, unsigned Opcode,
1319+
ArrayRef<VPValue *> Operands, bool NUW, bool NSW, DebugLoc DL)
1320+
: VPRecipeWithIRFlags(VPDefOpcode, Operands, WrapFlagsTy(NUW, NSW), DL),
1321+
Opcode(Opcode) {}
1322+
12971323
public:
12981324
VPWidenRecipe(Instruction &I, ArrayRef<VPValue *> Operands)
12991325
: VPWidenRecipe(VPDef::VPWidenSC, I, Operands) {}
13001326

1327+
VPWidenRecipe(unsigned Opcode, ArrayRef<VPValue *> Operands, bool NUW,
1328+
bool NSW, DebugLoc DL)
1329+
: VPWidenRecipe(VPDef::VPWidenSC, Opcode, Operands, NUW, NSW, DL) {}
1330+
13011331
~VPWidenRecipe() override = default;
13021332

13031333
VPWidenRecipe *clone() override {
@@ -1342,8 +1372,15 @@ class VPWidenCastRecipe : public VPRecipeWithIRFlags, public VPIRMetadata {
13421372
"opcode of underlying cast doesn't match");
13431373
}
13441374

1345-
VPWidenCastRecipe(Instruction::CastOps Opcode, VPValue *Op, Type *ResultTy)
1346-
: VPRecipeWithIRFlags(VPDef::VPWidenCastSC, Op), VPIRMetadata(),
1375+
VPWidenCastRecipe(Instruction::CastOps Opcode, VPValue *Op, Type *ResultTy,
1376+
DebugLoc DL = {})
1377+
: VPRecipeWithIRFlags(VPDef::VPWidenCastSC, Op, DL), VPIRMetadata(),
1378+
Opcode(Opcode), ResultTy(ResultTy) {}
1379+
1380+
VPWidenCastRecipe(Instruction::CastOps Opcode, VPValue *Op, Type *ResultTy,
1381+
bool IsNonNeg, DebugLoc DL = {})
1382+
: VPRecipeWithIRFlags(VPDef::VPWidenCastSC, Op, NonNegFlagsTy(IsNonNeg),
1383+
DL),
13471384
Opcode(Opcode), ResultTy(ResultTy) {}
13481385

13491386
~VPWidenCastRecipe() override = default;
@@ -2394,6 +2431,28 @@ class VPReductionRecipe : public VPRecipeWithIRFlags {
23942431
setUnderlyingValue(I);
23952432
}
23962433

2434+
/// For VPExtendedReductionRecipe.
2435+
/// Note that the debug location is from the extend.
2436+
VPReductionRecipe(const unsigned char SC, const RecurKind RdxKind,
2437+
ArrayRef<VPValue *> Operands, VPValue *CondOp,
2438+
bool IsOrdered, DebugLoc DL)
2439+
: VPRecipeWithIRFlags(SC, Operands, DL), RdxKind(RdxKind),
2440+
IsOrdered(IsOrdered), IsConditional(CondOp) {
2441+
if (CondOp)
2442+
addOperand(CondOp);
2443+
}
2444+
2445+
/// For VPMulAccumulateReductionRecipe.
2446+
/// Note that the NUW/NSW flags and the debug location are from the Mul.
2447+
VPReductionRecipe(const unsigned char SC, const RecurKind RdxKind,
2448+
ArrayRef<VPValue *> Operands, VPValue *CondOp,
2449+
bool IsOrdered, WrapFlagsTy WrapFlags, DebugLoc DL)
2450+
: VPRecipeWithIRFlags(SC, Operands, WrapFlags, DL), RdxKind(RdxKind),
2451+
IsOrdered(IsOrdered), IsConditional(CondOp) {
2452+
if (CondOp)
2453+
addOperand(CondOp);
2454+
}
2455+
23972456
public:
23982457
VPReductionRecipe(RecurKind RdxKind, FastMathFlags FMFs, Instruction *I,
23992458
VPValue *ChainOp, VPValue *VecOp, VPValue *CondOp,
@@ -2402,6 +2461,13 @@ class VPReductionRecipe : public VPRecipeWithIRFlags {
24022461
ArrayRef<VPValue *>({ChainOp, VecOp}), CondOp,
24032462
IsOrdered, DL) {}
24042463

2464+
VPReductionRecipe(const RecurKind RdxKind, FastMathFlags FMFs,
2465+
VPValue *ChainOp, VPValue *VecOp, VPValue *CondOp,
2466+
bool IsOrdered, DebugLoc DL = {})
2467+
: VPReductionRecipe(VPDef::VPReductionSC, RdxKind, FMFs, nullptr,
2468+
ArrayRef<VPValue *>({ChainOp, VecOp}), CondOp,
2469+
IsOrdered, DL) {}
2470+
24052471
~VPReductionRecipe() override = default;
24062472

24072473
VPReductionRecipe *clone() override {
@@ -2412,7 +2478,9 @@ class VPReductionRecipe : public VPRecipeWithIRFlags {
24122478

24132479
static inline bool classof(const VPRecipeBase *R) {
24142480
return R->getVPDefID() == VPRecipeBase::VPReductionSC ||
2415-
R->getVPDefID() == VPRecipeBase::VPReductionEVLSC;
2481+
R->getVPDefID() == VPRecipeBase::VPReductionEVLSC ||
2482+
R->getVPDefID() == VPRecipeBase::VPExtendedReductionSC ||
2483+
R->getVPDefID() == VPRecipeBase::VPMulAccumulateReductionSC;
24162484
}
24172485

24182486
static inline bool classof(const VPUser *U) {
@@ -2551,6 +2619,182 @@ class VPReductionEVLRecipe : public VPReductionRecipe {
25512619
}
25522620
};
25532621

2622+
/// A recipe to represent inloop extended reduction operations, performing a
2623+
/// reduction on a extended vector operand into a scalar value, and adding the
2624+
/// result to a chain. This recipe is abstract and needs to be lowered to
2625+
/// concrete recipes before codegen. The operands are {ChainOp, VecOp,
2626+
/// [Condition]}.
2627+
class VPExtendedReductionRecipe : public VPReductionRecipe {
2628+
/// Opcode of the extend for VecOp.
2629+
Instruction::CastOps ExtOp;
2630+
2631+
/// The scalar type after extending.
2632+
Type *ResultTy;
2633+
2634+
/// For cloning VPExtendedReductionRecipe.
2635+
VPExtendedReductionRecipe(VPExtendedReductionRecipe *ExtRed)
2636+
: VPReductionRecipe(
2637+
VPDef::VPExtendedReductionSC, ExtRed->getRecurrenceKind(),
2638+
{ExtRed->getChainOp(), ExtRed->getVecOp()}, ExtRed->getCondOp(),
2639+
ExtRed->isOrdered(), ExtRed->getDebugLoc()),
2640+
ExtOp(ExtRed->getExtOpcode()), ResultTy(ExtRed->getResultType()) {
2641+
transferFlags(*ExtRed);
2642+
setUnderlyingValue(ExtRed->getUnderlyingValue());
2643+
}
2644+
2645+
public:
2646+
VPExtendedReductionRecipe(VPReductionRecipe *R, VPWidenCastRecipe *Ext)
2647+
: VPReductionRecipe(VPDef::VPExtendedReductionSC, R->getRecurrenceKind(),
2648+
{R->getChainOp(), Ext->getOperand(0)}, R->getCondOp(),
2649+
R->isOrdered(), Ext->getDebugLoc()),
2650+
ExtOp(Ext->getOpcode()), ResultTy(Ext->getResultType()) {
2651+
assert((ExtOp == Instruction::CastOps::ZExt ||
2652+
ExtOp == Instruction::CastOps::SExt) &&
2653+
"VPExtendedReductionRecipe only supports zext and sext.");
2654+
2655+
transferFlags(*Ext);
2656+
setUnderlyingValue(R->getUnderlyingValue());
2657+
}
2658+
2659+
~VPExtendedReductionRecipe() override = default;
2660+
2661+
VPExtendedReductionRecipe *clone() override {
2662+
return new VPExtendedReductionRecipe(this);
2663+
}
2664+
2665+
VP_CLASSOF_IMPL(VPDef::VPExtendedReductionSC);
2666+
2667+
void execute(VPTransformState &State) override {
2668+
llvm_unreachable("VPExtendedReductionRecipe should be transform to "
2669+
"VPExtendedRecipe + VPReductionRecipe before execution.");
2670+
};
2671+
2672+
#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
2673+
/// Print the recipe.
2674+
void print(raw_ostream &O, const Twine &Indent,
2675+
VPSlotTracker &SlotTracker) const override;
2676+
#endif
2677+
2678+
/// The scalar type after extending.
2679+
Type *getResultType() const { return ResultTy; }
2680+
2681+
/// Is the extend ZExt?
2682+
bool isZExt() const { return getExtOpcode() == Instruction::ZExt; }
2683+
2684+
/// Get the opcode of the extend for VecOp.
2685+
Instruction::CastOps getExtOpcode() const { return ExtOp; }
2686+
};
2687+
2688+
/// A recipe to represent inloop MulAccumulateReduction operations, multiplying
2689+
/// the vector operands (which may be extended), performing a reduction.add on
2690+
/// the result, and adding the scalar result to a chain. This recipe is abstract
2691+
/// and needs to be lowered to concrete recipes before codegen. The operands are
2692+
/// {ChainOp, VecOp1, VecOp2, [Condition]}.
2693+
class VPMulAccumulateReductionRecipe : public VPReductionRecipe {
2694+
/// Opcode of the extend for VecOp1 and VecOp2.
2695+
Instruction::CastOps ExtOp;
2696+
2697+
/// Non-neg flag of the extend recipe.
2698+
bool IsNonNeg = false;
2699+
2700+
/// The scalar type after extending.
2701+
Type *ResultTy = nullptr;
2702+
2703+
/// For cloning VPMulAccumulateReductionRecipe.
2704+
VPMulAccumulateReductionRecipe(VPMulAccumulateReductionRecipe *MulAcc)
2705+
: VPReductionRecipe(
2706+
VPDef::VPMulAccumulateReductionSC, MulAcc->getRecurrenceKind(),
2707+
{MulAcc->getChainOp(), MulAcc->getVecOp0(), MulAcc->getVecOp1()},
2708+
MulAcc->getCondOp(), MulAcc->isOrdered(),
2709+
WrapFlagsTy(MulAcc->hasNoUnsignedWrap(), MulAcc->hasNoSignedWrap()),
2710+
MulAcc->getDebugLoc()),
2711+
ExtOp(MulAcc->getExtOpcode()), IsNonNeg(MulAcc->isNonNeg()),
2712+
ResultTy(MulAcc->getResultType()) {
2713+
transferFlags(*MulAcc);
2714+
setUnderlyingValue(MulAcc->getUnderlyingValue());
2715+
}
2716+
2717+
public:
2718+
VPMulAccumulateReductionRecipe(VPReductionRecipe *R, VPWidenRecipe *Mul,
2719+
VPWidenCastRecipe *Ext0,
2720+
VPWidenCastRecipe *Ext1, Type *ResultTy)
2721+
: VPReductionRecipe(
2722+
VPDef::VPMulAccumulateReductionSC, R->getRecurrenceKind(),
2723+
{R->getChainOp(), Ext0->getOperand(0), Ext1->getOperand(0)},
2724+
R->getCondOp(), R->isOrdered(),
2725+
WrapFlagsTy(Mul->hasNoUnsignedWrap(), Mul->hasNoSignedWrap()),
2726+
R->getDebugLoc()),
2727+
ExtOp(Ext0->getOpcode()), ResultTy(ResultTy) {
2728+
assert(RecurrenceDescriptor::getOpcode(getRecurrenceKind()) ==
2729+
Instruction::Add &&
2730+
"The reduction instruction in MulAccumulateteReductionRecipe must "
2731+
"be Add");
2732+
assert((ExtOp == Instruction::CastOps::ZExt ||
2733+
ExtOp == Instruction::CastOps::SExt) &&
2734+
"VPMulAccumulateReductionRecipe only supports zext and sext.");
2735+
setUnderlyingValue(R->getUnderlyingValue());
2736+
// Only set the non-negative flag if the original recipe contains.
2737+
if (Ext0->hasNonNegFlag())
2738+
IsNonNeg = Ext0->isNonNeg();
2739+
}
2740+
2741+
VPMulAccumulateReductionRecipe(VPReductionRecipe *R, VPWidenRecipe *Mul,
2742+
Type *ResultTy)
2743+
: VPReductionRecipe(
2744+
VPDef::VPMulAccumulateReductionSC, R->getRecurrenceKind(),
2745+
{R->getChainOp(), Mul->getOperand(0), Mul->getOperand(1)},
2746+
R->getCondOp(), R->isOrdered(),
2747+
WrapFlagsTy(Mul->hasNoUnsignedWrap(), Mul->hasNoSignedWrap()),
2748+
R->getDebugLoc()),
2749+
ExtOp(Instruction::CastOps::CastOpsEnd), ResultTy(ResultTy) {
2750+
assert(RecurrenceDescriptor::getOpcode(getRecurrenceKind()) ==
2751+
Instruction::Add &&
2752+
"The reduction instruction in MulAccumulateReductionRecipe must be "
2753+
"Add");
2754+
setUnderlyingValue(R->getUnderlyingValue());
2755+
}
2756+
2757+
~VPMulAccumulateReductionRecipe() override = default;
2758+
2759+
VPMulAccumulateReductionRecipe *clone() override {
2760+
return new VPMulAccumulateReductionRecipe(this);
2761+
}
2762+
2763+
VP_CLASSOF_IMPL(VPDef::VPMulAccumulateReductionSC);
2764+
2765+
void execute(VPTransformState &State) override {
2766+
llvm_unreachable("VPMulAccumulateReductionRecipe should transform to "
2767+
"VPWidenCastRecipe + "
2768+
"VPWidenRecipe + VPReductionRecipe before execution");
2769+
}
2770+
2771+
#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
2772+
/// Print the recipe.
2773+
void print(raw_ostream &O, const Twine &Indent,
2774+
VPSlotTracker &SlotTracker) const override;
2775+
#endif
2776+
2777+
Type *getResultType() const { return ResultTy; }
2778+
2779+
/// The first vector value to be extended and reduced.
2780+
VPValue *getVecOp0() const { return getOperand(1); }
2781+
2782+
/// The second vector value to be extended and reduced.
2783+
VPValue *getVecOp1() const { return getOperand(2); }
2784+
2785+
/// Return true if this recipe contains extended operands.
2786+
bool isExtended() const { return ExtOp != Instruction::CastOps::CastOpsEnd; }
2787+
2788+
/// Return the opcode of the extends for the operands.
2789+
Instruction::CastOps getExtOpcode() const { return ExtOp; }
2790+
2791+
/// Return if the operands are zero-extended.
2792+
bool isZExt() const { return ExtOp == Instruction::CastOps::ZExt; }
2793+
2794+
/// Return true if the operand extends have the non-negative flag.
2795+
bool isNonNeg() const { return IsNonNeg; }
2796+
};
2797+
25542798
/// VPReplicateRecipe replicates a given instruction producing multiple scalar
25552799
/// copies of the original scalar type, one per lane, instead of producing a
25562800
/// single copy of widened type for all lanes. If the instruction is known to be

llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,8 @@ Type *VPTypeAnalysis::inferScalarType(const VPValue *V) {
293293
// TODO: Use info from interleave group.
294294
return V->getUnderlyingValue()->getType();
295295
})
296+
.Case<VPExtendedReductionRecipe, VPMulAccumulateReductionRecipe>(
297+
[](const auto *R) { return R->getResultType(); })
296298
.Case<VPExpandSCEVRecipe>([](const VPExpandSCEVRecipe *R) {
297299
return R->getSCEV()->getType();
298300
})

0 commit comments

Comments
 (0)