Skip to content

Commit e291008

Browse files
committed
[VPlan] Only store RecurKind + FastMathFlags in VPReductionRecipe. NFCI
VPReductionRecipes take a RecurrenceDescriptor, but only use the RecurKind and FastMathFlags in it when executing. This patch makes the recipe more lightweight by stripping it to only take the latter, which allows it inherit from VPRecipeWithIRFlags. This also allows us to remove createReduction in LoopUtils since it now only has one user in VPInstruction::ComputeReductionResult. The motiviation for this is to simplify an upcoming patch to support in-loop AnyOf reductions. For an in-loop AnyOf reduction we want to create an Or reduction, and by using RecurKind we can create an arbitrary reduction without needing a full RecurrenceDescriptor.
1 parent 297f6d9 commit e291008

File tree

5 files changed

+60
-74
lines changed

5 files changed

+60
-74
lines changed

llvm/include/llvm/Transforms/Utils/LoopUtils.h

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -411,8 +411,8 @@ Value *createSimpleReduction(IRBuilderBase &B, Value *Src,
411411
RecurKind RdxKind);
412412
/// Overloaded function to generate vector-predication intrinsics for
413413
/// reduction.
414-
Value *createSimpleReduction(VectorBuilder &VB, Value *Src,
415-
const RecurrenceDescriptor &Desc);
414+
Value *createSimpleReduction(VectorBuilder &VB, Value *Src, RecurKind RdxKind,
415+
FastMathFlags FMFs);
416416

417417
/// Create a reduction of the given vector \p Src for a reduction of the
418418
/// kind RecurKind::IAnyOf or RecurKind::FAnyOf. The reduction operation is
@@ -428,14 +428,12 @@ Value *createFindLastIVReduction(IRBuilderBase &B, Value *Src,
428428
const RecurrenceDescriptor &Desc);
429429

