Skip to content

Commit 764aac0

Browse files
committed
[RISCV][LoopIdiom] Support VP intrinsics in LoopIdiomTransform
Teach LoopIdiomTransform to use VP intrinsics to replace the byte compare loops. Right now only RISC-V uses LoopIdiomTransform of this style.
1 parent b1420ba commit 764aac0

File tree

6 files changed

+1938
-12
lines changed

6 files changed

+1938
-12
lines changed

llvm/include/llvm/Transforms/Vectorize/LoopIdiomTransform.h

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,22 @@
1313
#include "llvm/Transforms/Scalar/LoopPassManager.h"
1414

1515
namespace llvm {
16-
struct LoopIdiomTransformPass : PassInfoMixin<LoopIdiomTransformPass> {
16+
enum class LoopIdiomTransformStyle { Masked, Predicated };
17+
18+
class LoopIdiomTransformPass : public PassInfoMixin<LoopIdiomTransformPass> {
19+
LoopIdiomTransformStyle VectorizeStyle = LoopIdiomTransformStyle::Masked;
20+
21+
// The VF used in vectorizing the byte compare pattern.
22+
unsigned ByteCompareVF = 16;
23+
24+
public:
25+
LoopIdiomTransformPass() = default;
26+
explicit LoopIdiomTransformPass(LoopIdiomTransformStyle S)
27+
: VectorizeStyle(S) {}
28+
29+
LoopIdiomTransformPass(LoopIdiomTransformStyle S, unsigned BCVF)
30+
: VectorizeStyle(S), ByteCompareVF(BCVF) {}
31+
1732
PreservedAnalyses run(Loop &L, LoopAnalysisManager &AM,
1833
LoopStandardAnalysisResults &AR, LPMUpdater &U);
1934
};

llvm/lib/Target/RISCV/RISCVTargetMachine.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,12 @@
3333
#include "llvm/CodeGen/TargetPassConfig.h"
3434
#include "llvm/InitializePasses.h"
3535
#include "llvm/MC/TargetRegistry.h"
36+
#include "llvm/Passes/PassBuilder.h"
3637
#include "llvm/Support/FormattedStream.h"
3738
#include "llvm/Target/TargetOptions.h"
3839
#include "llvm/Transforms/IPO.h"
3940
#include "llvm/Transforms/Scalar.h"
41+
#include "llvm/Transforms/Vectorize/LoopIdiomTransform.h"
4042
#include <optional>
4143
using namespace llvm;
4244

@@ -576,6 +578,14 @@ void RISCVPassConfig::addPostRegAlloc() {
576578
addPass(createRISCVRedundantCopyEliminationPass());
577579
}
578580

581+
void RISCVTargetMachine::registerPassBuilderCallbacks(
582+
PassBuilder &PB, bool PopulateClassToPassNames) {
583+
PB.registerLateLoopOptimizationsEPCallback([=](LoopPassManager &LPM,
584+
OptimizationLevel Level) {
585+
LPM.addPass(LoopIdiomTransformPass(LoopIdiomTransformStyle::Predicated));
586+
});
587+
}
588+
579589
yaml::MachineFunctionInfo *
580590
RISCVTargetMachine::createDefaultFuncInfoYAML() const {
581591
return new yaml::RISCVMachineFunctionInfo();

llvm/lib/Target/RISCV/RISCVTargetMachine.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@ class RISCVTargetMachine : public LLVMTargetMachine {
5959
PerFunctionMIParsingState &PFS,
6060
SMDiagnostic &Error,
6161
SMRange &SourceRange) const override;
62+
void registerPassBuilderCallbacks(PassBuilder &PB,
63+
bool PopulateClassToPassNames) override;
6264
};
6365
} // namespace llvm
6466

llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -397,6 +397,8 @@ class RISCVTTIImpl : public BasicTTIImplBase<RISCVTTIImpl> {
397397
bool shouldFoldTerminatingConditionAfterLSR() const {
398398
return true;
399399
}
400+
401+
std::optional<unsigned> getMinPageSize() const { return 4096; }
400402
};
401403

402404
} // end namespace llvm

