@@ -56,18 +56,34 @@ static cl::opt<bool> DisableAll("disable-loop-idiom-transform-all", cl::Hidden,
56
56
cl::init (false ),
57
57
cl::desc(" Disable Loop Idiom Transform Pass." ));
58
58
59
+ static cl::opt<LoopIdiomTransformStyle>
60
+ LITVecStyle (" loop-idiom-transform-style" , cl::Hidden,
61
+ cl::desc (" The vectorization style for loop idiom transform." ),
62
+ cl::values(clEnumValN(LoopIdiomTransformStyle::Masked, " masked" ,
63
+ " Use masked vector intrinsics" ),
64
+ clEnumValN(LoopIdiomTransformStyle::Predicated,
65
+ " predicated" , " Use VP intrinsics" )),
66
+ cl::init(LoopIdiomTransformStyle::Masked));
67
+
59
68
static cl::opt<bool >
60
69
DisableByteCmp (" disable-loop-idiom-transform-bytecmp" , cl::Hidden,
61
70
cl::init (false ),
62
71
cl::desc(" Proceed with Loop Idiom Transform Pass, but do "
63
72
" not convert byte-compare loop(s)." ));
64
73
74
+ static cl::opt<unsigned >
75
+ ByteCmpVF (" loop-idiom-transform-bytecmp-vf" , cl::Hidden,
76
+ cl::desc (" The vectorization factor for byte-compare patterns." ),
77
+ cl::init(16 ));
78
+
65
79
static cl::opt<bool >
66
80
VerifyLoops (" verify-loop-idiom-transform" , cl::Hidden, cl::init(false ),
67
81
cl::desc(" Verify loops generated Loop Idiom Transform Pass." ));
68
82
69
83
namespace {
70
84
class LoopIdiomTransform {
85
+ LoopIdiomTransformStyle VectorizeStyle;
86
+ unsigned ByteCompareVF;
71
87
Loop *CurLoop = nullptr ;
72
88
DominatorTree *DT;
73
89
LoopInfo *LI;
@@ -82,10 +98,11 @@ class LoopIdiomTransform {
82
98
BasicBlock *VectorLoopIncBlock = nullptr ;
83
99
84
100
public:
85
- explicit LoopIdiomTransform (DominatorTree *DT, LoopInfo *LI,
86
- const TargetTransformInfo *TTI,
87
- const DataLayout *DL)
88
- : DT(DT), LI(LI), TTI(TTI), DL(DL) {}
101
+ LoopIdiomTransform (LoopIdiomTransformStyle S, unsigned VF, DominatorTree *DT,
102
+ LoopInfo *LI, const TargetTransformInfo *TTI,
103
+ const DataLayout *DL)
104
+ : VectorizeStyle(S), ByteCompareVF(VF), DT(DT), LI(LI), TTI(TTI), DL(DL) {
105
+ }
89
106
90
107
bool run (Loop *L);
91
108
@@ -106,6 +123,10 @@ class LoopIdiomTransform {
106
123
Value *createMaskedFindMismatch (IRBuilder<> &Builder, GetElementPtrInst *GEPA,
107
124
GetElementPtrInst *GEPB, Value *ExtStart,
108
125
Value *ExtEnd);
126
+ Value *createPredicatedFindMismatch (IRBuilder<> &Builder,
127
+ GetElementPtrInst *GEPA,
128
+ GetElementPtrInst *GEPB, Value *ExtStart,
129
+ Value *ExtEnd);
109
130
110
131
void transformByteCompare (GetElementPtrInst *GEPA, GetElementPtrInst *GEPB,
111
132
PHINode *IndPhi, Value *MaxLen, Instruction *Index,
@@ -123,7 +144,15 @@ PreservedAnalyses LoopIdiomTransformPass::run(Loop &L, LoopAnalysisManager &AM,
123
144
124
145
const auto *DL = &L.getHeader ()->getModule ()->getDataLayout ();
125
146
126
- LoopIdiomTransform LIT (&AR.DT , &AR.LI , &AR.TTI , DL);
147
+ LoopIdiomTransformStyle VecStyle = VectorizeStyle;
148
+ if (LITVecStyle.getNumOccurrences ())
149
+ VecStyle = LITVecStyle;
150
+
151
+ unsigned BCVF = ByteCompareVF;
152
+ if (ByteCmpVF.getNumOccurrences ())
153
+ BCVF = ByteCmpVF;
154
+
155
+ LoopIdiomTransform LIT (VecStyle, BCVF, &AR.DT , &AR.LI , &AR.TTI , DL);
127
156
if (!LIT.run (&L))
128
157
return PreservedAnalyses::all ();
129
158
@@ -357,14 +386,15 @@ Value *LoopIdiomTransform::createMaskedFindMismatch(IRBuilder<> &Builder,
357
386
// Therefore, we know that we can use a 64-bit induction variable that
358
387
// starts from 0 -> ExtMaxLen and it will not overflow.
359
388
ScalableVectorType *PredVTy =
360
- ScalableVectorType::get (Builder.getInt1Ty (), 16 );
389
+ ScalableVectorType::get (Builder.getInt1Ty (), ByteCompareVF );
361
390
362
391
Value *InitialPred = Builder.CreateIntrinsic (
363
392
Intrinsic::get_active_lane_mask, {PredVTy, I64Type}, {ExtStart, ExtEnd});
364
393
365
394
Value *VecLen = Builder.CreateIntrinsic (Intrinsic::vscale, {I64Type}, {});
366
- VecLen = Builder.CreateMul (VecLen, ConstantInt::get (I64Type, 16 ), " " ,
367
- /* HasNUW=*/ true , /* HasNSW=*/ true );
395
+ VecLen =
396
+ Builder.CreateMul (VecLen, ConstantInt::get (I64Type, ByteCompareVF), " " ,
397
+ /* HasNUW=*/ true , /* HasNSW=*/ true );
368
398
369
399
Value *PFalse = Builder.CreateVectorSplat (PredVTy->getElementCount (),
370
400
Builder.getInt1 (false ));
@@ -379,7 +409,8 @@ Value *LoopIdiomTransform::createMaskedFindMismatch(IRBuilder<> &Builder,
379
409
LoopPred->addIncoming (InitialPred, VectorLoopPreheaderBlock);
380
410
PHINode *VectorIndexPhi = Builder.CreatePHI (I64Type, 2 , " mismatch_vec_index" );
381
411
VectorIndexPhi->addIncoming (ExtStart, VectorLoopPreheaderBlock);
382
- Type *VectorLoadType = ScalableVectorType::get (Builder.getInt8Ty (), 16 );
412
+ Type *VectorLoadType =
413
+ ScalableVectorType::get (Builder.getInt8Ty (), ByteCompareVF);
383
414
Value *Passthru = ConstantInt::getNullValue (VectorLoadType);
384
415
385
416
Value *VectorLhsGep = Builder.CreateGEP (LoadType, PtrA, VectorIndexPhi);
@@ -442,6 +473,112 @@ Value *LoopIdiomTransform::createMaskedFindMismatch(IRBuilder<> &Builder,
442
473
return Builder.CreateTrunc (VectorLoopRes64, ResType );
443
474
}
444
475
476
+ Value *LoopIdiomTransform::createPredicatedFindMismatch (IRBuilder<> &Builder,
477
+ GetElementPtrInst *GEPA,
478
+ GetElementPtrInst *GEPB,
479
+ Value *ExtStart,
480
+ Value *ExtEnd) {
481
+ Type *I64Type = Builder.getInt64Ty ();
482
+ Type *I32Type = Builder.getInt32Ty ();
483
+ Type *ResType = I32Type;
484
+ Type *LoadType = Builder.getInt8Ty ();
485
+ Value *PtrA = GEPA->getPointerOperand ();
486
+ Value *PtrB = GEPB->getPointerOperand ();
487
+
488
+ // At this point we know two things must be true:
489
+ // 1. Start <= End
490
+ // 2. ExtMaxLen <= 4096 due to the page checks.
491
+ // Therefore, we know that we can use a 64-bit induction variable that
492
+ // starts from 0 -> ExtMaxLen and it will not overflow.
493
+ auto *JumpToVectorLoop = BranchInst::Create (VectorLoopStartBlock);
494
+ Builder.Insert (JumpToVectorLoop);
495
+
496
+ // Set up the first Vector loop block by creating the PHIs, doing the vector
497
+ // loads and comparing the vectors.
498
+ Builder.SetInsertPoint (VectorLoopStartBlock);
499
+ auto *VectorIndexPhi = Builder.CreatePHI (I64Type, 2 , " mismatch_vector_index" );
500
+ VectorIndexPhi->addIncoming (ExtStart, VectorLoopPreheaderBlock);
501
+
502
+ // Calculate AVL by subtracting the vector loop index from the trip count
503
+ Value *AVL = Builder.CreateSub (ExtEnd, VectorIndexPhi, " avl" , /* HasNUW=*/ true ,
504
+ /* HasNSW=*/ true );
505
+
506
+ auto *VectorLoadType = ScalableVectorType::get (LoadType, ByteCompareVF);
507
+ auto *VF = ConstantInt::get (
508
+ I32Type, VectorLoadType->getElementCount ().getKnownMinValue ());
509
+ auto *IsScalable = ConstantInt::getBool (
510
+ Builder.getContext (), VectorLoadType->getElementCount ().isScalable ());
511
+
512
+ Value *VL = Builder.CreateIntrinsic (Intrinsic::experimental_get_vector_length,
513
+ {I64Type}, {AVL, VF, IsScalable});
514
+ Value *GepOffset = VectorIndexPhi;
515
+
516
+ Value *VectorLhsGep = Builder.CreateGEP (LoadType, PtrA, GepOffset);
517
+ if (GEPA->isInBounds ())
518
+ cast<GetElementPtrInst>(VectorLhsGep)->setIsInBounds (true );
519
+ VectorType *TrueMaskTy =
520
+ VectorType::get (Builder.getInt1Ty (), VectorLoadType->getElementCount ());
521
+ Value *AllTrueMask = Constant::getAllOnesValue (TrueMaskTy);
522
+ Value *VectorLhsLoad = Builder.CreateIntrinsic (
523
+ Intrinsic::vp_load, {VectorLoadType, VectorLhsGep->getType ()},
524
+ {VectorLhsGep, AllTrueMask, VL}, nullptr , " lhs.load" );
525
+
526
+ Value *VectorRhsGep = Builder.CreateGEP (LoadType, PtrB, GepOffset);
527
+ if (GEPB->isInBounds ())
528
+ cast<GetElementPtrInst>(VectorRhsGep)->setIsInBounds (true );
529
+ Value *VectorRhsLoad = Builder.CreateIntrinsic (
530
+ Intrinsic::vp_load, {VectorLoadType, VectorLhsGep->getType ()},
531
+ {VectorRhsGep, AllTrueMask, VL}, nullptr , " rhs.load" );
532
+
533
+ StringRef PredicateStr = CmpInst::getPredicateName (CmpInst::ICMP_NE);
534
+ auto *PredicateMDS = MDString::get (VectorLhsLoad->getContext (), PredicateStr);
535
+ Value *Pred = MetadataAsValue::get (VectorLhsLoad->getContext (), PredicateMDS);
536
+ Value *VectorMatchCmp = Builder.CreateIntrinsic (
537
+ Intrinsic::vp_icmp, {VectorLhsLoad->getType ()},
538
+ {VectorLhsLoad, VectorRhsLoad, Pred, AllTrueMask, VL}, nullptr ,
539
+ " mismatch.cmp" );
540
+ Value *CTZ = Builder.CreateIntrinsic (
541
+ Intrinsic::vp_cttz_elts, {ResType , VectorMatchCmp->getType ()},
542
+ {VectorMatchCmp, /* ZeroIsPoison=*/ Builder.getInt1 (true ), AllTrueMask,
543
+ VL});
544
+ // RISC-V refines/lowers the poison returned by vp.cttz.elts to -1.
545
+ Value *MismatchFound =
546
+ Builder.CreateICmpSGE (CTZ, ConstantInt::get (ResType , 0 ));
547
+ auto *VectorEarlyExit = BranchInst::Create (VectorLoopMismatchBlock,
548
+ VectorLoopIncBlock, MismatchFound);
549
+ Builder.Insert (VectorEarlyExit);
550
+
551
+ // Increment the index counter and calculate the predicate for the next
552
+ // iteration of the loop. We branch back to the start of the loop if there
553
+ // is at least one active lane.
554
+ Builder.SetInsertPoint (VectorLoopIncBlock);
555
+ Value *VL64 = Builder.CreateZExt (VL, I64Type);
556
+ Value *NewVectorIndexPhi =
557
+ Builder.CreateAdd (VectorIndexPhi, VL64, " " ,
558
+ /* HasNUW=*/ true , /* HasNSW=*/ true );
559
+ VectorIndexPhi->addIncoming (NewVectorIndexPhi, VectorLoopIncBlock);
560
+ Value *ExitCond = Builder.CreateICmpNE (NewVectorIndexPhi, ExtEnd);
561
+ auto *VectorLoopBranchBack =
562
+ BranchInst::Create (VectorLoopStartBlock, EndBlock, ExitCond);
563
+ Builder.Insert (VectorLoopBranchBack);
564
+
565
+ // If we found a mismatch then we need to calculate which lane in the vector
566
+ // had a mismatch and add that on to the current loop index.
567
+ Builder.SetInsertPoint (VectorLoopMismatchBlock);
568
+
569
+ // Add LCSSA phis for CTZ and VectorIndexPhi.
570
+ auto *CTZLCSSAPhi = Builder.CreatePHI (CTZ->getType (), 1 , " ctz" );
571
+ CTZLCSSAPhi->addIncoming (CTZ, VectorLoopStartBlock);
572
+ auto *VectorIndexLCSSAPhi =
573
+ Builder.CreatePHI (VectorIndexPhi->getType (), 1 , " mismatch_vector_index" );
574
+ VectorIndexLCSSAPhi->addIncoming (VectorIndexPhi, VectorLoopStartBlock);
575
+
576
+ Value *CTZI64 = Builder.CreateZExt (CTZLCSSAPhi, I64Type);
577
+ Value *VectorLoopRes64 = Builder.CreateAdd (VectorIndexLCSSAPhi, CTZI64, " " ,
578
+ /* HasNUW=*/ true , /* HasNSW=*/ true );
579
+ return Builder.CreateTrunc (VectorLoopRes64, ResType );
580
+ }
581
+
445
582
Value *LoopIdiomTransform::expandFindMismatch (
446
583
IRBuilder<> &Builder, DomTreeUpdater &DTU, GetElementPtrInst *GEPA,
447
584
GetElementPtrInst *GEPB, Instruction *Index, Value *Start, Value *MaxLen) {
@@ -593,8 +730,17 @@ Value *LoopIdiomTransform::expandFindMismatch(
593
730
// processed in each iteration, etc.
594
731
Builder.SetInsertPoint (VectorLoopPreheaderBlock);
595
732
596
- Value *VectorLoopRes =
597
- createMaskedFindMismatch (Builder, GEPA, GEPB, ExtStart, ExtEnd);
733
+ Value *VectorLoopRes = nullptr ;
734
+ switch (VectorizeStyle) {
735
+ case LoopIdiomTransformStyle::Masked:
736
+ VectorLoopRes =
737
+ createMaskedFindMismatch (Builder, GEPA, GEPB, ExtStart, ExtEnd);
738
+ break ;
739
+ case LoopIdiomTransformStyle::Predicated:
740
+ VectorLoopRes =
741
+ createPredicatedFindMismatch (Builder, GEPA, GEPB, ExtStart, ExtEnd);
742
+ break ;
743
+ }
598
744
599
745
Builder.Insert (BranchInst::Create (EndBlock));
600
746
0 commit comments