Skip to content

Commit eaf48dd

Browse files
committed
[VPlan] Replace BranchOnCount with BranchOnCond if TC <= UF * VF.
Try to simplify BranchOnCount to `BranchOnCond true` if TC <= UF * VF. This is an alternative to D121899 which simplifies the VPlan directly instead of doing so late in code-gen. The potential benefit of doing this in VPlan is that this may help cost-modeling in the future. The reason this is done in prepareToExecute at the moment is that a single plan may be used for multiple VFs/UFs. There are further simplifications that can be applied as follow ups: 1. Replace inductions with constants 2. Replace vector region with regular block. Fixes #55354. Depends on D126679. Reviewed By: Ayal Differential Revision: https://reviews.llvm.org/D126680
1 parent c8db406 commit eaf48dd

File tree

6 files changed

+34
-23
lines changed

6 files changed

+34
-23
lines changed

llvm/lib/Transforms/Vectorize/VPlan.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -767,6 +767,7 @@ void VPInstruction::generateInstruction(VPTransformState &State,
767767
case VPInstruction::BranchOnCond: {
768768
if (Part != 0)
769769
break;
770+
770771
Value *Cond = State.get(getOperand(0), VPIteration(Part, 0));
771772
VPRegionBlock *ParentRegion = getParent()->getParent();
772773
VPBasicBlock *Header = ParentRegion->getEntryBasicBlock();
@@ -898,6 +899,28 @@ void VPInstruction::setFastMathFlags(FastMathFlags FMFNew) {
898899
void VPlan::prepareToExecute(Value *TripCountV, Value *VectorTripCountV,
899900
Value *CanonicalIVStartValue,
900901
VPTransformState &State) {
902+
903+
VPBasicBlock *ExitingVPBB = getVectorLoopRegion()->getExitingBasicBlock();
904+
auto *Term = dyn_cast<VPInstruction>(&ExitingVPBB->back());
905+
// Try to simplify BranchOnCount to 'BranchOnCond true' if TC <= VF * UF when
906+
// preparing to execute the plan for the main vector loop.
907+
if (!CanonicalIVStartValue && Term &&
908+
Term->getOpcode() == VPInstruction::BranchOnCount &&
909+
isa<ConstantInt>(TripCountV)) {
910+
ConstantInt *C = cast<ConstantInt>(TripCountV);
911+
uint64_t TCVal = C->getZExtValue();
912+
if (TCVal && TCVal <= State.VF.getKnownMinValue() * State.UF) {
913+
auto *BOC =
914+
new VPInstruction(VPInstruction::BranchOnCond,
915+
{getOrAddExternalDef(State.Builder.getTrue())});
916+
Term->eraseFromParent();
917+
ExitingVPBB->appendRecipe(BOC);
918+
// TODO: Further simplifications are possible
919+
// 1. Replace inductions with constants.
920+
// 2. Replace vector loop region with VPBasicBlock.
921+
}
922+
}
923+
901924
// Check if the trip count is needed, and if so build it.
902925
if (TripCount && TripCount->getNumUsers()) {
903926
for (unsigned Part = 0, UF = State.UF; Part < UF; ++Part)

llvm/test/Transforms/LoopVectorize/AArch64/sve-low-trip-count.ll

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,7 @@ define void @trip5_i8(i8* noalias nocapture noundef %dst, i8* noalias nocapture
4747
; CHECK: [[VSCALE:%.*]] = call i64 @llvm.vscale.i64()
4848
; CHECK-NEXT: [[VF:%.*]] = mul i64 [[VSCALE]], 16
4949
; CHECK-NEXT: [[INDEX_NEXT]] = add i64 [[INDEX]], [[VF]]
50-
; CHECK-NEXT: [[COND:%.*]] = icmp eq i64 [[INDEX_NEXT]], {{%.*}}
51-
; CHECK-NEXT: br i1 [[COND]], label %middle.block, label %vector.body
50+
; CHECK-NEXT: br i1 true, label %middle.block, label %vector.body
5251
;
5352
entry:
5453
br label %for.body

llvm/test/Transforms/LoopVectorize/X86/constant-fold.ll

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,7 @@ define void @f1() {
2626
; CHECK-NEXT: [[TMP4:%.*]] = bitcast i16** [[TMP3]] to <2 x i16*>*
2727
; CHECK-NEXT: store <2 x i16*> <i16* getelementptr inbounds ([1 x %rec8], [1 x %rec8]* @a, i32 0, i32 0, i32 0), i16* getelementptr inbounds ([1 x %rec8], [1 x %rec8]* @a, i32 0, i32 0, i32 0)>, <2 x i16*>* [[TMP4]], align 8
2828
; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i32 [[INDEX]], 2
29-
; CHECK-NEXT: [[TMP5:%.*]] = icmp eq i32 [[INDEX_NEXT]], 2
30-
; CHECK-NEXT: br i1 [[TMP5]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], [[LOOP0:!llvm.loop !.*]]
29+
; CHECK-NEXT: br i1 true, label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], [[LOOP0:!llvm.loop !.*]]
3130
; CHECK: middle.block:
3231
; CHECK-NEXT: [[CMP_N:%.*]] = icmp eq i32 2, 2
3332
; CHECK-NEXT: br i1 [[CMP_N]], label [[BB3:%.*]], label [[SCALAR_PH]]

llvm/test/Transforms/LoopVectorize/X86/outer_loop_test1_no_explicit_vect_width.ll

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,10 +71,9 @@
7171
; AVX: br i1 %[[InnerCond]], label %[[ForInc]], label %[[InnerLoop]]
7272

7373
; AVX: [[ForInc]]:
74-
; AVX: %[[IndNext]] = add nuw i64 %[[Ind]], 8
7574
; AVX: %[[VecIndNext]] = add <8 x i64> %[[VecInd]], <i64 8, i64 8, i64 8, i64 8, i64 8, i64 8, i64 8, i64 8>
76-
; AVX: %[[Cmp:.*]] = icmp eq i64 %[[IndNext]], 8
77-
; AVX: br i1 %[[Cmp]], label %middle.block, label %vector.body
75+
; AVX: %[[IndNext]] = add nuw i64 %[[Ind]], 8
76+
; AVX: br i1 true, label %middle.block, label %vector.body
7877

7978
@arr2 = external global [8 x i32], align 16
8079
@arr = external global [8 x [8 x i32]], align 16

llvm/test/Transforms/LoopVectorize/X86/pr34438.ll

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,7 @@ define void @small_tc(float* noalias nocapture %A, float* noalias nocapture read
3030
; CHECK-NEXT: [[TMP8:%.*]] = bitcast float* [[TMP5]] to <8 x float>*
3131
; CHECK-NEXT: store <8 x float> [[TMP7]], <8 x float>* [[TMP8]], align 4, !llvm.access.group !0
3232
; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 8
33-
; CHECK-NEXT: [[TMP9:%.*]] = icmp eq i64 [[INDEX_NEXT]], 8
34-
; CHECK-NEXT: br i1 [[TMP9]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], [[LOOP1:!llvm.loop !.*]]
33+
; CHECK-NEXT: br i1 true, label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], [[LOOP1:!llvm.loop !.*]]
3534
; CHECK: middle.block:
3635
; CHECK-NEXT: [[CMP_N:%.*]] = icmp eq i64 8, 8
3736
; CHECK-NEXT: br i1 [[CMP_N]], label [[FOR_END:%.*]], label [[SCALAR_PH]]

llvm/test/Transforms/LoopVectorize/X86/pr42674.ll

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -9,26 +9,18 @@
99
define zeroext i8 @sum() {
1010
; CHECK-LABEL: @sum(
1111
; CHECK-NEXT: iter.check:
12-
; CHECK-NEXT: br label [[VECTOR_BODY:%.*]]
13-
; CHECK: vector.body:
14-
; CHECK-NEXT: [[INDEX:%.*]] = phi i64 [ 0, [[ENTRY:%.*]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
15-
; CHECK-NEXT: [[VEC_PHI:%.*]] = phi <64 x i8> [ zeroinitializer, [[ENTRY]] ], [ [[TMP4:%.*]], [[VECTOR_BODY]] ]
16-
; CHECK-NEXT: [[VEC_PHI1:%.*]] = phi <64 x i8> [ zeroinitializer, [[ENTRY]] ], [ [[TMP5:%.*]], [[VECTOR_BODY]] ]
17-
; CHECK-NEXT: [[TMP0:%.*]] = getelementptr inbounds [128 x i8], [128 x i8]* @bytes, i64 0, i64 [[INDEX]]
12+
; CHECK-NEXT: [[TMP0:%.*]] = getelementptr inbounds [128 x i8], [128 x i8]* @bytes, i64 0, i64 0
1813
; CHECK-NEXT: [[TMP1:%.*]] = bitcast i8* [[TMP0]] to <64 x i8>*
1914
; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load <64 x i8>, <64 x i8>* [[TMP1]], align 16
2015
; CHECK-NEXT: [[TMP2:%.*]] = getelementptr inbounds i8, i8* [[TMP0]], i64 64
2116
; CHECK-NEXT: [[TMP3:%.*]] = bitcast i8* [[TMP2]] to <64 x i8>*
2217
; CHECK-NEXT: [[WIDE_LOAD2:%.*]] = load <64 x i8>, <64 x i8>* [[TMP3]], align 16
23-
; CHECK-NEXT: [[TMP4]] = add <64 x i8> [[WIDE_LOAD]], [[VEC_PHI]]
24-
; CHECK-NEXT: [[TMP5]] = add <64 x i8> [[WIDE_LOAD2]], [[VEC_PHI1]]
25-
; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 128
26-
; CHECK-NEXT: [[TMP6:%.*]] = icmp eq i64 [[INDEX]], 0
27-
; CHECK-NEXT: br i1 [[TMP6]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop !0
28-
; CHECK: middle.block:
18+
; CHECK-NEXT: [[TMP4:%.*]] = add <64 x i8> [[WIDE_LOAD]], zeroinitializer
19+
; CHECK-NEXT: [[TMP5:%.*]] = add <64 x i8> [[WIDE_LOAD2]], zeroinitializer
20+
; CHECK-NEXT: [[INDEX_NEXT:%.*]] = add nuw i64 0, 128
2921
; CHECK-NEXT: [[BIN_RDX:%.*]] = add <64 x i8> [[TMP5]], [[TMP4]]
30-
; CHECK-NEXT: [[TMP7:%.*]] = call i8 @llvm.vector.reduce.add.v64i8(<64 x i8> [[BIN_RDX]])
31-
; CHECK-NEXT: ret i8 [[TMP7]]
22+
; CHECK-NEXT: [[TMP6:%.*]] = call i8 @llvm.vector.reduce.add.v64i8(<64 x i8> [[BIN_RDX]])
23+
; CHECK-NEXT: ret i8 [[TMP6]]
3224
;
3325
entry:
3426
br label %for.body

0 commit comments

Comments
 (0)