llvm/lib/Transforms/Vectorize/LoopIdiomTransform.cpp

Lines changed: 157 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -56,18 +56,34 @@ static cl::opt<bool> DisableAll("disable-loop-idiom-transform-all", cl::Hidden,
5656
cl::init(false),
5757
cl::desc("Disable Loop Idiom Transform Pass."));
5858

59+
static cl::opt<LoopIdiomTransformStyle>
60+
LITVecStyle("loop-idiom-transform-style", cl::Hidden,
61+
cl::desc("The vectorization style for loop idiom transform."),
62+
cl::values(clEnumValN(LoopIdiomTransformStyle::Masked, "masked",
63+
"Use masked vector intrinsics"),
64+
clEnumValN(LoopIdiomTransformStyle::Predicated,
65+
"predicated", "Use VP intrinsics")),
66+
cl::init(LoopIdiomTransformStyle::Masked));
67+
5968
static cl::opt<bool>
6069
DisableByteCmp("disable-loop-idiom-transform-bytecmp", cl::Hidden,
6170
cl::init(false),
6271
cl::desc("Proceed with Loop Idiom Transform Pass, but do "
6372
"not convert byte-compare loop(s)."));
6473

74+
static cl::opt<unsigned>
75+
ByteCmpVF("loop-idiom-transform-bytecmp-vf", cl::Hidden,
76+
cl::desc("The vectorization factor for byte-compare patterns."),
77+
cl::init(16));
78+
6579
static cl::opt<bool>
6680
VerifyLoops("verify-loop-idiom-transform", cl::Hidden, cl::init(false),
6781
cl::desc("Verify loops generated Loop Idiom Transform Pass."));
6882

