Skip to content

Commit 22578ee

Browse files
committed
[VPlan] Add recipes for extended-reduction and mul-accumulate-reduction. NFC
This patch add two new recipes for extended-reduction and the mul-accumulate-reductions. Split from #113904.
1 parent d913ea3 commit 22578ee

File tree

4 files changed

+304
-5
lines changed

4 files changed

+304
-5
lines changed

llvm/lib/Transforms/Vectorize/VPlan.h

Lines changed: 232 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -525,6 +525,8 @@ class VPSingleDefRecipe : public VPRecipeBase, public VPValue {
525525
case VPRecipeBase::VPInstructionSC:
526526
case VPRecipeBase::VPReductionEVLSC:
527527
case VPRecipeBase::VPReductionSC:
528+
case VPRecipeBase::VPMulAccumulateReductionSC:
529+
case VPRecipeBase::VPExtendedReductionSC:
528530
case VPRecipeBase::VPReplicateSC:
529531
case VPRecipeBase::VPScalarIVStepsSC:
530532
case VPRecipeBase::VPVectorPointerSC:
@@ -609,13 +611,15 @@ class VPRecipeWithIRFlags : public VPSingleDefRecipe {
609611
DisjointFlagsTy(bool IsDisjoint) : IsDisjoint(IsDisjoint) {}
610612
};
611613

614+
struct NonNegFlagsTy {
615+
char NonNeg : 1;
616+
NonNegFlagsTy(bool IsNonNeg) : NonNeg(IsNonNeg) {}
617+
};
618+
612619
private:
613620
struct ExactFlagsTy {
614621
char IsExact : 1;
615622
};
616-
struct NonNegFlagsTy {
617-
char NonNeg : 1;
618-
};
619623
struct FastMathFlagsTy {
620624
char AllowReassoc : 1;
621625
char NoNaNs : 1;
@@ -709,6 +713,12 @@ class VPRecipeWithIRFlags : public VPSingleDefRecipe {
709713
: VPSingleDefRecipe(SC, Operands, DL), OpType(OperationType::DisjointOp),
710714
DisjointFlags(DisjointFlags) {}
711715

716+
template <typename IterT>
717+
VPRecipeWithIRFlags(const unsigned char SC, IterT Operands,
718+
NonNegFlagsTy NonNegFlags, DebugLoc DL = {})
719+
: VPSingleDefRecipe(SC, Operands, DL), OpType(OperationType::NonNegOp),
720+
NonNegFlags(NonNegFlags) {}
721+
712722
protected:
713723
template <typename IterT>
714724
VPRecipeWithIRFlags(const unsigned char SC, IterT Operands,
@@ -728,7 +738,9 @@ class VPRecipeWithIRFlags : public VPSingleDefRecipe {
728738
R->getVPDefID() == VPRecipeBase::VPReductionEVLSC ||
729739
R->getVPDefID() == VPRecipeBase::VPReplicateSC ||
730740
R->getVPDefID() == VPRecipeBase::VPVectorEndPointerSC ||
731-
R->getVPDefID() == VPRecipeBase::VPVectorPointerSC;
741+
R->getVPDefID() == VPRecipeBase::VPVectorPointerSC ||
742+
R->getVPDefID() == VPRecipeBase::VPExtendedReductionSC ||
743+
R->getVPDefID() == VPRecipeBase::VPMulAccumulateReductionSC;
732744
}
733745

734746
static inline bool classof(const VPUser *U) {
@@ -820,6 +832,15 @@ class VPRecipeWithIRFlags : public VPSingleDefRecipe {
820832

821833
FastMathFlags getFastMathFlags() const;
822834

835+
/// Returns true if the recipe has non-negative flag.
836+
bool hasNonNegFlag() const { return OpType == OperationType::NonNegOp; }
837+
838+
bool isNonNeg() const {
839+
assert(OpType == OperationType::NonNegOp &&
840+
"recipe doesn't have a NNEG flag");
841+
return NonNegFlags.NonNeg;
842+
}
843+
823844
bool hasNoUnsignedWrap() const {
824845
assert(OpType == OperationType::OverflowingBinOp &&
825846
"recipe doesn't have a NUW flag");
@@ -2373,6 +2394,28 @@ class VPReductionRecipe : public VPRecipeWithIRFlags {
23732394
setUnderlyingValue(I);
23742395
}
23752396

2397+
/// For VPExtendedReductionRecipe.
2398+
/// Note that the debug location is from the extend.
2399+
VPReductionRecipe(const unsigned char SC, const RecurKind RdxKind,
2400+
ArrayRef<VPValue *> Operands, VPValue *CondOp,
2401+
bool IsOrdered, DebugLoc DL)
2402+
: VPRecipeWithIRFlags(SC, Operands, DL), RdxKind(RdxKind),
2403+
IsOrdered(IsOrdered), IsConditional(CondOp) {
2404+
if (CondOp)
2405+
addOperand(CondOp);
2406+
}
2407+
2408+
/// For VPMulAccumulateReductionRecipe.
2409+
/// Note that the NUW/NSW flags and the debug location are from the Mul.
2410+
VPReductionRecipe(const unsigned char SC, const RecurKind RdxKind,
2411+
ArrayRef<VPValue *> Operands, VPValue *CondOp,
2412+
bool IsOrdered, WrapFlagsTy WrapFlags, DebugLoc DL)
2413+
: VPRecipeWithIRFlags(SC, Operands, WrapFlags, DL), RdxKind(RdxKind),
2414+
IsOrdered(IsOrdered), IsConditional(CondOp) {
2415+
if (CondOp)
2416+
addOperand(CondOp);
2417+
}
2418+
23762419
public:
23772420
VPReductionRecipe(RecurKind RdxKind, FastMathFlags FMFs, Instruction *I,
23782421
VPValue *ChainOp, VPValue *VecOp, VPValue *CondOp,
@@ -2381,6 +2424,13 @@ class VPReductionRecipe : public VPRecipeWithIRFlags {
23812424
ArrayRef<VPValue *>({ChainOp, VecOp}), CondOp,
23822425
IsOrdered, DL) {}
23832426

2427+
VPReductionRecipe(const RecurKind RdxKind, FastMathFlags FMFs,
2428+
VPValue *ChainOp, VPValue *VecOp, VPValue *CondOp,
2429+
bool IsOrdered, DebugLoc DL = {})
2430+
: VPReductionRecipe(VPDef::VPReductionSC, RdxKind, FMFs, nullptr,
2431+
ArrayRef<VPValue *>({ChainOp, VecOp}), CondOp,
2432+
IsOrdered, DL) {}
2433+
23842434
~VPReductionRecipe() override = default;
23852435

23862436
VPReductionRecipe *clone() override {
@@ -2391,7 +2441,9 @@ class VPReductionRecipe : public VPRecipeWithIRFlags {
23912441

23922442
static inline bool classof(const VPRecipeBase *R) {
23932443
return R->getVPDefID() == VPRecipeBase::VPReductionSC ||
2394-
R->getVPDefID() == VPRecipeBase::VPReductionEVLSC;
2444+
R->getVPDefID() == VPRecipeBase::VPReductionEVLSC ||
2445+
R->getVPDefID() == VPRecipeBase::VPExtendedReductionSC ||
2446+
R->getVPDefID() == VPRecipeBase::VPMulAccumulateReductionSC;
23952447
}
23962448

23972449
static inline bool classof(const VPUser *U) {
@@ -2471,6 +2523,181 @@ class VPReductionEVLRecipe : public VPReductionRecipe {
24712523
}
24722524
};
24732525

2526+
/// A recipe to represent inloop extended reduction operations, performing a
2527+
/// reduction on a extended vector operand into a scalar value, and adding the
2528+
/// result to a chain. This recipe is abstract and needs to be lowered to
2529+
/// concrete recipes before codegen. The operands are {ChainOp, VecOp,
2530+
/// [Condition]}.
2531+
class VPExtendedReductionRecipe : public VPReductionRecipe {
2532+
/// Opcode of the extend recipe will be lowered to.
2533+
Instruction::CastOps ExtOp;
2534+
2535+
Type *ResultTy;
2536+
2537+
/// For cloning VPExtendedReductionRecipe.
2538+
VPExtendedReductionRecipe(VPExtendedReductionRecipe *ExtRed)
2539+
: VPReductionRecipe(
2540+
VPDef::VPExtendedReductionSC, ExtRed->getRecurrenceKind(),
2541+
{ExtRed->getChainOp(), ExtRed->getVecOp()}, ExtRed->getCondOp(),
2542+
ExtRed->isOrdered(), ExtRed->getDebugLoc()),
2543+
ExtOp(ExtRed->getExtOpcode()), ResultTy(ExtRed->getResultType()) {
2544+
transferFlags(*ExtRed);
2545+
}
2546+
2547+
public:
2548+
VPExtendedReductionRecipe(VPReductionRecipe *R, VPWidenCastRecipe *Ext)
2549+
: VPReductionRecipe(VPDef::VPExtendedReductionSC, R->getRecurrenceKind(),
2550+
{R->getChainOp(), Ext->getOperand(0)}, R->getCondOp(),
2551+
R->isOrdered(), Ext->getDebugLoc()),
2552+
ExtOp(Ext->getOpcode()), ResultTy(Ext->getResultType()) {
2553+
// Not all WidenCastRecipes contain nneg flag. Need to transfer flags from
2554+
// the original recipe to prevent setting wrong flags.
2555+
transferFlags(*Ext);
2556+
}
2557+
2558+
~VPExtendedReductionRecipe() override = default;
2559+
2560+
VPExtendedReductionRecipe *clone() override {
2561+
auto *Copy = new VPExtendedReductionRecipe(this);
2562+
Copy->transferFlags(*this);
2563+
return Copy;
2564+
}
2565+
2566+
VP_CLASSOF_IMPL(VPDef::VPExtendedReductionSC);
2567+
2568+
void execute(VPTransformState &State) override {
2569+
llvm_unreachable("VPExtendedReductionRecipe should be transform to "
2570+
"VPExtendedRecipe + VPReductionRecipe before execution.");
2571+
};
2572+
2573+
/// Return the cost of VPExtendedReductionRecipe.
2574+
InstructionCost computeCost(ElementCount VF,
2575+
VPCostContext &Ctx) const override;
2576+
2577+
#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
2578+
/// Print the recipe.
2579+
void print(raw_ostream &O, const Twine &Indent,
2580+
VPSlotTracker &SlotTracker) const override;
2581+
#endif
2582+
2583+
/// The scalar type after extending.
2584+
Type *getResultType() const { return ResultTy; }
2585+
2586+
/// Is the extend ZExt?
2587+
bool isZExt() const { return getExtOpcode() == Instruction::ZExt; }
2588+
2589+
/// The opcode of extend recipe.
2590+
Instruction::CastOps getExtOpcode() const { return ExtOp; }
2591+
};
2592+
2593+
/// A recipe to represent inloop MulAccumulateReduction operations, performing a
2594+
/// reduction.add on the result of vector operands (might be extended)
2595+
/// multiplication into a scalar value, and adding the result to a chain. This
2596+
/// recipe is abstract and needs to be lowered to concrete recipes before
2597+
/// codegen. The operands are {ChainOp, VecOp1, VecOp2, [Condition]}.
2598+
class VPMulAccumulateReductionRecipe : public VPReductionRecipe {
2599+
/// Opcode of the extend recipe.
2600+
Instruction::CastOps ExtOp;
2601+
2602+
/// Non-neg flag of the extend recipe.
2603+
bool IsNonNeg = false;
2604+
2605+
Type *ResultTy;
2606+
2607+
/// For cloning VPMulAccumulateReductionRecipe.
2608+
VPMulAccumulateReductionRecipe(VPMulAccumulateReductionRecipe *MulAcc)
2609+
: VPReductionRecipe(
2610+
VPDef::VPMulAccumulateReductionSC, MulAcc->getRecurrenceKind(),
2611+
{MulAcc->getChainOp(), MulAcc->getVecOp0(), MulAcc->getVecOp1()},
2612+
MulAcc->getCondOp(), MulAcc->isOrdered(),
2613+
WrapFlagsTy(MulAcc->hasNoUnsignedWrap(), MulAcc->hasNoSignedWrap()),
2614+
MulAcc->getDebugLoc()),
2615+
ExtOp(MulAcc->getExtOpcode()), IsNonNeg(MulAcc->isNonNeg()),
2616+
ResultTy(MulAcc->getResultType()) {}
2617+
2618+
public:
2619+
VPMulAccumulateReductionRecipe(VPReductionRecipe *R, VPWidenRecipe *Mul,
2620+
VPWidenCastRecipe *Ext0,
2621+
VPWidenCastRecipe *Ext1, Type *ResultTy)
2622+
: VPReductionRecipe(
2623+
VPDef::VPMulAccumulateReductionSC, R->getRecurrenceKind(),
2624+
{R->getChainOp(), Ext0->getOperand(0), Ext1->getOperand(0)},
2625+
R->getCondOp(), R->isOrdered(),
2626+
WrapFlagsTy(Mul->hasNoUnsignedWrap(), Mul->hasNoSignedWrap()),
2627+
R->getDebugLoc()),
2628+
ExtOp(Ext0->getOpcode()), ResultTy(ResultTy) {
2629+
assert(RecurrenceDescriptor::getOpcode(getRecurrenceKind()) ==
2630+
Instruction::Add &&
2631+
"The reduction instruction in MulAccumulateteReductionRecipe must "
2632+
"be Add");
2633+
// Only set the non-negative flag if the original recipe contains.
2634+
if (Ext0->hasNonNegFlag())
2635+
IsNonNeg = Ext0->isNonNeg();
2636+
}
2637+
2638+
VPMulAccumulateReductionRecipe(VPReductionRecipe *R, VPWidenRecipe *Mul)
2639+
: VPReductionRecipe(
2640+
VPDef::VPMulAccumulateReductionSC, R->getRecurrenceKind(),
2641+
{R->getChainOp(), Mul->getOperand(0), Mul->getOperand(1)},
2642+
R->getCondOp(), R->isOrdered(),
2643+
WrapFlagsTy(Mul->hasNoUnsignedWrap(), Mul->hasNoSignedWrap()),
2644+
R->getDebugLoc()),
2645+
ExtOp(Instruction::CastOps::CastOpsEnd) {
2646+
assert(RecurrenceDescriptor::getOpcode(getRecurrenceKind()) ==
2647+
Instruction::Add &&
2648+
"The reduction instruction in MulAccumulateReductionRecipe must be "
2649+
"Add");
2650+
}
2651+
2652+
~VPMulAccumulateReductionRecipe() override = default;
2653+
2654+
VPMulAccumulateReductionRecipe *clone() override {
2655+
auto *Copy = new VPMulAccumulateReductionRecipe(this);
2656+
Copy->transferFlags(*this);
2657+
return Copy;
2658+
}
2659+
2660+
VP_CLASSOF_IMPL(VPDef::VPMulAccumulateReductionSC);
2661+
2662+
void execute(VPTransformState &State) override {
2663+
llvm_unreachable("VPMulAccumulateReductionRecipe should transform to "
2664+
"VPWidenCastRecipe + "
2665+
"VPWidenRecipe + VPReductionRecipe before execution");
2666+
}
2667+
2668+
/// Return the cost of VPMulAccumulateReductionRecipe.
2669+
InstructionCost computeCost(ElementCount VF,
2670+
VPCostContext &Ctx) const override;
2671+
2672+
#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
2673+
/// Print the recipe.
2674+
void print(raw_ostream &O, const Twine &Indent,
2675+
VPSlotTracker &SlotTracker) const override;
2676+
#endif
2677+
2678+
Type *getResultType() const {
2679+
assert(isExtended() && "Only support getResultType when this recipe "
2680+
"contains implicit extend.");
2681+
return ResultTy;
2682+
}
2683+
2684+
/// The VPValue of the vector value to be extended and reduced.
2685+
VPValue *getVecOp0() const { return getOperand(1); }
2686+
VPValue *getVecOp1() const { return getOperand(2); }
2687+
2688+
/// Return if this MulAcc recipe contains extended operands.
2689+
bool isExtended() const { return ExtOp != Instruction::CastOps::CastOpsEnd; }
2690+
2691+
/// Return the opcode of the extends for the operands.
2692+
Instruction::CastOps getExtOpcode() const { return ExtOp; }
2693+
2694+
/// Return if the operands are zero extended.
2695+
bool isZExt() const { return ExtOp == Instruction::CastOps::ZExt; }
2696+
2697+
/// Return the non negative flag of the ext recipe.
2698+
bool isNonNeg() const { return IsNonNeg; }
2699+
};
2700+
24742701
/// VPReplicateRecipe replicates a given instruction producing multiple scalar
24752702
/// copies of the original scalar type, one per lane, instead of producing a
24762703
/// single copy of widened type for all lanes. If the instruction is known to be

llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,8 @@ Type *VPTypeAnalysis::inferScalarType(const VPValue *V) {
273273
// TODO: Use info from interleave group.
274274
return V->getUnderlyingValue()->getType();
275275
})
276+
.Case<VPExtendedReductionRecipe, VPMulAccumulateReductionRecipe>(
277+
[](const auto *R) { return R->getResultType(); })
276278
.Case<VPExpandSCEVRecipe>([](const VPExpandSCEVRecipe *R) {
277279
return R->getSCEV()->getType();
278280
})

0 commit comments

Comments
 (0)