430430
/// Create an ordered reduction intrinsic using the given recurrence
431-
/// descriptor \p Desc.
432-
Value *createOrderedReduction(IRBuilderBase &B,
433-
const RecurrenceDescriptor &Desc, Value *Src,
431+
/// kind \p Kind.
432+
Value *createOrderedReduction(IRBuilderBase &B, RecurKind Kind, Value *Src,
434433
Value *Start);
435434
/// Overloaded function to generate vector-predication intrinsics for ordered
436435
/// reduction.
437-
Value *createOrderedReduction(VectorBuilder &VB,
438-
const RecurrenceDescriptor &Desc, Value *Src,
436+
Value *createOrderedReduction(VectorBuilder &VB, RecurKind Kind, Value *Src,
439437
Value *Start);
440438

441439
/// Get the intersection (logical and) of all of the potential IR flags

llvm/lib/Transforms/Utils/LoopUtils.cpp

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1333,36 +1333,31 @@ Value *llvm::createSimpleReduction(IRBuilderBase &Builder, Value *Src,
13331333
}
13341334

13351335
Value *llvm::createSimpleReduction(VectorBuilder &VBuilder, Value *Src,
1336-
const RecurrenceDescriptor &Desc) {
1337-
RecurKind Kind = Desc.getRecurrenceKind();
1336+
RecurKind Kind, FastMathFlags FMFs) {
13381337
assert(!RecurrenceDescriptor::isAnyOfRecurrenceKind(Kind) &&
13391338
!RecurrenceDescriptor::isFindLastIVRecurrenceKind(Kind) &&
13401339
"AnyOf or FindLastIV reductions are not supported.");
13411340
Intrinsic::ID Id = getReductionIntrinsicID(Kind);
13421341
auto *SrcTy = cast<VectorType>(Src->getType());
13431342
Type *SrcEltTy = SrcTy->getElementType();
1344-
Value *Iden = getRecurrenceIdentity(Kind, SrcEltTy, Desc.getFastMathFlags());
1343+
Value *Iden = getRecurrenceIdentity(Kind, SrcEltTy, FMFs);
13451344
Value *Ops[] = {Iden, Src};
13461345
return VBuilder.createSimpleReduction(Id, SrcTy, Ops);
13471346
}
13481347

1349-
Value *llvm::createOrderedReduction(IRBuilderBase &B,
1350-
const RecurrenceDescriptor &Desc,
1348+
Value *llvm::createOrderedReduction(IRBuilderBase &B, RecurKind Kind,
13511349
Value *Src, Value *Start) {
1352-
assert((Desc.getRecurrenceKind() == RecurKind::FAdd ||
1353-
Desc.getRecurrenceKind() == RecurKind::FMulAdd) &&
1350+
assert((Kind == RecurKind::FAdd || Kind == RecurKind::FMulAdd) &&
13541351
"Unexpected reduction kind");
13551352
assert(Src->getType()->isVectorTy() && "Expected a vector type");
13561353
assert(!Start->getType()->isVectorTy() && "Expected a scalar type");
13571354

13581355
return B.CreateFAddReduce(Start, Src);
13591356
}
13601357

1361-
Value *llvm::createOrderedReduction(VectorBuilder &VBuilder,
1362-
const RecurrenceDescriptor &Desc,
1358+
Value *llvm::createOrderedReduction(VectorBuilder &VBuilder, RecurKind Kind,
13631359
Value *Src, Value *Start) {
1364-
assert((Desc.getRecurrenceKind() == RecurKind::FAdd ||
1365-
Desc.getRecurrenceKind() == RecurKind::FMulAdd) &&
1360+
assert((Kind == RecurKind::FAdd || Kind == RecurKind::FMulAdd) &&
13661361
"Unexpected reduction kind");
13671362
assert(Src->getType()->isVectorTy() && "Expected a vector type");
13681363
assert(!Start->getType()->isVectorTy() && "Expected a scalar type");

llvm/lib/Transforms/Vectorize/VPlan.h

Lines changed: 28 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2239,22 +2239,21 @@ class VPInterleaveRecipe : public VPRecipeBase {
22392239
/// a vector operand into a scalar value, and adding the result to a chain.
22402240
/// The Operands are {ChainOp, VecOp, [Condition]}.
22412241
class VPReductionRecipe : public VPRecipeWithIRFlags {
2242-
/// The recurrence decriptor for the reduction in question.
2243-
const RecurrenceDescriptor &RdxDesc;
2242+
/// The recurrence kind for the reduction in question.
2243+
RecurKind RdxKind;
22442244
bool IsOrdered;
22452245
/// Whether the reduction is conditional.
22462246
bool IsConditional = false;
22472247

22482248
protected:
2249-
VPReductionRecipe(const unsigned char SC, const RecurrenceDescriptor &R,
2250-
Instruction *I, ArrayRef<VPValue *> Operands,
2251-
VPValue *CondOp, bool IsOrdered, DebugLoc DL)
2252-
: VPRecipeWithIRFlags(SC, Operands,
2253-
isa_and_nonnull<FPMathOperator>(I)
2254-
? R.getFastMathFlags()
2255-
: FastMathFlags(),
2256-
DL),
2257-
RdxDesc(R), IsOrdered(IsOrdered) {
2249+
VPReductionRecipe(const unsigned char SC, RecurKind RdxKind,
2250+
FastMathFlags FMFs, Instruction *I,
2251+
ArrayRef<VPValue *> Operands, VPValue *CondOp,
2252+
bool IsOrdered, DebugLoc DL)
2253+
: VPRecipeWithIRFlags(
2254+
SC, Operands,
2255+
isa_and_nonnull<FPMathOperator>(I) ? FMFs : FastMathFlags(), DL),
2256+
RdxKind(RdxKind), IsOrdered(IsOrdered) {
22582257
if (CondOp) {
22592258
IsConditional = true;
22602259
addOperand(CondOp);
@@ -2263,19 +2262,25 @@ class VPReductionRecipe : public VPRecipeWithIRFlags {
22632262
}
22642263

22652264
public:
2266-
VPReductionRecipe(const RecurrenceDescriptor &R, Instruction *I,
2265+
VPReductionRecipe(RecurKind RdxKind, FastMathFlags FMFs, Instruction *I,
22672266
VPValue *ChainOp, VPValue *VecOp, VPValue *CondOp,
22682267
bool IsOrdered, DebugLoc DL = {})
2269-
: VPReductionRecipe(VPDef::VPReductionSC, R, I,
2268+
: VPReductionRecipe(VPRecipeBase::VPReductionSC, RdxKind, FMFs, I,
22702269
ArrayRef<VPValue *>({ChainOp, VecOp}), CondOp,
22712270
IsOrdered, DL) {}
22722271

2272+
VPReductionRecipe(const RecurrenceDescriptor &R, Instruction *I,
2273+
VPValue *ChainOp, VPValue *VecOp, VPValue *CondOp,
2274+
bool IsOrdered, DebugLoc DL = {})
2275+
: VPReductionRecipe(R.getRecurrenceKind(), R.getFastMathFlags(), I,
2276+
ChainOp, VecOp, CondOp, IsOrdered, DL) {}
2277+
22732278
~VPReductionRecipe() override = default;
22742279

22752280
VPReductionRecipe *clone() override {
2276-
return new VPReductionRecipe(RdxDesc, getUnderlyingInstr(), getChainOp(),
2277-
getVecOp(), getCondOp(), IsOrdered,
2278-
getDebugLoc());
2281+
return new VPReductionRecipe(RdxKind, getFastMathFlags(),
2282+
getUnderlyingInstr(), getChainOp(), getVecOp(),
2283+
getCondOp(), IsOrdered, getDebugLoc());
22792284
}
22802285

22812286
static inline bool classof(const VPRecipeBase *R) {
@@ -2301,9 +2306,11 @@ class VPReductionRecipe : public VPRecipeWithIRFlags {
23012306
VPSlotTracker &SlotTracker) const override;
23022307
#endif
23032308

2304-
/// Return the recurrence decriptor for the in-loop reduction.
2305-
const RecurrenceDescriptor &getRecurrenceDescriptor() const {
2306-
return RdxDesc;
2309+
/// Return the recurrence kind for the in-loop reduction.
2310+
RecurKind getRecurrenceKind() const { return RdxKind; }
2311+
/// Return the opcode for the recurrence for the in-loop reduction.
2312+
unsigned getOpcode() const {
2313+
return RecurrenceDescriptor::getOpcode(RdxKind);
23072314
}
23082315
/// Return true if the in-loop reduction is ordered.
23092316
bool isOrdered() const { return IsOrdered; };
@@ -2328,7 +2335,8 @@ class VPReductionEVLRecipe : public VPReductionRecipe {
23282335
VPReductionEVLRecipe(VPReductionRecipe &R, VPValue &EVL, VPValue *CondOp,
23292336
DebugLoc DL = {})
23302337
: VPReductionRecipe(
2331-
VPDef::VPReductionEVLSC, R.getRecurrenceDescriptor(),
2338+
VPDef::VPReductionEVLSC, R.getRecurrenceKind(),
2339+
R.getFastMathFlags(),
23322340
cast_or_null<Instruction>(R.getUnderlyingValue()),
23332341
ArrayRef<VPValue *>({R.getChainOp(), R.getVecOp(), &EVL}), CondOp,
23342342
R.isOrdered(), DL) {}

llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp

Lines changed: 20 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -2285,7 +2285,7 @@ void VPBlendRecipe::print(raw_ostream &O, const Twine &Indent,
22852285
void VPReductionRecipe::execute(VPTransformState &State) {
22862286
assert(!State.Lane && "Reduction being replicated.");
22872287
Value *PrevInChain = State.get(getChainOp(), /*IsScalar*/ true);
2288-
RecurKind Kind = RdxDesc.getRecurrenceKind();
2288+
RecurKind Kind = getRecurrenceKind();
22892289
assert(!RecurrenceDescriptor::isAnyOfRecurrenceKind(Kind) &&
22902290
"In-loop AnyOf reductions aren't currently supported");
22912291
// Propagate the fast-math flags carried by the underlying instruction.
@@ -2298,8 +2298,7 @@ void VPReductionRecipe::execute(VPTransformState &State) {
22982298
VectorType *VecTy = dyn_cast<VectorType>(NewVecOp->getType());
22992299
Type *ElementTy = VecTy ? VecTy->getElementType() : NewVecOp->getType();
23002300

2301-
Value *Start =
2302-
getRecurrenceIdentity(Kind, ElementTy, RdxDesc.getFastMathFlags());
2301+
Value *Start = getRecurrenceIdentity(Kind, ElementTy, getFastMathFlags());
23032302
if (State.VF.isVector())
23042303
Start = State.Builder.CreateVectorSplat(VecTy->getElementCount(), Start);
23052304

@@ -2311,21 +2310,20 @@ void VPReductionRecipe::execute(VPTransformState &State) {
23112310
if (IsOrdered) {
23122311
if (State.VF.isVector())
23132312
NewRed =
2314-
createOrderedReduction(State.Builder, RdxDesc, NewVecOp, PrevInChain);
2313+
createOrderedReduction(State.Builder, Kind, NewVecOp, PrevInChain);
23152314
else
2316-
NewRed = State.Builder.CreateBinOp(
2317-
(Instruction::BinaryOps)RdxDesc.getOpcode(), PrevInChain, NewVecOp);
2315+
NewRed = State.Builder.CreateBinOp((Instruction::BinaryOps)getOpcode(),
2316+
PrevInChain, NewVecOp);
23182317
PrevInChain = NewRed;
23192318
NextInChain = NewRed;
23202319
} else {
23212320
PrevInChain = State.get(getChainOp(), /*IsScalar*/ true);
23222321
NewRed = createSimpleReduction(State.Builder, NewVecOp, Kind);
23232322
if (RecurrenceDescriptor::isMinMaxRecurrenceKind(Kind))
2324-
NextInChain = createMinMaxOp(State.Builder, RdxDesc.getRecurrenceKind(),
2325-
NewRed, PrevInChain);
2323+
NextInChain = createMinMaxOp(State.Builder, Kind, NewRed, PrevInChain);
23262324
else
23272325
NextInChain = State.Builder.CreateBinOp(
2328-
(Instruction::BinaryOps)RdxDesc.getOpcode(), NewRed, PrevInChain);
2326+
(Instruction::BinaryOps)getOpcode(), NewRed, PrevInChain);
23292327
}
23302328
State.set(this, NextInChain, /*IsScalar*/ true);
23312329
}
@@ -2336,10 +2334,9 @@ void VPReductionEVLRecipe::execute(VPTransformState &State) {
23362334
auto &Builder = State.Builder;
23372335
// Propagate the fast-math flags carried by the underlying instruction.
23382336
IRBuilderBase::FastMathFlagGuard FMFGuard(Builder);
2339-
const RecurrenceDescriptor &RdxDesc = getRecurrenceDescriptor();
23402337
Builder.setFastMathFlags(getFastMathFlags());
23412338

2342-
RecurKind Kind = RdxDesc.getRecurrenceKind();
2339+
RecurKind Kind = getRecurrenceKind();
23432340
Value *Prev = State.get(getChainOp(), /*IsScalar*/ true);
23442341
Value *VecOp = State.get(getVecOp());
23452342
Value *EVL = State.get(getEVL(), VPLane(0));
@@ -2356,25 +2353,23 @@ void VPReductionEVLRecipe::execute(VPTransformState &State) {
23562353

23572354
Value *NewRed;
23582355
if (isOrdered()) {
2359-
NewRed = createOrderedReduction(VBuilder, RdxDesc, VecOp, Prev);
2356+
NewRed = createOrderedReduction(VBuilder, Kind, VecOp, Prev);
23602357
} else {
2361-
NewRed = createSimpleReduction(VBuilder, VecOp, RdxDesc);
2358+
NewRed = createSimpleReduction(VBuilder, VecOp, Kind, getFastMathFlags());
23622359
if (RecurrenceDescriptor::isMinMaxRecurrenceKind(Kind))
23632360
NewRed = createMinMaxOp(Builder, Kind, NewRed, Prev);
23642361
else
2365-
NewRed = Builder.CreateBinOp((Instruction::BinaryOps)RdxDesc.getOpcode(),
2366-
NewRed, Prev);
2362+
NewRed = Builder.CreateBinOp((Instruction::BinaryOps)getOpcode(), NewRed,
2363+
Prev);
23672364
}
23682365
State.set(this, NewRed, /*IsScalar*/ true);
23692366
}
23702367

23712368
InstructionCost VPReductionRecipe::computeCost(ElementCount VF,
23722369
VPCostContext &Ctx) const {
2373-
RecurKind RdxKind = RdxDesc.getRecurrenceKind();
2370+
RecurKind RdxKind = getRecurrenceKind();
23742371
Type *ElementTy = Ctx.Types.inferScalarType(this);
23752372
auto *VectorTy = cast<VectorType>(toVectorTy(ElementTy, VF));
2376-
unsigned Opcode = RdxDesc.getOpcode();
2377-
FastMathFlags FMFs = getFastMathFlags();
23782373

23792374
// TODO: Support any-of and in-loop reductions.
23802375
assert(
@@ -2386,20 +2381,17 @@ InstructionCost VPReductionRecipe::computeCost(ElementCount VF,
23862381
ForceTargetInstructionCost.getNumOccurrences() > 0) &&
23872382
"In-loop reduction not implemented in VPlan-based cost model currently.");
23882383

2389-
assert(ElementTy->getTypeID() == RdxDesc.getRecurrenceType()->getTypeID() &&
2390-
"Inferred type and recurrence type mismatch.");
2391-
23922384
// Cost = Reduction cost + BinOp cost
23932385
InstructionCost Cost =
2394-
Ctx.TTI.getArithmeticInstrCost(Opcode, ElementTy, Ctx.CostKind);
2386+
Ctx.TTI.getArithmeticInstrCost(getOpcode(), ElementTy, Ctx.CostKind);
23952387
if (RecurrenceDescriptor::isMinMaxRecurrenceKind(RdxKind)) {
23962388
Intrinsic::ID Id = getMinMaxReductionIntrinsicOp(RdxKind);
2397-
return Cost +
2398-
Ctx.TTI.getMinMaxReductionCost(Id, VectorTy, FMFs, Ctx.CostKind);
2389+
return Cost + Ctx.TTI.getMinMaxReductionCost(
2390+
Id, VectorTy, getFastMathFlags(), Ctx.CostKind);
23992391
}
24002392

2401-
return Cost + Ctx.TTI.getArithmeticReductionCost(Opcode, VectorTy, FMFs,
2402-
Ctx.CostKind);
2393+
return Cost + Ctx.TTI.getArithmeticReductionCost(
2394+
getOpcode(), VectorTy, getFastMathFlags(), Ctx.CostKind);
24032395
}
24042396

24052397
#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
@@ -2411,28 +2403,24 @@ void VPReductionRecipe::print(raw_ostream &O, const Twine &Indent,
24112403
getChainOp()->printAsOperand(O, SlotTracker);
24122404
O << " +";
24132405
printFlags(O);
2414-
O << " reduce." << Instruction::getOpcodeName(RdxDesc.getOpcode()) << " (";
2406+
O << " reduce." << Instruction::getOpcodeName(getOpcode()) << " (";
24152407
getVecOp()->printAsOperand(O, SlotTracker);
24162408
if (isConditional()) {
24172409
O << ", ";
24182410
getCondOp()->printAsOperand(O, SlotTracker);
24192411
}
24202412
O << ")";
2421-
if (RdxDesc.IntermediateStore)
2422-
O << " (with final reduction value stored in invariant address sank "
2423-
"outside of loop)";
24242413
}
24252414

24262415
void VPReductionEVLRecipe::print(raw_ostream &O, const Twine &Indent,
24272416
VPSlotTracker &SlotTracker) const {
2428-
const RecurrenceDescriptor &RdxDesc = getRecurrenceDescriptor();
24292417
O << Indent << "REDUCE ";
24302418
printAsOperand(O, SlotTracker);
24312419
O << " = ";
24322420
getChainOp()->printAsOperand(O, SlotTracker);
24332421
O << " +";
24342422
printFlags(O);
2435-
O << " vp.reduce." << Instruction::getOpcodeName(RdxDesc.getOpcode()) << " (";
2423+
O << " vp.reduce." << Instruction::getOpcodeName(getOpcode()) << " (";
24362424
getVecOp()->printAsOperand(O, SlotTracker);
24372425
O << ", ";
24382426
getEVL()->printAsOperand(O, SlotTracker);
@@ -2441,9 +2429,6 @@ void VPReductionEVLRecipe::print(raw_ostream &O, const Twine &Indent,
24412429
getCondOp()->printAsOperand(O, SlotTracker);
24422430
}
24432431
O << ")";
2444-
if (RdxDesc.IntermediateStore)
2445-
O << " (with final reduction value stored in invariant address sank "
2446-
"outside of loop)";
24472432
}
24482433
#endif
24492434

llvm/test/Transforms/LoopVectorize/vplan-printing.ll

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,7 @@ define void @print_reduction_with_invariant_store(i64 %n, ptr noalias %y, ptr no
234234
; CHECK-NEXT: CLONE ir<%arrayidx> = getelementptr inbounds ir<%y>, vp<[[IV]]>
235235
; CHECK-NEXT: vp<[[VEC_PTR:%.+]]> = vector-pointer ir<%arrayidx>
236236
; CHECK-NEXT: WIDEN ir<%lv> = load vp<[[VEC_PTR]]>
237-
; CHECK-NEXT: REDUCE ir<%red.next> = ir<%red> + fast reduce.fadd (ir<%lv>) (with final reduction value stored in invariant address sank outside of loop)
237+
; CHECK-NEXT: REDUCE ir<%red.next> = ir<%red> + fast reduce.fadd (ir<%lv>)
238238
; CHECK-NEXT: EMIT vp<[[CAN_IV_NEXT]]> = add nuw vp<[[CAN_IV]]>, vp<[[VFxUF]]>
239239
; CHECK-NEXT: EMIT branch-on-count vp<[[CAN_IV_NEXT]]>, vp<[[VTC]]>
240240
; CHECK-NEXT: No successors

0 commit comments

Comments
 (0)