@@ -525,6 +525,8 @@ class VPSingleDefRecipe : public VPRecipeBase, public VPValue {
525
525
case VPRecipeBase::VPInstructionSC:
526
526
case VPRecipeBase::VPReductionEVLSC:
527
527
case VPRecipeBase::VPReductionSC:
528
+ case VPRecipeBase::VPMulAccumulateReductionSC:
529
+ case VPRecipeBase::VPExtendedReductionSC:
528
530
case VPRecipeBase::VPReplicateSC:
529
531
case VPRecipeBase::VPScalarIVStepsSC:
530
532
case VPRecipeBase::VPVectorPointerSC:
@@ -609,13 +611,15 @@ class VPRecipeWithIRFlags : public VPSingleDefRecipe {
609
611
DisjointFlagsTy (bool IsDisjoint) : IsDisjoint(IsDisjoint) {}
610
612
};
611
613
614
+ struct NonNegFlagsTy {
615
+ char NonNeg : 1 ;
616
+ NonNegFlagsTy (bool IsNonNeg) : NonNeg(IsNonNeg) {}
617
+ };
618
+
612
619
private:
613
620
struct ExactFlagsTy {
614
621
char IsExact : 1 ;
615
622
};
616
- struct NonNegFlagsTy {
617
- char NonNeg : 1 ;
618
- };
619
623
struct FastMathFlagsTy {
620
624
char AllowReassoc : 1 ;
621
625
char NoNaNs : 1 ;
@@ -709,6 +713,12 @@ class VPRecipeWithIRFlags : public VPSingleDefRecipe {
709
713
: VPSingleDefRecipe(SC, Operands, DL), OpType(OperationType::DisjointOp),
710
714
DisjointFlags(DisjointFlags) {}
711
715
716
+ template <typename IterT>
717
+ VPRecipeWithIRFlags (const unsigned char SC, IterT Operands,
718
+ NonNegFlagsTy NonNegFlags, DebugLoc DL = {})
719
+ : VPSingleDefRecipe(SC, Operands, DL), OpType(OperationType::NonNegOp),
720
+ NonNegFlags(NonNegFlags) {}
721
+
712
722
protected:
713
723
template <typename IterT>
714
724
VPRecipeWithIRFlags (const unsigned char SC, IterT Operands,
@@ -728,7 +738,9 @@ class VPRecipeWithIRFlags : public VPSingleDefRecipe {
728
738
R->getVPDefID () == VPRecipeBase::VPReductionEVLSC ||
729
739
R->getVPDefID () == VPRecipeBase::VPReplicateSC ||
730
740
R->getVPDefID () == VPRecipeBase::VPVectorEndPointerSC ||
731
- R->getVPDefID () == VPRecipeBase::VPVectorPointerSC;
741
+ R->getVPDefID () == VPRecipeBase::VPVectorPointerSC ||
742
+ R->getVPDefID () == VPRecipeBase::VPExtendedReductionSC ||
743
+ R->getVPDefID () == VPRecipeBase::VPMulAccumulateReductionSC;
732
744
}
733
745
734
746
static inline bool classof (const VPUser *U) {
@@ -820,6 +832,15 @@ class VPRecipeWithIRFlags : public VPSingleDefRecipe {
820
832
821
833
FastMathFlags getFastMathFlags () const ;
822
834
835
+ // / Returns true if the recipe has non-negative flag.
836
+ bool hasNonNegFlag () const { return OpType == OperationType::NonNegOp; }
837
+
838
+ bool isNonNeg () const {
839
+ assert (OpType == OperationType::NonNegOp &&
840
+ " recipe doesn't have a NNEG flag" );
841
+ return NonNegFlags.NonNeg ;
842
+ }
843
+
823
844
bool hasNoUnsignedWrap () const {
824
845
assert (OpType == OperationType::OverflowingBinOp &&
825
846
" recipe doesn't have a NUW flag" );
@@ -2373,6 +2394,28 @@ class VPReductionRecipe : public VPRecipeWithIRFlags {
2373
2394
setUnderlyingValue (I);
2374
2395
}
2375
2396
2397
+ // / For VPExtendedReductionRecipe.
2398
+ // / Note that the debug location is from the extend.
2399
+ VPReductionRecipe (const unsigned char SC, const RecurKind RdxKind,
2400
+ ArrayRef<VPValue *> Operands, VPValue *CondOp,
2401
+ bool IsOrdered, DebugLoc DL)
2402
+ : VPRecipeWithIRFlags(SC, Operands, DL), RdxKind(RdxKind),
2403
+ IsOrdered(IsOrdered), IsConditional(CondOp) {
2404
+ if (CondOp)
2405
+ addOperand (CondOp);
2406
+ }
2407
+
2408
+ // / For VPMulAccumulateReductionRecipe.
2409
+ // / Note that the NUW/NSW flags and the debug location are from the Mul.
2410
+ VPReductionRecipe (const unsigned char SC, const RecurKind RdxKind,
2411
+ ArrayRef<VPValue *> Operands, VPValue *CondOp,
2412
+ bool IsOrdered, WrapFlagsTy WrapFlags, DebugLoc DL)
2413
+ : VPRecipeWithIRFlags(SC, Operands, WrapFlags, DL), RdxKind(RdxKind),
2414
+ IsOrdered(IsOrdered), IsConditional(CondOp) {
2415
+ if (CondOp)
2416
+ addOperand (CondOp);
2417
+ }
2418
+
2376
2419
public:
2377
2420
VPReductionRecipe (RecurKind RdxKind, FastMathFlags FMFs, Instruction *I,
2378
2421
VPValue *ChainOp, VPValue *VecOp, VPValue *CondOp,
@@ -2381,6 +2424,13 @@ class VPReductionRecipe : public VPRecipeWithIRFlags {
2381
2424
ArrayRef<VPValue *>({ChainOp, VecOp}), CondOp,
2382
2425
IsOrdered, DL) {}
2383
2426
2427
+ VPReductionRecipe (const RecurKind RdxKind, FastMathFlags FMFs,
2428
+ VPValue *ChainOp, VPValue *VecOp, VPValue *CondOp,
2429
+ bool IsOrdered, DebugLoc DL = {})
2430
+ : VPReductionRecipe(VPDef::VPReductionSC, RdxKind, FMFs, nullptr ,
2431
+ ArrayRef<VPValue *>({ChainOp, VecOp}), CondOp,
2432
+ IsOrdered, DL) {}
2433
+
2384
2434
~VPReductionRecipe () override = default ;
2385
2435
2386
2436
VPReductionRecipe *clone () override {
@@ -2391,7 +2441,9 @@ class VPReductionRecipe : public VPRecipeWithIRFlags {
2391
2441
2392
2442
static inline bool classof (const VPRecipeBase *R) {
2393
2443
return R->getVPDefID () == VPRecipeBase::VPReductionSC ||
2394
- R->getVPDefID () == VPRecipeBase::VPReductionEVLSC;
2444
+ R->getVPDefID () == VPRecipeBase::VPReductionEVLSC ||
2445
+ R->getVPDefID () == VPRecipeBase::VPExtendedReductionSC ||
2446
+ R->getVPDefID () == VPRecipeBase::VPMulAccumulateReductionSC;
2395
2447
}
2396
2448
2397
2449
static inline bool classof (const VPUser *U) {
@@ -2471,6 +2523,181 @@ class VPReductionEVLRecipe : public VPReductionRecipe {
2471
2523
}
2472
2524
};
2473
2525
2526
+ // / A recipe to represent inloop extended reduction operations, performing a
2527
+ // / reduction on a extended vector operand into a scalar value, and adding the
2528
+ // / result to a chain. This recipe is abstract and needs to be lowered to
2529
+ // / concrete recipes before codegen. The operands are {ChainOp, VecOp,
2530
+ // / [Condition]}.
2531
+ class VPExtendedReductionRecipe : public VPReductionRecipe {
2532
+ // / Opcode of the extend recipe will be lowered to.
2533
+ Instruction::CastOps ExtOp;
2534
+
2535
+ Type *ResultTy;
2536
+
2537
+ // / For cloning VPExtendedReductionRecipe.
2538
+ VPExtendedReductionRecipe (VPExtendedReductionRecipe *ExtRed)
2539
+ : VPReductionRecipe(
2540
+ VPDef::VPExtendedReductionSC, ExtRed->getRecurrenceKind (),
2541
+ {ExtRed->getChainOp (), ExtRed->getVecOp ()}, ExtRed->getCondOp (),
2542
+ ExtRed->isOrdered(), ExtRed->getDebugLoc()),
2543
+ ExtOp(ExtRed->getExtOpcode ()), ResultTy(ExtRed->getResultType ()) {
2544
+ transferFlags (*ExtRed);
2545
+ }
2546
+
2547
+ public:
2548
+ VPExtendedReductionRecipe (VPReductionRecipe *R, VPWidenCastRecipe *Ext)
2549
+ : VPReductionRecipe(VPDef::VPExtendedReductionSC, R->getRecurrenceKind (),
2550
+ {R->getChainOp (), Ext->getOperand (0 )}, R->getCondOp (),
2551
+ R->isOrdered(), Ext->getDebugLoc()),
2552
+ ExtOp(Ext->getOpcode ()), ResultTy(Ext->getResultType ()) {
2553
+ // Not all WidenCastRecipes contain nneg flag. Need to transfer flags from
2554
+ // the original recipe to prevent setting wrong flags.
2555
+ transferFlags (*Ext);
2556
+ }
2557
+
2558
+ ~VPExtendedReductionRecipe () override = default ;
2559
+
2560
+ VPExtendedReductionRecipe *clone () override {
2561
+ auto *Copy = new VPExtendedReductionRecipe (this );
2562
+ Copy->transferFlags (*this );
2563
+ return Copy;
2564
+ }
2565
+
2566
+ VP_CLASSOF_IMPL (VPDef::VPExtendedReductionSC);
2567
+
2568
+ void execute (VPTransformState &State) override {
2569
+ llvm_unreachable (" VPExtendedReductionRecipe should be transform to "
2570
+ " VPExtendedRecipe + VPReductionRecipe before execution." );
2571
+ };
2572
+
2573
+ // / Return the cost of VPExtendedReductionRecipe.
2574
+ InstructionCost computeCost (ElementCount VF,
2575
+ VPCostContext &Ctx) const override ;
2576
+
2577
+ #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
2578
+ // / Print the recipe.
2579
+ void print (raw_ostream &O, const Twine &Indent,
2580
+ VPSlotTracker &SlotTracker) const override ;
2581
+ #endif
2582
+
2583
+ // / The scalar type after extending.
2584
+ Type *getResultType () const { return ResultTy; }
2585
+
2586
+ // / Is the extend ZExt?
2587
+ bool isZExt () const { return getExtOpcode () == Instruction::ZExt; }
2588
+
2589
+ // / The opcode of extend recipe.
2590
+ Instruction::CastOps getExtOpcode () const { return ExtOp; }
2591
+ };
2592
+
2593
+ // / A recipe to represent inloop MulAccumulateReduction operations, performing a
2594
+ // / reduction.add on the result of vector operands (might be extended)
2595
+ // / multiplication into a scalar value, and adding the result to a chain. This
2596
+ // / recipe is abstract and needs to be lowered to concrete recipes before
2597
+ // / codegen. The operands are {ChainOp, VecOp1, VecOp2, [Condition]}.
2598
+ class VPMulAccumulateReductionRecipe : public VPReductionRecipe {
2599
+ // / Opcode of the extend recipe.
2600
+ Instruction::CastOps ExtOp;
2601
+
2602
+ // / Non-neg flag of the extend recipe.
2603
+ bool IsNonNeg = false ;
2604
+
2605
+ Type *ResultTy;
2606
+
2607
+ // / For cloning VPMulAccumulateReductionRecipe.
2608
+ VPMulAccumulateReductionRecipe (VPMulAccumulateReductionRecipe *MulAcc)
2609
+ : VPReductionRecipe(
2610
+ VPDef::VPMulAccumulateReductionSC, MulAcc->getRecurrenceKind (),
2611
+ {MulAcc->getChainOp (), MulAcc->getVecOp0 (), MulAcc->getVecOp1 ()},
2612
+ MulAcc->getCondOp (), MulAcc->isOrdered(),
2613
+ WrapFlagsTy(MulAcc->hasNoUnsignedWrap (), MulAcc->hasNoSignedWrap()),
2614
+ MulAcc->getDebugLoc()),
2615
+ ExtOp(MulAcc->getExtOpcode ()), IsNonNeg(MulAcc->isNonNeg ()),
2616
+ ResultTy(MulAcc->getResultType ()) {}
2617
+
2618
+ public:
2619
+ VPMulAccumulateReductionRecipe (VPReductionRecipe *R, VPWidenRecipe *Mul,
2620
+ VPWidenCastRecipe *Ext0,
2621
+ VPWidenCastRecipe *Ext1, Type *ResultTy)
2622
+ : VPReductionRecipe(
2623
+ VPDef::VPMulAccumulateReductionSC, R->getRecurrenceKind (),
2624
+ {R->getChainOp (), Ext0->getOperand (0 ), Ext1->getOperand (0 )},
2625
+ R->getCondOp (), R->isOrdered(),
2626
+ WrapFlagsTy(Mul->hasNoUnsignedWrap (), Mul->hasNoSignedWrap()),
2627
+ R->getDebugLoc()),
2628
+ ExtOp(Ext0->getOpcode ()), ResultTy(ResultTy) {
2629
+ assert (RecurrenceDescriptor::getOpcode (getRecurrenceKind ()) ==
2630
+ Instruction::Add &&
2631
+ " The reduction instruction in MulAccumulateteReductionRecipe must "
2632
+ " be Add" );
2633
+ // Only set the non-negative flag if the original recipe contains.
2634
+ if (Ext0->hasNonNegFlag ())
2635
+ IsNonNeg = Ext0->isNonNeg ();
2636
+ }
2637
+
2638
+ VPMulAccumulateReductionRecipe (VPReductionRecipe *R, VPWidenRecipe *Mul)
2639
+ : VPReductionRecipe(
2640
+ VPDef::VPMulAccumulateReductionSC, R->getRecurrenceKind (),
2641
+ {R->getChainOp (), Mul->getOperand (0 ), Mul->getOperand (1 )},
2642
+ R->getCondOp (), R->isOrdered(),
2643
+ WrapFlagsTy(Mul->hasNoUnsignedWrap (), Mul->hasNoSignedWrap()),
2644
+ R->getDebugLoc()),
2645
+ ExtOp(Instruction::CastOps::CastOpsEnd) {
2646
+ assert (RecurrenceDescriptor::getOpcode (getRecurrenceKind ()) ==
2647
+ Instruction::Add &&
2648
+ " The reduction instruction in MulAccumulateReductionRecipe must be "
2649
+ " Add" );
2650
+ }
2651
+
2652
+ ~VPMulAccumulateReductionRecipe () override = default ;
2653
+
2654
+ VPMulAccumulateReductionRecipe *clone () override {
2655
+ auto *Copy = new VPMulAccumulateReductionRecipe (this );
2656
+ Copy->transferFlags (*this );
2657
+ return Copy;
2658
+ }
2659
+
2660
+ VP_CLASSOF_IMPL (VPDef::VPMulAccumulateReductionSC);
2661
+
2662
+ void execute (VPTransformState &State) override {
2663
+ llvm_unreachable (" VPMulAccumulateReductionRecipe should transform to "
2664
+ " VPWidenCastRecipe + "
2665
+ " VPWidenRecipe + VPReductionRecipe before execution" );
2666
+ }
2667
+
2668
+ // / Return the cost of VPMulAccumulateReductionRecipe.
2669
+ InstructionCost computeCost (ElementCount VF,
2670
+ VPCostContext &Ctx) const override ;
2671
+
2672
+ #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
2673
+ // / Print the recipe.
2674
+ void print (raw_ostream &O, const Twine &Indent,
2675
+ VPSlotTracker &SlotTracker) const override ;
2676
+ #endif
2677
+
2678
+ Type *getResultType () const {
2679
+ assert (isExtended () && " Only support getResultType when this recipe "
2680
+ " contains implicit extend." );
2681
+ return ResultTy;
2682
+ }
2683
+
2684
+ // / The VPValue of the vector value to be extended and reduced.
2685
+ VPValue *getVecOp0 () const { return getOperand (1 ); }
2686
+ VPValue *getVecOp1 () const { return getOperand (2 ); }
2687
+
2688
+ // / Return if this MulAcc recipe contains extended operands.
2689
+ bool isExtended () const { return ExtOp != Instruction::CastOps::CastOpsEnd; }
2690
+
2691
+ // / Return the opcode of the extends for the operands.
2692
+ Instruction::CastOps getExtOpcode () const { return ExtOp; }
2693
+
2694
+ // / Return if the operands are zero extended.
2695
+ bool isZExt () const { return ExtOp == Instruction::CastOps::ZExt; }
2696
+
2697
+ // / Return the non negative flag of the ext recipe.
2698
+ bool isNonNeg () const { return IsNonNeg; }
2699
+ };
2700
+
2474
2701
// / VPReplicateRecipe replicates a given instruction producing multiple scalar
2475
2702
// / copies of the original scalar type, one per lane, instead of producing a
2476
2703
// / single copy of widened type for all lanes. If the instruction is known to be
0 commit comments