Skip to content

Commit df8e0d0

Browse files
author
Rin
authored
[AArch64][LoopVectorize] Use upper bound trip count instead of the constant TC when choosing max VF (#67697)
This patch is based off of #67543. We are currently using the exact trip count to make decisions regarding the maximum VF. We can instead use the upper bound TC, which will be the same as the constant trip count when that is known.
1 parent f9bd62f commit df8e0d0

File tree

2 files changed

+55
-23
lines changed

2 files changed

+55
-23
lines changed

llvm/lib/Transforms/Vectorize/LoopVectorize.cpp

Lines changed: 24 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1695,14 +1695,14 @@ class LoopVectorizationCostModel {
16951695
/// elements is a power-of-2 larger than zero. If scalable vectorization is
16961696
/// disabled or unsupported, then the scalable part will be equal to
16971697
/// ElementCount::getScalable(0).
1698-
FixedScalableVFPair computeFeasibleMaxVF(unsigned ConstTripCount,
1698+
FixedScalableVFPair computeFeasibleMaxVF(unsigned MaxTripCount,
16991699
ElementCount UserVF,
17001700
bool FoldTailByMasking);
17011701

17021702
/// \return the maximized element count based on the targets vector
17031703
/// registers and the loop trip-count, but limited to a maximum safe VF.
17041704
/// This is a helper function of computeFeasibleMaxVF.
1705-
ElementCount getMaximizedVFForTarget(unsigned ConstTripCount,
1705+
ElementCount getMaximizedVFForTarget(unsigned MaxTripCount,
17061706
unsigned SmallestType,
17071707
unsigned WidestType,
17081708
ElementCount MaxSafeVF,
@@ -4809,7 +4809,7 @@ LoopVectorizationCostModel::getMaxLegalScalableVF(unsigned MaxSafeElements) {
48094809
}
48104810

48114811
FixedScalableVFPair LoopVectorizationCostModel::computeFeasibleMaxVF(
4812-
unsigned ConstTripCount, ElementCount UserVF, bool FoldTailByMasking) {
4812+
unsigned MaxTripCount, ElementCount UserVF, bool FoldTailByMasking) {
48134813
MinBWs = computeMinimumValueSizes(TheLoop->getBlocks(), *DB, &TTI);
48144814
unsigned SmallestType, WidestType;
48154815
std::tie(SmallestType, WidestType) = getSmallestAndWidestTypes();
@@ -4897,12 +4897,12 @@ FixedScalableVFPair LoopVectorizationCostModel::computeFeasibleMaxVF(
48974897
FixedScalableVFPair Result(ElementCount::getFixed(1),
48984898
ElementCount::getScalable(0));
48994899
if (auto MaxVF =
4900-
getMaximizedVFForTarget(ConstTripCount, SmallestType, WidestType,
4900+
getMaximizedVFForTarget(MaxTripCount, SmallestType, WidestType,
49014901
MaxSafeFixedVF, FoldTailByMasking))
49024902
Result.FixedVF = MaxVF;
49034903

49044904
if (auto MaxVF =
4905-
getMaximizedVFForTarget(ConstTripCount, SmallestType, WidestType,
4905+
getMaximizedVFForTarget(MaxTripCount, SmallestType, WidestType,
49064906
MaxSafeScalableVF, FoldTailByMasking))
49074907
if (MaxVF.isScalable()) {
49084908
Result.ScalableVF = MaxVF;
@@ -4926,6 +4926,7 @@ LoopVectorizationCostModel::computeMaxVF(ElementCount UserVF, unsigned UserIC) {
49264926
}
49274927

49284928
unsigned TC = PSE.getSE()->getSmallConstantTripCount(TheLoop);
4929+
unsigned MaxTC = PSE.getSE()->getSmallConstantMaxTripCount(TheLoop);
49294930
LLVM_DEBUG(dbgs() << "LV: Found trip count: " << TC << '\n');
49304931
if (TC == 1) {
49314932
reportVectorizationFailure("Single iteration (non) loop",
@@ -4936,7 +4937,7 @@ LoopVectorizationCostModel::computeMaxVF(ElementCount UserVF, unsigned UserIC) {
49364937

49374938
switch (ScalarEpilogueStatus) {
49384939
case CM_ScalarEpilogueAllowed:
4939-
return computeFeasibleMaxVF(TC, UserVF, false);
4940+
return computeFeasibleMaxVF(MaxTC, UserVF, false);
49404941
case CM_ScalarEpilogueNotAllowedUsePredicate:
49414942
[[fallthrough]];
49424943
case CM_ScalarEpilogueNotNeededUsePredicate:
@@ -4974,7 +4975,7 @@ LoopVectorizationCostModel::computeMaxVF(ElementCount UserVF, unsigned UserIC) {
49744975
LLVM_DEBUG(dbgs() << "LV: Cannot fold tail by masking: vectorize with a "
49754976
"scalar epilogue instead.\n");
49764977
ScalarEpilogueStatus = CM_ScalarEpilogueAllowed;
4977-
return computeFeasibleMaxVF(TC, UserVF, false);
4978+
return computeFeasibleMaxVF(MaxTC, UserVF, false);
49784979
}
49794980
return FixedScalableVFPair::getNone();
49804981
}
@@ -4991,7 +4992,7 @@ LoopVectorizationCostModel::computeMaxVF(ElementCount UserVF, unsigned UserIC) {
49914992
InterleaveInfo.invalidateGroupsRequiringScalarEpilogue();
49924993
}
49934994

4994-
FixedScalableVFPair MaxFactors = computeFeasibleMaxVF(TC, UserVF, true);
4995+
FixedScalableVFPair MaxFactors = computeFeasibleMaxVF(MaxTC, UserVF, true);
49954996

49964997
// Avoid tail folding if the trip count is known to be a multiple of any VF
49974998
// we choose.
@@ -5067,7 +5068,7 @@ LoopVectorizationCostModel::computeMaxVF(ElementCount UserVF, unsigned UserIC) {
50675068
}
50685069

50695070
ElementCount LoopVectorizationCostModel::getMaximizedVFForTarget(
5070-
unsigned ConstTripCount, unsigned SmallestType, unsigned WidestType,
5071+
unsigned MaxTripCount, unsigned SmallestType, unsigned WidestType,
50715072
ElementCount MaxSafeVF, bool FoldTailByMasking) {
50725073
bool ComputeScalableMaxVF = MaxSafeVF.isScalable();
50735074
const TypeSize WidestRegister = TTI.getRegisterBitWidth(
@@ -5106,24 +5107,24 @@ ElementCount LoopVectorizationCostModel::getMaximizedVFForTarget(
51065107
}
51075108

51085109
// When a scalar epilogue is required, at least one iteration of the scalar
5109-
// loop has to execute. Adjust ConstTripCount accordingly to avoid picking a
5110+
// loop has to execute. Adjust MaxTripCount accordingly to avoid picking a
51105111
// max VF that results in a dead vector loop.
5111-
if (ConstTripCount > 0 && requiresScalarEpilogue(true))
5112-
ConstTripCount -= 1;
5113-
5114-
if (ConstTripCount && ConstTripCount <= WidestRegisterMinEC &&
5115-
(!FoldTailByMasking || isPowerOf2_32(ConstTripCount))) {
5116-
// If loop trip count (TC) is known at compile time there is no point in
5117-
// choosing VF greater than TC (as done in the loop below). Select maximum
5118-
// power of two which doesn't exceed TC.
5119-
// If MaxVectorElementCount is scalable, we only fall back on a fixed VF
5120-
// when the TC is less than or equal to the known number of lanes.
5121-
auto ClampedConstTripCount = llvm::bit_floor(ConstTripCount);
5112+
if (MaxTripCount > 0 && requiresScalarEpilogue(true))
5113+
MaxTripCount -= 1;
5114+
5115+
if (MaxTripCount && MaxTripCount <= WidestRegisterMinEC &&
5116+
(!FoldTailByMasking || isPowerOf2_32(MaxTripCount))) {
5117+
// If upper bound loop trip count (TC) is known at compile time there is no
5118+
// point in choosing VF greater than TC (as done in the loop below). Select
5119+
// maximum power of two which doesn't exceed TC. If MaxVectorElementCount is
5120+
// scalable, we only fall back on a fixed VF when the TC is less than or
5121+
// equal to the known number of lanes.
5122+
auto ClampedUpperTripCount = llvm::bit_floor(MaxTripCount);
51225123
LLVM_DEBUG(dbgs() << "LV: Clamping the MaxVF to maximum power of two not "
51235124
"exceeding the constant trip count: "
5124-
<< ClampedConstTripCount << "\n");
5125+
<< ClampedUpperTripCount << "\n");
51255126
return ElementCount::get(
5126-
ClampedConstTripCount,
5127+
ClampedUpperTripCount,
51275128
FoldTailByMasking ? MaxVectorElementCount.isScalable() : false);
51285129
}
51295130

llvm/test/Transforms/LoopVectorize/AArch64/clamped-trip-count.ll

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,3 +21,34 @@ for.body: ; preds = %entry, %for.body
2121
for.cond.cleanup: ; preds = %for.body
2222
ret void
2323
}
24+
25+
define void @clamped_tc_max_8(ptr nocapture %dst, i32 %n, i64 %val){
26+
; CHECK-LABEL: define void @clamped_tc_max_8(
27+
; CHECK: call void @llvm.masked.store.nxv8i8.p0(<vscale x 8 x i8> {{.*}}, ptr {{.*}}, i32 1, <vscale x 8 x i1> {{.*}})
28+
29+
entry:
30+
%rem = and i32 %n, 63
31+
%cmp8.not = icmp eq i32 %rem, 0
32+
br i1 %cmp8.not, label %for.cond.cleanup, label %for.body.preheader
33+
34+
for.body.preheader: ; preds = %entry
35+
%add = add nuw nsw i32 %rem, 7
36+
%shr = lshr i32 %add, 3
37+
%wide.trip.count = zext i32 %shr to i64
38+
br label %for.body
39+
40+
for.body: ; preds = %for.body.preheader, %for.body
41+
%indvars.iv = phi i64 [ 0, %for.body.preheader ], [ %indvars.iv.next, %for.body ]
42+
%p_out_tail.09 = phi ptr [ %dst, %for.body.preheader ], [ %incdec.ptr, %for.body ]
43+
%0 = shl nuw nsw i64 %indvars.iv, 3
44+
%shr3 = lshr i64 %val, %0
45+
%conv4 = trunc i64 %shr3 to i8
46+
store i8 %conv4, ptr %p_out_tail.09, align 1
47+
%incdec.ptr = getelementptr inbounds i8, ptr %p_out_tail.09, i64 1
48+
%indvars.iv.next = add nuw nsw i64 %indvars.iv, 1
49+
%exitcond.not = icmp eq i64 %indvars.iv.next, %wide.trip.count
50+
br i1 %exitcond.not, label %for.cond.cleanup, label %for.body
51+
52+
for.cond.cleanup: ; preds = %for.body
53+
ret void
54+
}

0 commit comments

Comments
 (0)