6983
namespace {
7084
class LoopIdiomTransform {
85+
LoopIdiomTransformStyle VectorizeStyle;
86+
unsigned ByteCompareVF;
7187
Loop *CurLoop = nullptr;
7288
DominatorTree *DT;
7389
LoopInfo *LI;
@@ -82,10 +98,11 @@ class LoopIdiomTransform {
8298
BasicBlock *VectorLoopIncBlock = nullptr;
8399

84100
public:
85-
explicit LoopIdiomTransform(DominatorTree *DT, LoopInfo *LI,
86-
const TargetTransformInfo *TTI,
87-
const DataLayout *DL)
88-
: DT(DT), LI(LI), TTI(TTI), DL(DL) {}
101+
LoopIdiomTransform(LoopIdiomTransformStyle S, unsigned VF, DominatorTree *DT,
102+
LoopInfo *LI, const TargetTransformInfo *TTI,
103+
const DataLayout *DL)
104+
: VectorizeStyle(S), ByteCompareVF(VF), DT(DT), LI(LI), TTI(TTI), DL(DL) {
105+
}
89106

90107
bool run(Loop *L);
91108

@@ -106,6 +123,10 @@ class LoopIdiomTransform {
106123
Value *createMaskedFindMismatch(IRBuilder<> &Builder, GetElementPtrInst *GEPA,
107124
GetElementPtrInst *GEPB, Value *ExtStart,
108125
Value *ExtEnd);
126+
Value *createPredicatedFindMismatch(IRBuilder<> &Builder,
127+
GetElementPtrInst *GEPA,
128+
GetElementPtrInst *GEPB, Value *ExtStart,
129+
Value *ExtEnd);
109130

110131
void transformByteCompare(GetElementPtrInst *GEPA, GetElementPtrInst *GEPB,
111132
PHINode *IndPhi, Value *MaxLen, Instruction *Index,
@@ -123,7 +144,15 @@ PreservedAnalyses LoopIdiomTransformPass::run(Loop &L, LoopAnalysisManager &AM,
123144

124145
const auto *DL = &L.getHeader()->getModule()->getDataLayout();
125146

126-
LoopIdiomTransform LIT(&AR.DT, &AR.LI, &AR.TTI, DL);
147+
LoopIdiomTransformStyle VecStyle = VectorizeStyle;
148+
if (LITVecStyle.getNumOccurrences())
149+
VecStyle = LITVecStyle;
150+
151+
unsigned BCVF = ByteCompareVF;
152+
if (ByteCmpVF.getNumOccurrences())
153+
BCVF = ByteCmpVF;
154+
155+
LoopIdiomTransform LIT(VecStyle, BCVF, &AR.DT, &AR.LI, &AR.TTI, DL);
127156
if (!LIT.run(&L))
128157
return PreservedAnalyses::all();
129158

@@ -357,14 +386,15 @@ Value *LoopIdiomTransform::createMaskedFindMismatch(IRBuilder<> &Builder,
357386
// Therefore, we know that we can use a 64-bit induction variable that
358387
// starts from 0 -> ExtMaxLen and it will not overflow.
359388
ScalableVectorType *PredVTy =
360-
ScalableVectorType::get(Builder.getInt1Ty(), 16);
389+
ScalableVectorType::get(Builder.getInt1Ty(), ByteCompareVF);
361390

362391
Value *InitialPred = Builder.CreateIntrinsic(
363392
Intrinsic::get_active_lane_mask, {PredVTy, I64Type}, {ExtStart, ExtEnd});
364393

365394
Value *VecLen = Builder.CreateIntrinsic(Intrinsic::vscale, {I64Type}, {});
366-
VecLen = Builder.CreateMul(VecLen, ConstantInt::get(I64Type, 16), "",
367-
/*HasNUW=*/true, /*HasNSW=*/true);
395+
VecLen =
396+
Builder.CreateMul(VecLen, ConstantInt::get(I64Type, ByteCompareVF), "",
397+
/*HasNUW=*/true, /*HasNSW=*/true);
368398

369399
Value *PFalse = Builder.CreateVectorSplat(PredVTy->getElementCount(),
370400
Builder.getInt1(false));
@@ -379,7 +409,8 @@ Value *LoopIdiomTransform::createMaskedFindMismatch(IRBuilder<> &Builder,
379409
LoopPred->addIncoming(InitialPred, VectorLoopPreheaderBlock);
380410
PHINode *VectorIndexPhi = Builder.CreatePHI(I64Type, 2, "mismatch_vec_index");
381411
VectorIndexPhi->addIncoming(ExtStart, VectorLoopPreheaderBlock);
382-
Type *VectorLoadType = ScalableVectorType::get(Builder.getInt8Ty(), 16);
412+
Type *VectorLoadType =
413+
ScalableVectorType::get(Builder.getInt8Ty(), ByteCompareVF);
383414
Value *Passthru = ConstantInt::getNullValue(VectorLoadType);
384415

385416
Value *VectorLhsGep = Builder.CreateGEP(LoadType, PtrA, VectorIndexPhi);
@@ -442,6 +473,112 @@ Value *LoopIdiomTransform::createMaskedFindMismatch(IRBuilder<> &Builder,
442473
return Builder.CreateTrunc(VectorLoopRes64, ResType);
443474
}
444475

476+
Value *LoopIdiomTransform::createPredicatedFindMismatch(IRBuilder<> &Builder,
477+
GetElementPtrInst *GEPA,
478+
GetElementPtrInst *GEPB,
479+
Value *ExtStart,
480+
Value *ExtEnd) {
481+
Type *I64Type = Builder.getInt64Ty();
482+
Type *I32Type = Builder.getInt32Ty();
483+
Type *ResType = I32Type;
484+
Type *LoadType = Builder.getInt8Ty();
485+
Value *PtrA = GEPA->getPointerOperand();
486+
Value *PtrB = GEPB->getPointerOperand();
487+
488+
// At this point we know two things must be true:
489+
// 1. Start <= End
490+
// 2. ExtMaxLen <= 4096 due to the page checks.
491+
// Therefore, we know that we can use a 64-bit induction variable that
492+
// starts from 0 -> ExtMaxLen and it will not overflow.
493+
auto *JumpToVectorLoop = BranchInst::Create(VectorLoopStartBlock);
494+
Builder.Insert(JumpToVectorLoop);
495+
496+
// Set up the first Vector loop block by creating the PHIs, doing the vector
497+
// loads and comparing the vectors.
498+
Builder.SetInsertPoint(VectorLoopStartBlock);
499+
auto *VectorIndexPhi = Builder.CreatePHI(I64Type, 2, "mismatch_vector_index");
500+
VectorIndexPhi->addIncoming(ExtStart, VectorLoopPreheaderBlock);
501+
502+
// Calculate AVL by subtracting the vector loop index from the trip count
503+
Value *AVL = Builder.CreateSub(ExtEnd, VectorIndexPhi, "avl", /*HasNUW=*/true,
504+
/*HasNSW=*/true);
505+
506+
auto *VectorLoadType = ScalableVectorType::get(LoadType, ByteCompareVF);
507+
auto *VF = ConstantInt::get(
508+
I32Type, VectorLoadType->getElementCount().getKnownMinValue());
509+
auto *IsScalable = ConstantInt::getBool(
510+
Builder.getContext(), VectorLoadType->getElementCount().isScalable());
511+
512+
Value *VL = Builder.CreateIntrinsic(Intrinsic::experimental_get_vector_length,
513+
{I64Type}, {AVL, VF, IsScalable});
514+
Value *GepOffset = VectorIndexPhi;
515+
516+
Value *VectorLhsGep = Builder.CreateGEP(LoadType, PtrA, GepOffset);
517+
if (GEPA->isInBounds())
518+
cast<GetElementPtrInst>(VectorLhsGep)->setIsInBounds(true);
519+
VectorType *TrueMaskTy =
520+
VectorType::get(Builder.getInt1Ty(), VectorLoadType->getElementCount());
521+
Value *AllTrueMask = Constant::getAllOnesValue(TrueMaskTy);
522+
Value *VectorLhsLoad = Builder.CreateIntrinsic(
523+
Intrinsic::vp_load, {VectorLoadType, VectorLhsGep->getType()},
524+
{VectorLhsGep, AllTrueMask, VL}, nullptr, "lhs.load");
525+
526+
Value *VectorRhsGep = Builder.CreateGEP(LoadType, PtrB, GepOffset);
527+
if (GEPB->isInBounds())
528+
cast<GetElementPtrInst>(VectorRhsGep)->setIsInBounds(true);
529+
Value *VectorRhsLoad = Builder.CreateIntrinsic(
530+
Intrinsic::vp_load, {VectorLoadType, VectorLhsGep->getType()},
531+
{VectorRhsGep, AllTrueMask, VL}, nullptr, "rhs.load");
532+
533+
StringRef PredicateStr = CmpInst::getPredicateName(CmpInst::ICMP_NE);
534+
auto *PredicateMDS = MDString::get(VectorLhsLoad->getContext(), PredicateStr);
535+
Value *Pred = MetadataAsValue::get(VectorLhsLoad->getContext(), PredicateMDS);
536+
Value *VectorMatchCmp = Builder.CreateIntrinsic(
537+
Intrinsic::vp_icmp, {VectorLhsLoad->getType()},
538+
{VectorLhsLoad, VectorRhsLoad, Pred, AllTrueMask, VL}, nullptr,
539+
"mismatch.cmp");
540+
Value *CTZ = Builder.CreateIntrinsic(
541+
Intrinsic::vp_cttz_elts, {ResType, VectorMatchCmp->getType()},
542+
{VectorMatchCmp, /*ZeroIsPoison=*/Builder.getInt1(true), AllTrueMask,
543+
VL});
544+
// RISC-V refines/lowers the poison returned by vp.cttz.elts to -1.
545+
Value *MismatchFound =
546+
Builder.CreateICmpSGE(CTZ, ConstantInt::get(ResType, 0));
547+
auto *VectorEarlyExit = BranchInst::Create(VectorLoopMismatchBlock,
548+
VectorLoopIncBlock, MismatchFound);
549+
Builder.Insert(VectorEarlyExit);
550+
551+
// Increment the index counter and calculate the predicate for the next
552+
// iteration of the loop. We branch back to the start of the loop if there
553+
// is at least one active lane.
554+
Builder.SetInsertPoint(VectorLoopIncBlock);
555+
Value *VL64 = Builder.CreateZExt(VL, I64Type);
556+
Value *NewVectorIndexPhi =
557+
Builder.CreateAdd(VectorIndexPhi, VL64, "",
558+
/*HasNUW=*/true, /*HasNSW=*/true);
559+
VectorIndexPhi->addIncoming(NewVectorIndexPhi, VectorLoopIncBlock);
560+
Value *ExitCond = Builder.CreateICmpNE(NewVectorIndexPhi, ExtEnd);
561+
auto *VectorLoopBranchBack =
562+
BranchInst::Create(VectorLoopStartBlock, EndBlock, ExitCond);
563+
Builder.Insert(VectorLoopBranchBack);
564+
565+
// If we found a mismatch then we need to calculate which lane in the vector
566+
// had a mismatch and add that on to the current loop index.
567+
Builder.SetInsertPoint(VectorLoopMismatchBlock);
568+
569+
// Add LCSSA phis for CTZ and VectorIndexPhi.
570+
auto *CTZLCSSAPhi = Builder.CreatePHI(CTZ->getType(), 1, "ctz");
571+
CTZLCSSAPhi->addIncoming(CTZ, VectorLoopStartBlock);
572+
auto *VectorIndexLCSSAPhi =
573+
Builder.CreatePHI(VectorIndexPhi->getType(), 1, "mismatch_vector_index");
574+
VectorIndexLCSSAPhi->addIncoming(VectorIndexPhi, VectorLoopStartBlock);
575+
576+
Value *CTZI64 = Builder.CreateZExt(CTZLCSSAPhi, I64Type);
577+
Value *VectorLoopRes64 = Builder.CreateAdd(VectorIndexLCSSAPhi, CTZI64, "",
578+
/*HasNUW=*/true, /*HasNSW=*/true);
579+
return Builder.CreateTrunc(VectorLoopRes64, ResType);
580+
}
581+
445582
Value *LoopIdiomTransform::expandFindMismatch(
446583
IRBuilder<> &Builder, DomTreeUpdater &DTU, GetElementPtrInst *GEPA,
447584
GetElementPtrInst *GEPB, Instruction *Index, Value *Start, Value *MaxLen) {
@@ -593,8 +730,17 @@ Value *LoopIdiomTransform::expandFindMismatch(
593730
// processed in each iteration, etc.
594731
Builder.SetInsertPoint(VectorLoopPreheaderBlock);
595732

596-
Value *VectorLoopRes =
597-
createMaskedFindMismatch(Builder, GEPA, GEPB, ExtStart, ExtEnd);
733+
Value *VectorLoopRes = nullptr;
734+
switch (VectorizeStyle) {
735+
case LoopIdiomTransformStyle::Masked:
736+
VectorLoopRes =
737+
createMaskedFindMismatch(Builder, GEPA, GEPB, ExtStart, ExtEnd);
738+
break;
739+
case LoopIdiomTransformStyle::Predicated:
740+
VectorLoopRes =
741+
createPredicatedFindMismatch(Builder, GEPA, GEPB, ExtStart, ExtEnd);
742+
break;
743+
}
598744

599745
Builder.Insert(BranchInst::Create(EndBlock));
600746

0 commit comments

Comments
 (0)