Skip to content

[LoopVectorizer] Bundle partial reductions inside VPMulAccumulateReductionRecipe #136173

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions llvm/include/llvm/Analysis/TargetTransformInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,8 @@ class TargetTransformInfo {
/// Get the kind of extension that an instruction represents.
LLVM_ABI static PartialReductionExtendKind
getPartialReductionExtendKind(Instruction *I);
static PartialReductionExtendKind
getPartialReductionExtendKind(Instruction::CastOps ExtOpcode);

/// Construct a TTI object using a type implementing the \c Concept
/// API below.
Expand Down
19 changes: 15 additions & 4 deletions llvm/lib/Analysis/TargetTransformInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -995,13 +995,24 @@ InstructionCost TargetTransformInfo::getShuffleCost(

TargetTransformInfo::PartialReductionExtendKind
TargetTransformInfo::getPartialReductionExtendKind(Instruction *I) {
if (isa<SExtInst>(I))
return PR_SignExtend;
if (isa<ZExtInst>(I))
return PR_ZeroExtend;
if (auto *Cast = dyn_cast<CastInst>(I))
return getPartialReductionExtendKind(Cast->getOpcode());
return PR_None;
}

TargetTransformInfo::PartialReductionExtendKind
TargetTransformInfo::getPartialReductionExtendKind(
Instruction::CastOps ExtOpcode) {
switch (ExtOpcode) {
case Instruction::CastOps::ZExt:
return PR_ZeroExtend;
case Instruction::CastOps::SExt:
return PR_SignExtend;
default:
llvm_unreachable("Unexpected cast opcode");
}
}

TTI::CastContextHint
TargetTransformInfo::getCastContextHint(const Instruction *I) {
if (!I)
Expand Down
3 changes: 0 additions & 3 deletions llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8639,9 +8639,6 @@ VPRecipeBuilder::tryToCreatePartialReduction(Instruction *Reduction,
"Expected an ADD or SUB operation for predicated partial "
"reductions (because the neutral element in the mask is zero)!");
Cond = getBlockInMask(Builder.getInsertBlock());
VPValue *Zero =
Plan.getOrAddLiveIn(ConstantInt::get(Reduction->getType(), 0));
BinOp = Builder.createSelect(Cond, BinOp, Zero, Reduction->getDebugLoc());
}
return new VPPartialReductionRecipe(ReductionOpcode, Accumulator, BinOp, Cond,
ScaleFactor, Reduction);
Expand Down
23 changes: 19 additions & 4 deletions llvm/lib/Transforms/Vectorize/VPlan.h
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried downloading this patch and applying to the HEAD of LLVM and patch said this diff had already been applied. Does the PR need rebasing?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah perhaps this is my mistake. You did say it depends upon #113903. :)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah that's the case :). Let me know if you have any issues applying it after applying 113903 too.

Original file line number Diff line number Diff line change
Expand Up @@ -2470,7 +2470,8 @@ class VPReductionRecipe : public VPRecipeWithIRFlags {
return R->getVPDefID() == VPRecipeBase::VPReductionSC ||
R->getVPDefID() == VPRecipeBase::VPReductionEVLSC ||
R->getVPDefID() == VPRecipeBase::VPExtendedReductionSC ||
R->getVPDefID() == VPRecipeBase::VPMulAccumulateReductionSC;
R->getVPDefID() == VPRecipeBase::VPMulAccumulateReductionSC ||
R->getVPDefID() == VPRecipeBase::VPPartialReductionSC;
}

static inline bool classof(const VPUser *U) {
Expand Down Expand Up @@ -2559,6 +2560,9 @@ class VPPartialReductionRecipe : public VPReductionRecipe {
/// Get the factor that the VF of this recipe's output should be scaled by.
unsigned getVFScaleFactor() const { return VFScaleFactor; }

/// Get the binary op this reduction is applied to.
VPValue *getBinOp() const { return getOperand(1); }

#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
/// Print the recipe.
void print(raw_ostream &O, const Twine &Indent,
Expand Down Expand Up @@ -2694,6 +2698,10 @@ class VPMulAccumulateReductionRecipe : public VPReductionRecipe {
/// The scalar type after extending.
Type *ResultTy = nullptr;

/// The scaling factor, relative to the VF, that this recipe's output is
/// divided by
unsigned VFScaleFactor = 1;

/// For cloning VPMulAccumulateReductionRecipe.
VPMulAccumulateReductionRecipe(VPMulAccumulateReductionRecipe *MulAcc)
: VPReductionRecipe(
Expand All @@ -2703,22 +2711,25 @@ class VPMulAccumulateReductionRecipe : public VPReductionRecipe {
WrapFlagsTy(MulAcc->hasNoUnsignedWrap(), MulAcc->hasNoSignedWrap()),
MulAcc->getDebugLoc()),
ExtOp(MulAcc->getExtOpcode()), IsNonNeg(MulAcc->isNonNeg()),
ResultTy(MulAcc->getResultType()) {
ResultTy(MulAcc->getResultType()),
VFScaleFactor(MulAcc->getVFScaleFactor()) {
transferFlags(*MulAcc);
setUnderlyingValue(MulAcc->getUnderlyingValue());
}

public:
VPMulAccumulateReductionRecipe(VPReductionRecipe *R, VPWidenRecipe *Mul,
VPWidenCastRecipe *Ext0,
VPWidenCastRecipe *Ext1, Type *ResultTy)
VPWidenCastRecipe *Ext1, Type *ResultTy,
unsigned ScaleFactor = 1)
: VPReductionRecipe(
VPDef::VPMulAccumulateReductionSC, R->getRecurrenceKind(),
{R->getChainOp(), Ext0->getOperand(0), Ext1->getOperand(0)},
R->getCondOp(), R->isOrdered(),
WrapFlagsTy(Mul->hasNoUnsignedWrap(), Mul->hasNoSignedWrap()),
R->getDebugLoc()),
ExtOp(Ext0->getOpcode()), ResultTy(ResultTy) {
ExtOp(Ext0->getOpcode()), ResultTy(ResultTy),
VFScaleFactor(ScaleFactor) {
assert(RecurrenceDescriptor::getOpcode(getRecurrenceKind()) ==
Instruction::Add &&
"The reduction instruction in MulAccumulateteReductionRecipe must "
Expand Down Expand Up @@ -2791,6 +2802,10 @@ class VPMulAccumulateReductionRecipe : public VPReductionRecipe {

/// Return true if the operand extends have the non-negative flag.
bool isNonNeg() const { return IsNonNeg; }

/// Return the scaling factor that the VF is divided by to form the recipe's
/// output
unsigned getVFScaleFactor() const { return VFScaleFactor; }
};

/// VPReplicateRecipe replicates a given instruction producing multiple scalar
Expand Down
28 changes: 20 additions & 8 deletions llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ bool VPRecipeBase::mayHaveSideEffects() const {
case VPWidenIntrinsicSC:
return cast<VPWidenIntrinsicRecipe>(this)->mayHaveSideEffects();
case VPBlendSC:
case VPPartialReductionSC:
case VPReductionEVLSC:
case VPReductionSC:
case VPExtendedReductionSC:
Expand Down Expand Up @@ -295,14 +296,9 @@ InstructionCost
VPPartialReductionRecipe::computeCost(ElementCount VF,
VPCostContext &Ctx) const {
std::optional<unsigned> Opcode = std::nullopt;
VPValue *BinOp = getOperand(1);
VPValue *BinOp = getBinOp();

// If the partial reduction is predicated, a select will be operand 0 rather
// than the binary op
using namespace llvm::VPlanPatternMatch;
if (match(getOperand(1), m_Select(m_VPValue(), m_VPValue(), m_VPValue())))
BinOp = BinOp->getDefiningRecipe()->getOperand(1);

// If BinOp is a negation, use the side effect of match to assign the actual
// binary operation to BinOp
match(BinOp, m_Binary<Instruction::Sub>(m_SpecificInt(0), m_VPValue(BinOp)));
Expand Down Expand Up @@ -345,12 +341,18 @@ void VPPartialReductionRecipe::execute(VPTransformState &State) {
assert(getOpcode() == Instruction::Add &&
"Unhandled partial reduction opcode");

Value *BinOpVal = State.get(getOperand(1));
Value *PhiVal = State.get(getOperand(0));
Value *BinOpVal = State.get(getBinOp());
Value *PhiVal = State.get(getChainOp());
assert(PhiVal && BinOpVal && "Phi and Mul must be set");

Type *RetTy = PhiVal->getType();

/// Mask the bin op output.
if (VPValue *Cond = getCondOp()) {
Value *Zero = ConstantInt::get(BinOpVal->getType(), 0);
BinOpVal = Builder.CreateSelect(State.get(Cond), BinOpVal, Zero);
}

CallInst *V = Builder.CreateIntrinsic(
RetTy, Intrinsic::experimental_vector_partial_reduce_add,
{PhiVal, BinOpVal}, nullptr, "partial.reduce");
Expand Down Expand Up @@ -2570,6 +2572,14 @@ VPExtendedReductionRecipe::computeCost(ElementCount VF,
InstructionCost
VPMulAccumulateReductionRecipe::computeCost(ElementCount VF,
VPCostContext &Ctx) const {
if (getVFScaleFactor() > 1) {
return Ctx.TTI.getPartialReductionCost(
Instruction::Add, Ctx.Types.inferScalarType(getVecOp0()),
Ctx.Types.inferScalarType(getVecOp1()), getResultType(), VF,
TTI::getPartialReductionExtendKind(getExtOpcode()),
TTI::getPartialReductionExtendKind(getExtOpcode()), Instruction::Mul);
}

Type *RedTy = Ctx.Types.inferScalarType(this);
auto *SrcVecTy =
cast<VectorType>(toVectorTy(Ctx.Types.inferScalarType(getVecOp0()), VF));
Expand Down Expand Up @@ -2648,6 +2658,8 @@ void VPMulAccumulateReductionRecipe::print(raw_ostream &O, const Twine &Indent,
O << " = ";
getChainOp()->printAsOperand(O, SlotTracker);
O << " + ";
if (getVFScaleFactor() > 1)
O << "partial.";
O << "reduce."
<< Instruction::getOpcodeName(
RecurrenceDescriptor::getOpcode(getRecurrenceKind()))
Expand Down
45 changes: 41 additions & 4 deletions llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2581,9 +2581,15 @@ expandVPMulAccumulateReduction(VPMulAccumulateReductionRecipe *MulAcc) {
MulAcc->hasNoSignedWrap(), MulAcc->getDebugLoc());
Mul->insertBefore(MulAcc);

auto *Red = new VPReductionRecipe(
MulAcc->getRecurrenceKind(), FastMathFlags(), MulAcc->getChainOp(), Mul,
MulAcc->getCondOp(), MulAcc->isOrdered(), MulAcc->getDebugLoc());
// Generate VPReductionRecipe.
VPReductionRecipe *Red = nullptr;
if (unsigned ScaleFactor = MulAcc->getVFScaleFactor(); ScaleFactor > 1)
Red = new VPPartialReductionRecipe(Instruction::Add, MulAcc->getChainOp(),
Mul, MulAcc->getCondOp(), ScaleFactor);
else
Red = new VPReductionRecipe(MulAcc->getRecurrenceKind(), FastMathFlags(),
MulAcc->getChainOp(), Mul, MulAcc->getCondOp(),
MulAcc->isOrdered(), MulAcc->getDebugLoc());
Red->insertBefore(MulAcc);

MulAcc->replaceAllUsesWith(Red);
Expand Down Expand Up @@ -2911,12 +2917,43 @@ static void tryToCreateAbstractReductionRecipe(VPReductionRecipe *Red,
Red->replaceAllUsesWith(AbstractR);
}

/// This function tries to create an abstract recipe from a partial reduction to
/// hide its mul and extends from cost estimation.
static void
tryToCreateAbstractPartialReductionRecipe(VPPartialReductionRecipe *PRed) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this need to be given the Range & and for that range to be clamped if it doesn't match or if the cost is higher than the individual operations (similar to what happens in tryToCreateAbstractReductionRecipe) ?

(note that the cost part is still missing)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At this point we've already created the partial reduction and clamped the range so I don't think we need to do any costing (like tryToMatchAndCreateMulAccumulateReduction does with getMulAccReductionCost) since we already know it's worthwhile (see getScaledReductions in LoopVectorize.cpp). This part of the code just puts the partial reduction inside the abstract recipe, which shouldn't need to consider any costing.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The way I read the code is that at the point of getting to this point in the code, it has recognised a reduction so there is a VP[Partial]ReductionRecipe. It then tries to analyse whether that recipe can be transformed into a VPMulAccumulateReductionRecipe. For VPReductionRecipe it will clamp the range to all the VFs that can be turned into a VPMulAccumulateReductionRecipe, but for VPPartialReductionRecipe it doesn't do that. I don't see why for partial reductions we'd do something different.

In fact, why wouldn't the tryToMatchAndCreateMulAccumulateReduction code be sufficient here? Now that you've made VPPartialReductionRecipe a subclass of VPReductionRecipe, I'd expect that code to function roughly the same.

if (PRed->getOpcode() != Instruction::Add)
return;

using namespace llvm::VPlanPatternMatch;
auto *BinOp = PRed->getBinOp();
if (!match(BinOp,
m_Mul(m_ZExtOrSExt(m_VPValue()), m_ZExtOrSExt(m_VPValue()))))
return;

auto *BinOpR = cast<VPWidenRecipe>(BinOp->getDefiningRecipe());
VPWidenCastRecipe *Ext0R = dyn_cast<VPWidenCastRecipe>(BinOpR->getOperand(0));
VPWidenCastRecipe *Ext1R = dyn_cast<VPWidenCastRecipe>(BinOpR->getOperand(1));

// TODO: Make work with extends of different signedness
if (Ext0R->hasMoreThanOneUniqueUser() || Ext1R->hasMoreThanOneUniqueUser() ||
Ext0R->getOpcode() != Ext1R->getOpcode())
return;

auto *AbstractR = new VPMulAccumulateReductionRecipe(
PRed, BinOpR, Ext0R, Ext1R, Ext0R->getResultType(),
PRed->getVFScaleFactor());
AbstractR->insertBefore(PRed);
PRed->replaceAllUsesWith(AbstractR);
}

void VPlanTransforms::convertToAbstractRecipes(VPlan &Plan, VPCostContext &Ctx,
VFRange &Range) {
for (VPBasicBlock *VPBB : VPBlockUtils::blocksOnly<VPBasicBlock>(
vp_depth_first_deep(Plan.getVectorLoopRegion()))) {
for (VPRecipeBase &R : *VPBB) {
if (auto *Red = dyn_cast<VPReductionRecipe>(&R))
if (auto *PRed = dyn_cast<VPPartialReductionRecipe>(&R))
tryToCreateAbstractPartialReductionRecipe(PRed);
else if (auto *Red = dyn_cast<VPReductionRecipe>(&R))
tryToCreateAbstractReductionRecipe(Red, Ctx, Range);
}
}
Expand Down
Loading
Loading