Skip to content

Commit 6a8606e

Browse files
authored
[VPlan] Only store RecurKind + FastMathFlags in VPReductionRecipe. NFCI (#131300)
VPReductionRecipes take a RecurrenceDescriptor, but only use the RecurKind and FastMathFlags in it when executing. This patch makes the recipe more lightweight by stripping it to only take the latter two. The motiviation for this is to simplify an upcoming patch to support in-loop AnyOf reductions. For an in-loop AnyOf reduction we want to create an Or reduction, and by using RecurKind we can create an arbitrary reduction without needing a full RecurrenceDescriptor.
1 parent 63b5692 commit 6a8606e

File tree

5 files changed

+52
-56
lines changed

5 files changed

+52
-56
lines changed

llvm/lib/Transforms/Vectorize/LoopVectorize.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9772,8 +9772,12 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
97729772
if (CM.blockNeedsPredicationForAnyReason(BB))
97739773
CondOp = RecipeBuilder.getBlockInMask(BB);
97749774

9775+
// Non-FP RdxDescs will have all fast math flags set, so clear them.
9776+
FastMathFlags FMFs = isa<FPMathOperator>(CurrentLinkI)
9777+
? RdxDesc.getFastMathFlags()
9778+
: FastMathFlags();
97759779
auto *RedRecipe = new VPReductionRecipe(
9776-
RdxDesc, CurrentLinkI, PreviousLink, VecOp, CondOp,
9780+
Kind, FMFs, CurrentLinkI, PreviousLink, VecOp, CondOp,
97779781
CM.useOrderedReductions(RdxDesc), CurrentLinkI->getDebugLoc());
97789782
// Append the recipe to the end of the VPBasicBlock because we need to
97799783
// ensure that it comes after all of it's inputs, including CondOp.

llvm/lib/Transforms/Vectorize/VPlan.h

Lines changed: 17 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2239,22 +2239,19 @@ class VPInterleaveRecipe : public VPRecipeBase {
22392239
/// a vector operand into a scalar value, and adding the result to a chain.
22402240
/// The Operands are {ChainOp, VecOp, [Condition]}.
22412241
class VPReductionRecipe : public VPRecipeWithIRFlags {
2242-
/// The recurrence decriptor for the reduction in question.
2243-
const RecurrenceDescriptor &RdxDesc;
2242+
/// The recurrence kind for the reduction in question.
2243+
RecurKind RdxKind;
22442244
bool IsOrdered;
22452245
/// Whether the reduction is conditional.
22462246
bool IsConditional = false;
22472247

22482248
protected:
2249-
VPReductionRecipe(const unsigned char SC, const RecurrenceDescriptor &R,
2250-
Instruction *I, ArrayRef<VPValue *> Operands,
2251-
VPValue *CondOp, bool IsOrdered, DebugLoc DL)
2252-
: VPRecipeWithIRFlags(SC, Operands,
2253-
isa_and_nonnull<FPMathOperator>(I)
2254-
? R.getFastMathFlags()
2255-
: FastMathFlags(),
2256-
DL),
2257-
RdxDesc(R), IsOrdered(IsOrdered) {
2249+
VPReductionRecipe(const unsigned char SC, RecurKind RdxKind,
2250+
FastMathFlags FMFs, Instruction *I,
2251+
ArrayRef<VPValue *> Operands, VPValue *CondOp,
2252+
bool IsOrdered, DebugLoc DL)
2253+
: VPRecipeWithIRFlags(SC, Operands, FMFs, DL), RdxKind(RdxKind),
2254+
IsOrdered(IsOrdered) {
22582255
if (CondOp) {
22592256
IsConditional = true;
22602257
addOperand(CondOp);
@@ -2263,19 +2260,19 @@ class VPReductionRecipe : public VPRecipeWithIRFlags {
22632260
}
22642261

22652262
public:
2266-
VPReductionRecipe(const RecurrenceDescriptor &R, Instruction *I,
2263+
VPReductionRecipe(RecurKind RdxKind, FastMathFlags FMFs, Instruction *I,
22672264
VPValue *ChainOp, VPValue *VecOp, VPValue *CondOp,
22682265
bool IsOrdered, DebugLoc DL = {})
2269-
: VPReductionRecipe(VPDef::VPReductionSC, R, I,
2266+
: VPReductionRecipe(VPDef::VPReductionSC, RdxKind, FMFs, I,
22702267
ArrayRef<VPValue *>({ChainOp, VecOp}), CondOp,
22712268
IsOrdered, DL) {}
22722269

22732270
~VPReductionRecipe() override = default;
22742271

22752272
VPReductionRecipe *clone() override {
2276-
return new VPReductionRecipe(RdxDesc, getUnderlyingInstr(), getChainOp(),
2277-
getVecOp(), getCondOp(), IsOrdered,
2278-
getDebugLoc());
2273+
return new VPReductionRecipe(RdxKind, getFastMathFlags(),
2274+
getUnderlyingInstr(), getChainOp(), getVecOp(),
2275+
getCondOp(), IsOrdered, getDebugLoc());
22792276
}
22802277

22812278
static inline bool classof(const VPRecipeBase *R) {
@@ -2301,10 +2298,8 @@ class VPReductionRecipe : public VPRecipeWithIRFlags {
23012298
VPSlotTracker &SlotTracker) const override;
23022299
#endif
23032300

2304-
/// Return the recurrence decriptor for the in-loop reduction.
2305-
const RecurrenceDescriptor &getRecurrenceDescriptor() const {
2306-
return RdxDesc;
2307-
}
2301+
/// Return the recurrence kind for the in-loop reduction.
2302+
RecurKind getRecurrenceKind() const { return RdxKind; }
23082303
/// Return true if the in-loop reduction is ordered.
23092304
bool isOrdered() const { return IsOrdered; };
23102305
/// Return true if the in-loop reduction is conditional.
@@ -2328,7 +2323,8 @@ class VPReductionEVLRecipe : public VPReductionRecipe {
23282323
VPReductionEVLRecipe(VPReductionRecipe &R, VPValue &EVL, VPValue *CondOp,
23292324
DebugLoc DL = {})
23302325
: VPReductionRecipe(
2331-
VPDef::VPReductionEVLSC, R.getRecurrenceDescriptor(),
2326+
VPDef::VPReductionEVLSC, R.getRecurrenceKind(),
2327+
R.getFastMathFlags(),
23322328
cast_or_null<Instruction>(R.getUnderlyingValue()),
23332329
ArrayRef<VPValue *>({R.getChainOp(), R.getVecOp(), &EVL}), CondOp,
23342330
R.isOrdered(), DL) {}

llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp

Lines changed: 21 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -2300,7 +2300,7 @@ void VPBlendRecipe::print(raw_ostream &O, const Twine &Indent,
23002300
void VPReductionRecipe::execute(VPTransformState &State) {
23012301
assert(!State.Lane && "Reduction being replicated.");
23022302
Value *PrevInChain = State.get(getChainOp(), /*IsScalar*/ true);
2303-
RecurKind Kind = RdxDesc.getRecurrenceKind();
2303+
RecurKind Kind = getRecurrenceKind();
23042304
assert(!RecurrenceDescriptor::isAnyOfRecurrenceKind(Kind) &&
23052305
"In-loop AnyOf reductions aren't currently supported");
23062306
// Propagate the fast-math flags carried by the underlying instruction.
@@ -2313,8 +2313,7 @@ void VPReductionRecipe::execute(VPTransformState &State) {
23132313
VectorType *VecTy = dyn_cast<VectorType>(NewVecOp->getType());
23142314
Type *ElementTy = VecTy ? VecTy->getElementType() : NewVecOp->getType();
23152315

2316-
Value *Start =
2317-
getRecurrenceIdentity(Kind, ElementTy, RdxDesc.getFastMathFlags());
2316+
Value *Start = getRecurrenceIdentity(Kind, ElementTy, getFastMathFlags());
23182317
if (State.VF.isVector())
23192318
Start = State.Builder.CreateVectorSplat(VecTy->getElementCount(), Start);
23202319

@@ -2329,18 +2328,19 @@ void VPReductionRecipe::execute(VPTransformState &State) {
23292328
createOrderedReduction(State.Builder, Kind, NewVecOp, PrevInChain);
23302329
else
23312330
NewRed = State.Builder.CreateBinOp(
2332-
(Instruction::BinaryOps)RdxDesc.getOpcode(), PrevInChain, NewVecOp);
2331+
(Instruction::BinaryOps)RecurrenceDescriptor::getOpcode(Kind),
2332+
PrevInChain, NewVecOp);
23332333
PrevInChain = NewRed;
23342334
NextInChain = NewRed;
23352335
} else {
23362336
PrevInChain = State.get(getChainOp(), /*IsScalar*/ true);
23372337
NewRed = createSimpleReduction(State.Builder, NewVecOp, Kind);
23382338
if (RecurrenceDescriptor::isMinMaxRecurrenceKind(Kind))
2339-
NextInChain = createMinMaxOp(State.Builder, RdxDesc.getRecurrenceKind(),
2340-
NewRed, PrevInChain);
2339+
NextInChain = createMinMaxOp(State.Builder, Kind, NewRed, PrevInChain);
23412340
else
23422341
NextInChain = State.Builder.CreateBinOp(
2343-
(Instruction::BinaryOps)RdxDesc.getOpcode(), NewRed, PrevInChain);
2342+
(Instruction::BinaryOps)RecurrenceDescriptor::getOpcode(Kind), NewRed,
2343+
PrevInChain);
23442344
}
23452345
State.set(this, NextInChain, /*IsScalar*/ true);
23462346
}
@@ -2351,10 +2351,9 @@ void VPReductionEVLRecipe::execute(VPTransformState &State) {
23512351
auto &Builder = State.Builder;
23522352
// Propagate the fast-math flags carried by the underlying instruction.
23532353
IRBuilderBase::FastMathFlagGuard FMFGuard(Builder);
2354-
const RecurrenceDescriptor &RdxDesc = getRecurrenceDescriptor();
23552354
Builder.setFastMathFlags(getFastMathFlags());
23562355

2357-
RecurKind Kind = RdxDesc.getRecurrenceKind();
2356+
RecurKind Kind = getRecurrenceKind();
23582357
Value *Prev = State.get(getChainOp(), /*IsScalar*/ true);
23592358
Value *VecOp = State.get(getVecOp());
23602359
Value *EVL = State.get(getEVL(), VPLane(0));
@@ -2377,18 +2376,19 @@ void VPReductionEVLRecipe::execute(VPTransformState &State) {
23772376
if (RecurrenceDescriptor::isMinMaxRecurrenceKind(Kind))
23782377
NewRed = createMinMaxOp(Builder, Kind, NewRed, Prev);
23792378
else
2380-
NewRed = Builder.CreateBinOp((Instruction::BinaryOps)RdxDesc.getOpcode(),
2381-
NewRed, Prev);
2379+
NewRed = Builder.CreateBinOp(
2380+
(Instruction::BinaryOps)RecurrenceDescriptor::getOpcode(Kind), NewRed,
2381+
Prev);
23822382
}
23832383
State.set(this, NewRed, /*IsScalar*/ true);
23842384
}
23852385

23862386
InstructionCost VPReductionRecipe::computeCost(ElementCount VF,
23872387
VPCostContext &Ctx) const {
2388-
RecurKind RdxKind = RdxDesc.getRecurrenceKind();
2388+
RecurKind RdxKind = getRecurrenceKind();
23892389
Type *ElementTy = Ctx.Types.inferScalarType(this);
23902390
auto *VectorTy = cast<VectorType>(toVectorTy(ElementTy, VF));
2391-
unsigned Opcode = RdxDesc.getOpcode();
2391+
unsigned Opcode = RecurrenceDescriptor::getOpcode(RdxKind);
23922392
FastMathFlags FMFs = getFastMathFlags();
23932393

23942394
// TODO: Support any-of and in-loop reductions.
@@ -2401,9 +2401,6 @@ InstructionCost VPReductionRecipe::computeCost(ElementCount VF,
24012401
ForceTargetInstructionCost.getNumOccurrences() > 0) &&
24022402
"In-loop reduction not implemented in VPlan-based cost model currently.");
24032403

2404-
assert(ElementTy->getTypeID() == RdxDesc.getRecurrenceType()->getTypeID() &&
2405-
"Inferred type and recurrence type mismatch.");
2406-
24072404
// Cost = Reduction cost + BinOp cost
24082405
InstructionCost Cost =
24092406
Ctx.TTI.getArithmeticInstrCost(Opcode, ElementTy, Ctx.CostKind);
@@ -2426,28 +2423,30 @@ void VPReductionRecipe::print(raw_ostream &O, const Twine &Indent,
24262423
getChainOp()->printAsOperand(O, SlotTracker);
24272424
O << " +";
24282425
printFlags(O);
2429-
O << " reduce." << Instruction::getOpcodeName(RdxDesc.getOpcode()) << " (";
2426+
O << " reduce."
2427+
<< Instruction::getOpcodeName(
2428+
RecurrenceDescriptor::getOpcode(getRecurrenceKind()))
2429+
<< " (";
24302430
getVecOp()->printAsOperand(O, SlotTracker);
24312431
if (isConditional()) {
24322432
O << ", ";
24332433
getCondOp()->printAsOperand(O, SlotTracker);
24342434
}
24352435
O << ")";
2436-
if (RdxDesc.IntermediateStore)
2437-
O << " (with final reduction value stored in invariant address sank "
2438-
"outside of loop)";
24392436
}
24402437

24412438
void VPReductionEVLRecipe::print(raw_ostream &O, const Twine &Indent,
24422439
VPSlotTracker &SlotTracker) const {
2443-
const RecurrenceDescriptor &RdxDesc = getRecurrenceDescriptor();
24442440
O << Indent << "REDUCE ";
24452441
printAsOperand(O, SlotTracker);
24462442
O << " = ";
24472443
getChainOp()->printAsOperand(O, SlotTracker);
24482444
O << " +";
24492445
printFlags(O);
2450-
O << " vp.reduce." << Instruction::getOpcodeName(RdxDesc.getOpcode()) << " (";
2446+
O << " vp.reduce."
2447+
<< Instruction::getOpcodeName(
2448+
RecurrenceDescriptor::getOpcode(getRecurrenceKind()))
2449+
<< " (";
24512450
getVecOp()->printAsOperand(O, SlotTracker);
24522451
O << ", ";
24532452
getEVL()->printAsOperand(O, SlotTracker);
@@ -2456,9 +2455,6 @@ void VPReductionEVLRecipe::print(raw_ostream &O, const Twine &Indent,
24562455
getCondOp()->printAsOperand(O, SlotTracker);
24572456
}
24582457
O << ")";
2459-
if (RdxDesc.IntermediateStore)
2460-
O << " (with final reduction value stored in invariant address sank "
2461-
"outside of loop)";
24622458
}
24632459
#endif
24642460

llvm/test/Transforms/LoopVectorize/vplan-printing.ll

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,7 @@ define void @print_reduction_with_invariant_store(i64 %n, ptr noalias %y, ptr no
234234
; CHECK-NEXT: CLONE ir<%arrayidx> = getelementptr inbounds ir<%y>, vp<[[IV]]>
235235
; CHECK-NEXT: vp<[[VEC_PTR:%.+]]> = vector-pointer ir<%arrayidx>
236236
; CHECK-NEXT: WIDEN ir<%lv> = load vp<[[VEC_PTR]]>
237-
; CHECK-NEXT: REDUCE ir<%red.next> = ir<%red> + fast reduce.fadd (ir<%lv>) (with final reduction value stored in invariant address sank outside of loop)
237+
; CHECK-NEXT: REDUCE ir<%red.next> = ir<%red> + fast reduce.fadd (ir<%lv>)
238238
; CHECK-NEXT: EMIT vp<[[CAN_IV_NEXT]]> = add nuw vp<[[CAN_IV]]>, vp<[[VFxUF]]>
239239
; CHECK-NEXT: EMIT branch-on-count vp<[[CAN_IV_NEXT]]>, vp<[[VTC]]>
240240
; CHECK-NEXT: No successors

llvm/unittests/Transforms/Vectorize/VPlanTest.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1170,8 +1170,8 @@ TEST_F(VPRecipeTest, MayHaveSideEffectsAndMayReadWriteMemory) {
11701170
VPValue *ChainOp = Plan.getOrAddLiveIn(ConstantInt::get(Int32, 1));
11711171
VPValue *VecOp = Plan.getOrAddLiveIn(ConstantInt::get(Int32, 2));
11721172
VPValue *CondOp = Plan.getOrAddLiveIn(ConstantInt::get(Int32, 3));
1173-
VPReductionRecipe Recipe(RecurrenceDescriptor(), Add, ChainOp, CondOp,
1174-
VecOp, false);
1173+
VPReductionRecipe Recipe(RecurKind::Add, FastMathFlags(), Add, ChainOp,
1174+
CondOp, VecOp, false);
11751175
EXPECT_FALSE(Recipe.mayHaveSideEffects());
11761176
EXPECT_FALSE(Recipe.mayReadFromMemory());
11771177
EXPECT_FALSE(Recipe.mayWriteToMemory());
@@ -1185,8 +1185,8 @@ TEST_F(VPRecipeTest, MayHaveSideEffectsAndMayReadWriteMemory) {
11851185
VPValue *ChainOp = Plan.getOrAddLiveIn(ConstantInt::get(Int32, 1));
11861186
VPValue *VecOp = Plan.getOrAddLiveIn(ConstantInt::get(Int32, 2));
11871187
VPValue *CondOp = Plan.getOrAddLiveIn(ConstantInt::get(Int32, 3));
1188-
VPReductionRecipe Recipe(RecurrenceDescriptor(), Add, ChainOp, CondOp,
1189-
VecOp, false);
1188+
VPReductionRecipe Recipe(RecurKind::Add, FastMathFlags(), Add, ChainOp,
1189+
CondOp, VecOp, false);
11901190
VPValue *EVL = Plan.getOrAddLiveIn(ConstantInt::get(Int32, 4));
11911191
VPReductionEVLRecipe EVLRecipe(Recipe, *EVL, CondOp);
11921192
EXPECT_FALSE(EVLRecipe.mayHaveSideEffects());
@@ -1540,8 +1540,8 @@ TEST_F(VPRecipeTest, CastVPReductionRecipeToVPUser) {
15401540
VPValue *ChainOp = getPlan().getOrAddLiveIn(ConstantInt::get(Int32, 1));
15411541
VPValue *VecOp = getPlan().getOrAddLiveIn(ConstantInt::get(Int32, 2));
15421542
VPValue *CondOp = getPlan().getOrAddLiveIn(ConstantInt::get(Int32, 3));
1543-
VPReductionRecipe Recipe(RecurrenceDescriptor(), Add, ChainOp, CondOp, VecOp,
1544-
false);
1543+
VPReductionRecipe Recipe(RecurKind::Add, FastMathFlags(), Add, ChainOp,
1544+
CondOp, VecOp, false);
15451545
EXPECT_TRUE(isa<VPUser>(&Recipe));
15461546
VPRecipeBase *BaseR = &Recipe;
15471547
EXPECT_TRUE(isa<VPUser>(BaseR));
@@ -1555,8 +1555,8 @@ TEST_F(VPRecipeTest, CastVPReductionEVLRecipeToVPUser) {
15551555
VPValue *ChainOp = getPlan().getOrAddLiveIn(ConstantInt::get(Int32, 1));
15561556
VPValue *VecOp = getPlan().getOrAddLiveIn(ConstantInt::get(Int32, 2));
15571557
VPValue *CondOp = getPlan().getOrAddLiveIn(ConstantInt::get(Int32, 3));
1558-
VPReductionRecipe Recipe(RecurrenceDescriptor(), Add, ChainOp, CondOp, VecOp,
1559-
false);
1558+
VPReductionRecipe Recipe(RecurKind::Add, FastMathFlags(), Add, ChainOp,
1559+
CondOp, VecOp, false);
15601560
VPValue *EVL = getPlan().getOrAddLiveIn(ConstantInt::get(Int32, 0));
15611561
VPReductionEVLRecipe EVLRecipe(Recipe, *EVL, CondOp);
15621562
EXPECT_TRUE(isa<VPUser>(&EVLRecipe));

0 commit comments

Comments
 (0)