@@ -2285,7 +2285,7 @@ void VPBlendRecipe::print(raw_ostream &O, const Twine &Indent,
2285
2285
void VPReductionRecipe::execute (VPTransformState &State) {
2286
2286
assert (!State.Lane && " Reduction being replicated." );
2287
2287
Value *PrevInChain = State.get (getChainOp (), /* IsScalar*/ true );
2288
- RecurKind Kind = RdxDesc. getRecurrenceKind ();
2288
+ RecurKind Kind = getRecurrenceKind ();
2289
2289
assert (!RecurrenceDescriptor::isAnyOfRecurrenceKind (Kind) &&
2290
2290
" In-loop AnyOf reductions aren't currently supported" );
2291
2291
// Propagate the fast-math flags carried by the underlying instruction.
@@ -2298,8 +2298,7 @@ void VPReductionRecipe::execute(VPTransformState &State) {
2298
2298
VectorType *VecTy = dyn_cast<VectorType>(NewVecOp->getType ());
2299
2299
Type *ElementTy = VecTy ? VecTy->getElementType () : NewVecOp->getType ();
2300
2300
2301
- Value *Start =
2302
- getRecurrenceIdentity (Kind, ElementTy, RdxDesc.getFastMathFlags ());
2301
+ Value *Start = getRecurrenceIdentity (Kind, ElementTy, getFastMathFlags ());
2303
2302
if (State.VF .isVector ())
2304
2303
Start = State.Builder .CreateVectorSplat (VecTy->getElementCount (), Start);
2305
2304
@@ -2311,21 +2310,20 @@ void VPReductionRecipe::execute(VPTransformState &State) {
2311
2310
if (IsOrdered) {
2312
2311
if (State.VF .isVector ())
2313
2312
NewRed =
2314
- createOrderedReduction (State.Builder , RdxDesc , NewVecOp, PrevInChain);
2313
+ createOrderedReduction (State.Builder , Kind , NewVecOp, PrevInChain);
2315
2314
else
2316
- NewRed = State.Builder .CreateBinOp (
2317
- (Instruction::BinaryOps)RdxDesc. getOpcode (), PrevInChain, NewVecOp);
2315
+ NewRed = State.Builder .CreateBinOp ((Instruction::BinaryOps) getOpcode (),
2316
+ PrevInChain, NewVecOp);
2318
2317
PrevInChain = NewRed;
2319
2318
NextInChain = NewRed;
2320
2319
} else {
2321
2320
PrevInChain = State.get (getChainOp (), /* IsScalar*/ true );
2322
2321
NewRed = createSimpleReduction (State.Builder , NewVecOp, Kind);
2323
2322
if (RecurrenceDescriptor::isMinMaxRecurrenceKind (Kind))
2324
- NextInChain = createMinMaxOp (State.Builder , RdxDesc.getRecurrenceKind (),
2325
- NewRed, PrevInChain);
2323
+ NextInChain = createMinMaxOp (State.Builder , Kind, NewRed, PrevInChain);
2326
2324
else
2327
2325
NextInChain = State.Builder .CreateBinOp (
2328
- (Instruction::BinaryOps)RdxDesc. getOpcode (), NewRed, PrevInChain);
2326
+ (Instruction::BinaryOps)getOpcode (), NewRed, PrevInChain);
2329
2327
}
2330
2328
State.set (this , NextInChain, /* IsScalar*/ true );
2331
2329
}
@@ -2336,10 +2334,9 @@ void VPReductionEVLRecipe::execute(VPTransformState &State) {
2336
2334
auto &Builder = State.Builder ;
2337
2335
// Propagate the fast-math flags carried by the underlying instruction.
2338
2336
IRBuilderBase::FastMathFlagGuard FMFGuard (Builder);
2339
- const RecurrenceDescriptor &RdxDesc = getRecurrenceDescriptor ();
2340
2337
Builder.setFastMathFlags (getFastMathFlags ());
2341
2338
2342
- RecurKind Kind = RdxDesc. getRecurrenceKind ();
2339
+ RecurKind Kind = getRecurrenceKind ();
2343
2340
Value *Prev = State.get (getChainOp (), /* IsScalar*/ true );
2344
2341
Value *VecOp = State.get (getVecOp ());
2345
2342
Value *EVL = State.get (getEVL (), VPLane (0 ));
@@ -2356,25 +2353,23 @@ void VPReductionEVLRecipe::execute(VPTransformState &State) {
2356
2353
2357
2354
Value *NewRed;
2358
2355
if (isOrdered ()) {
2359
- NewRed = createOrderedReduction (VBuilder, RdxDesc , VecOp, Prev);
2356
+ NewRed = createOrderedReduction (VBuilder, Kind , VecOp, Prev);
2360
2357
} else {
2361
- NewRed = createSimpleReduction (VBuilder, VecOp, RdxDesc );
2358
+ NewRed = createSimpleReduction (VBuilder, VecOp, Kind, getFastMathFlags () );
2362
2359
if (RecurrenceDescriptor::isMinMaxRecurrenceKind (Kind))
2363
2360
NewRed = createMinMaxOp (Builder, Kind, NewRed, Prev);
2364
2361
else
2365
- NewRed = Builder.CreateBinOp ((Instruction::BinaryOps)RdxDesc. getOpcode (),
2366
- NewRed, Prev);
2362
+ NewRed = Builder.CreateBinOp ((Instruction::BinaryOps)getOpcode (), NewRed ,
2363
+ Prev);
2367
2364
}
2368
2365
State.set (this , NewRed, /* IsScalar*/ true );
2369
2366
}
2370
2367
2371
2368
InstructionCost VPReductionRecipe::computeCost (ElementCount VF,
2372
2369
VPCostContext &Ctx) const {
2373
- RecurKind RdxKind = RdxDesc. getRecurrenceKind ();
2370
+ RecurKind RdxKind = getRecurrenceKind ();
2374
2371
Type *ElementTy = Ctx.Types .inferScalarType (this );
2375
2372
auto *VectorTy = cast<VectorType>(toVectorTy (ElementTy, VF));
2376
- unsigned Opcode = RdxDesc.getOpcode ();
2377
- FastMathFlags FMFs = getFastMathFlags ();
2378
2373
2379
2374
// TODO: Support any-of and in-loop reductions.
2380
2375
assert (
@@ -2386,20 +2381,17 @@ InstructionCost VPReductionRecipe::computeCost(ElementCount VF,
2386
2381
ForceTargetInstructionCost.getNumOccurrences () > 0 ) &&
2387
2382
" In-loop reduction not implemented in VPlan-based cost model currently." );
2388
2383
2389
- assert (ElementTy->getTypeID () == RdxDesc.getRecurrenceType ()->getTypeID () &&
2390
- " Inferred type and recurrence type mismatch." );
2391
-
2392
2384
// Cost = Reduction cost + BinOp cost
2393
2385
InstructionCost Cost =
2394
- Ctx.TTI .getArithmeticInstrCost (Opcode , ElementTy, Ctx.CostKind );
2386
+ Ctx.TTI .getArithmeticInstrCost (getOpcode () , ElementTy, Ctx.CostKind );
2395
2387
if (RecurrenceDescriptor::isMinMaxRecurrenceKind (RdxKind)) {
2396
2388
Intrinsic::ID Id = getMinMaxReductionIntrinsicOp (RdxKind);
2397
- return Cost +
2398
- Ctx. TTI . getMinMaxReductionCost ( Id, VectorTy, FMFs , Ctx.CostKind );
2389
+ return Cost + Ctx. TTI . getMinMaxReductionCost (
2390
+ Id, VectorTy, getFastMathFlags () , Ctx.CostKind );
2399
2391
}
2400
2392
2401
- return Cost + Ctx.TTI .getArithmeticReductionCost (Opcode, VectorTy, FMFs,
2402
- Ctx.CostKind );
2393
+ return Cost + Ctx.TTI .getArithmeticReductionCost (
2394
+ getOpcode (), VectorTy, getFastMathFlags (), Ctx.CostKind );
2403
2395
}
2404
2396
2405
2397
#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
@@ -2411,28 +2403,24 @@ void VPReductionRecipe::print(raw_ostream &O, const Twine &Indent,
2411
2403
getChainOp ()->printAsOperand (O, SlotTracker);
2412
2404
O << " +" ;
2413
2405
printFlags (O);
2414
- O << " reduce." << Instruction::getOpcodeName (RdxDesc. getOpcode ()) << " (" ;
2406
+ O << " reduce." << Instruction::getOpcodeName (getOpcode ()) << " (" ;
2415
2407
getVecOp ()->printAsOperand (O, SlotTracker);
2416
2408
if (isConditional ()) {
2417
2409
O << " , " ;
2418
2410
getCondOp ()->printAsOperand (O, SlotTracker);
2419
2411
}
2420
2412
O << " )" ;
2421
- if (RdxDesc.IntermediateStore )
2422
- O << " (with final reduction value stored in invariant address sank "
2423
- " outside of loop)" ;
2424
2413
}
2425
2414
2426
2415
void VPReductionEVLRecipe::print (raw_ostream &O, const Twine &Indent,
2427
2416
VPSlotTracker &SlotTracker) const {
2428
- const RecurrenceDescriptor &RdxDesc = getRecurrenceDescriptor ();
2429
2417
O << Indent << " REDUCE " ;
2430
2418
printAsOperand (O, SlotTracker);
2431
2419
O << " = " ;
2432
2420
getChainOp ()->printAsOperand (O, SlotTracker);
2433
2421
O << " +" ;
2434
2422
printFlags (O);
2435
- O << " vp.reduce." << Instruction::getOpcodeName (RdxDesc. getOpcode ()) << " (" ;
2423
+ O << " vp.reduce." << Instruction::getOpcodeName (getOpcode ()) << " (" ;
2436
2424
getVecOp ()->printAsOperand (O, SlotTracker);
2437
2425
O << " , " ;
2438
2426
getEVL ()->printAsOperand (O, SlotTracker);
@@ -2441,9 +2429,6 @@ void VPReductionEVLRecipe::print(raw_ostream &O, const Twine &Indent,
2441
2429
getCondOp ()->printAsOperand (O, SlotTracker);
2442
2430
}
2443
2431
O << " )" ;
2444
- if (RdxDesc.IntermediateStore )
2445
- O << " (with final reduction value stored in invariant address sank "
2446
- " outside of loop)" ;
2447
2432
}
2448
2433
#endif
2449
2434
0 commit comments