@@ -2300,7 +2300,7 @@ void VPBlendRecipe::print(raw_ostream &O, const Twine &Indent,
2300
2300
void VPReductionRecipe::execute (VPTransformState &State) {
2301
2301
assert (!State.Lane && " Reduction being replicated." );
2302
2302
Value *PrevInChain = State.get (getChainOp (), /* IsScalar*/ true );
2303
- RecurKind Kind = RdxDesc. getRecurrenceKind ();
2303
+ RecurKind Kind = getRecurrenceKind ();
2304
2304
assert (!RecurrenceDescriptor::isAnyOfRecurrenceKind (Kind) &&
2305
2305
" In-loop AnyOf reductions aren't currently supported" );
2306
2306
// Propagate the fast-math flags carried by the underlying instruction.
@@ -2313,8 +2313,7 @@ void VPReductionRecipe::execute(VPTransformState &State) {
2313
2313
VectorType *VecTy = dyn_cast<VectorType>(NewVecOp->getType ());
2314
2314
Type *ElementTy = VecTy ? VecTy->getElementType () : NewVecOp->getType ();
2315
2315
2316
- Value *Start =
2317
- getRecurrenceIdentity (Kind, ElementTy, RdxDesc.getFastMathFlags ());
2316
+ Value *Start = getRecurrenceIdentity (Kind, ElementTy, getFastMathFlags ());
2318
2317
if (State.VF .isVector ())
2319
2318
Start = State.Builder .CreateVectorSplat (VecTy->getElementCount (), Start);
2320
2319
@@ -2329,18 +2328,19 @@ void VPReductionRecipe::execute(VPTransformState &State) {
2329
2328
createOrderedReduction (State.Builder , Kind, NewVecOp, PrevInChain);
2330
2329
else
2331
2330
NewRed = State.Builder .CreateBinOp (
2332
- (Instruction::BinaryOps)RdxDesc.getOpcode (), PrevInChain, NewVecOp);
2331
+ (Instruction::BinaryOps)RecurrenceDescriptor::getOpcode (Kind),
2332
+ PrevInChain, NewVecOp);
2333
2333
PrevInChain = NewRed;
2334
2334
NextInChain = NewRed;
2335
2335
} else {
2336
2336
PrevInChain = State.get (getChainOp (), /* IsScalar*/ true );
2337
2337
NewRed = createSimpleReduction (State.Builder , NewVecOp, Kind);
2338
2338
if (RecurrenceDescriptor::isMinMaxRecurrenceKind (Kind))
2339
- NextInChain = createMinMaxOp (State.Builder , RdxDesc.getRecurrenceKind (),
2340
- NewRed, PrevInChain);
2339
+ NextInChain = createMinMaxOp (State.Builder , Kind, NewRed, PrevInChain);
2341
2340
else
2342
2341
NextInChain = State.Builder .CreateBinOp (
2343
- (Instruction::BinaryOps)RdxDesc.getOpcode (), NewRed, PrevInChain);
2342
+ (Instruction::BinaryOps)RecurrenceDescriptor::getOpcode (Kind), NewRed,
2343
+ PrevInChain);
2344
2344
}
2345
2345
State.set (this , NextInChain, /* IsScalar*/ true );
2346
2346
}
@@ -2351,10 +2351,9 @@ void VPReductionEVLRecipe::execute(VPTransformState &State) {
2351
2351
auto &Builder = State.Builder ;
2352
2352
// Propagate the fast-math flags carried by the underlying instruction.
2353
2353
IRBuilderBase::FastMathFlagGuard FMFGuard (Builder);
2354
- const RecurrenceDescriptor &RdxDesc = getRecurrenceDescriptor ();
2355
2354
Builder.setFastMathFlags (getFastMathFlags ());
2356
2355
2357
- RecurKind Kind = RdxDesc. getRecurrenceKind ();
2356
+ RecurKind Kind = getRecurrenceKind ();
2358
2357
Value *Prev = State.get (getChainOp (), /* IsScalar*/ true );
2359
2358
Value *VecOp = State.get (getVecOp ());
2360
2359
Value *EVL = State.get (getEVL (), VPLane (0 ));
@@ -2377,18 +2376,19 @@ void VPReductionEVLRecipe::execute(VPTransformState &State) {
2377
2376
if (RecurrenceDescriptor::isMinMaxRecurrenceKind (Kind))
2378
2377
NewRed = createMinMaxOp (Builder, Kind, NewRed, Prev);
2379
2378
else
2380
- NewRed = Builder.CreateBinOp ((Instruction::BinaryOps)RdxDesc.getOpcode (),
2381
- NewRed, Prev);
2379
+ NewRed = Builder.CreateBinOp (
2380
+ (Instruction::BinaryOps)RecurrenceDescriptor::getOpcode (Kind), NewRed,
2381
+ Prev);
2382
2382
}
2383
2383
State.set (this , NewRed, /* IsScalar*/ true );
2384
2384
}
2385
2385
2386
2386
InstructionCost VPReductionRecipe::computeCost (ElementCount VF,
2387
2387
VPCostContext &Ctx) const {
2388
- RecurKind RdxKind = RdxDesc. getRecurrenceKind ();
2388
+ RecurKind RdxKind = getRecurrenceKind ();
2389
2389
Type *ElementTy = Ctx.Types .inferScalarType (this );
2390
2390
auto *VectorTy = cast<VectorType>(toVectorTy (ElementTy, VF));
2391
- unsigned Opcode = RdxDesc. getOpcode ();
2391
+ unsigned Opcode = RecurrenceDescriptor:: getOpcode (RdxKind );
2392
2392
FastMathFlags FMFs = getFastMathFlags ();
2393
2393
2394
2394
// TODO: Support any-of and in-loop reductions.
@@ -2401,9 +2401,6 @@ InstructionCost VPReductionRecipe::computeCost(ElementCount VF,
2401
2401
ForceTargetInstructionCost.getNumOccurrences () > 0 ) &&
2402
2402
" In-loop reduction not implemented in VPlan-based cost model currently." );
2403
2403
2404
- assert (ElementTy->getTypeID () == RdxDesc.getRecurrenceType ()->getTypeID () &&
2405
- " Inferred type and recurrence type mismatch." );
2406
-
2407
2404
// Cost = Reduction cost + BinOp cost
2408
2405
InstructionCost Cost =
2409
2406
Ctx.TTI .getArithmeticInstrCost (Opcode, ElementTy, Ctx.CostKind );
@@ -2426,28 +2423,30 @@ void VPReductionRecipe::print(raw_ostream &O, const Twine &Indent,
2426
2423
getChainOp ()->printAsOperand (O, SlotTracker);
2427
2424
O << " +" ;
2428
2425
printFlags (O);
2429
- O << " reduce." << Instruction::getOpcodeName (RdxDesc.getOpcode ()) << " (" ;
2426
+ O << " reduce."
2427
+ << Instruction::getOpcodeName (
2428
+ RecurrenceDescriptor::getOpcode (getRecurrenceKind ()))
2429
+ << " (" ;
2430
2430
getVecOp ()->printAsOperand (O, SlotTracker);
2431
2431
if (isConditional ()) {
2432
2432
O << " , " ;
2433
2433
getCondOp ()->printAsOperand (O, SlotTracker);
2434
2434
}
2435
2435
O << " )" ;
2436
- if (RdxDesc.IntermediateStore )
2437
- O << " (with final reduction value stored in invariant address sank "
2438
- " outside of loop)" ;
2439
2436
}
2440
2437
2441
2438
void VPReductionEVLRecipe::print (raw_ostream &O, const Twine &Indent,
2442
2439
VPSlotTracker &SlotTracker) const {
2443
- const RecurrenceDescriptor &RdxDesc = getRecurrenceDescriptor ();
2444
2440
O << Indent << " REDUCE " ;
2445
2441
printAsOperand (O, SlotTracker);
2446
2442
O << " = " ;
2447
2443
getChainOp ()->printAsOperand (O, SlotTracker);
2448
2444
O << " +" ;
2449
2445
printFlags (O);
2450
- O << " vp.reduce." << Instruction::getOpcodeName (RdxDesc.getOpcode ()) << " (" ;
2446
+ O << " vp.reduce."
2447
+ << Instruction::getOpcodeName (
2448
+ RecurrenceDescriptor::getOpcode (getRecurrenceKind ()))
2449
+ << " (" ;
2451
2450
getVecOp ()->printAsOperand (O, SlotTracker);
2452
2451
O << " , " ;
2453
2452
getEVL ()->printAsOperand (O, SlotTracker);
@@ -2456,9 +2455,6 @@ void VPReductionEVLRecipe::print(raw_ostream &O, const Twine &Indent,
2456
2455
getCondOp ()->printAsOperand (O, SlotTracker);
2457
2456
}
2458
2457
O << " )" ;
2459
- if (RdxDesc.IntermediateStore )
2460
- O << " (with final reduction value stored in invariant address sank "
2461
- " outside of loop)" ;
2462
2458
}
2463
2459
#endif
2464
2460
0 commit comments