Skip to content

Commit 401f6e1

Browse files
committed
[LoopVectorize] Add cost of generating tail-folding mask to the loop
At the moment if we decide to enable tail-folding we do not include the cost of generating the mask per VF. This can mean we make some poor choices of VF, which is definitely true for SVE-enabled AArch64 targets where mask generation for fixed-width vectors is more expensive than for scalable vectors. I've added a VPInstruction::computeCost function to return the costs of the ActiveLaneMask and ExplicitVectorLength operations. Unfortunately, in order to prevent asserts firing I've also had to duplicate the same code in the legacy cost model to make sure the chosen VFs match up. I've wrapped this up in a ifndef NDEBUG for now. New tests added: Transforms/LoopVectorize/AArch64/sve-tail-folding-cost.ll Transforms/LoopVectorize/RISCV/tail-folding-cost.ll
1 parent f6b1b91 commit 401f6e1

11 files changed

+475
-353
lines changed

llvm/lib/Transforms/Vectorize/LoopVectorize.cpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5610,6 +5610,31 @@ InstructionCost LoopVectorizationCostModel::expectedCost(ElementCount VF) {
56105610
Cost += BlockCost;
56115611
}
56125612

5613+
#ifndef NDEBUG
5614+
// TODO: We're effectively having to duplicate the code from
5615+
// VPInstruction::computeCost, which is ugly. This isn't meant to be a fully
5616+
// accurate representation of the cost of tail-folding - it exists purely to
5617+
// stop asserts firing when the legacy cost doesn't match the VPlan cost.
5618+
if (!VF.isScalar() && foldTailByMasking()) {
5619+
TailFoldingStyle Style = getTailFoldingStyle();
5620+
LLVMContext &Context = TheLoop->getHeader()->getContext();
5621+
Type *I1Ty = IntegerType::getInt1Ty(Context);
5622+
Type *IndTy = Legal->getWidestInductionType();
5623+
TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
5624+
if (Style == TailFoldingStyle::DataWithEVL) {
5625+
Type *I32Ty = IntegerType::getInt32Ty(Context);
5626+
IntrinsicCostAttributes Attrs(Intrinsic::experimental_get_vector_length,
5627+
I32Ty, {IndTy, I32Ty, I1Ty});
5628+
Cost += TTI.getIntrinsicInstrCost(Attrs, CostKind);
5629+
} else if (useActiveLaneMask(Style)) {
5630+
VectorType *RetTy = VectorType::get(I1Ty, VF);
5631+
IntrinsicCostAttributes Attrs(Intrinsic::get_active_lane_mask, RetTy,
5632+
{IndTy, IndTy});
5633+
Cost += TTI.getIntrinsicInstrCost(Attrs, CostKind);
5634+
}
5635+
}
5636+
#endif
5637+
56135638
return Cost;
56145639
}
56155640

llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -801,6 +801,22 @@ InstructionCost VPInstruction::computeCost(ElementCount VF,
801801
cast<VectorType>(VectorTy), Mask,
802802
Ctx.CostKind, VF.getKnownMinValue() - 1);
803803
}
804+
case VPInstruction::ActiveLaneMask: {
805+
Type *Arg0Ty = Ctx.Types.inferScalarType(getOperand(0));
806+
Type *Arg1Ty = Ctx.Types.inferScalarType(getOperand(1));
807+
Type *RetTy = toVectorTy(Type::getInt1Ty(Ctx.LLVMCtx), VF);
808+
IntrinsicCostAttributes Attrs(Intrinsic::get_active_lane_mask, RetTy,
809+
{Arg0Ty, Arg1Ty});
810+
return Ctx.TTI.getIntrinsicInstrCost(Attrs, Ctx.CostKind);
811+
}
812+
case VPInstruction::ExplicitVectorLength: {
813+
Type *Arg0Ty = Ctx.Types.inferScalarType(getOperand(0));
814+
Type *I32Ty = Type::getInt32Ty(Ctx.LLVMCtx);
815+
Type *I1Ty = Type::getInt1Ty(Ctx.LLVMCtx);
816+
IntrinsicCostAttributes Attrs(Intrinsic::experimental_get_vector_length,
817+
I32Ty, {Arg0Ty, I32Ty, I1Ty});
818+
return Ctx.TTI.getIntrinsicInstrCost(Attrs, Ctx.CostKind);
819+
}
804820
default:
805821
// TODO: Compute cost other VPInstructions once the legacy cost model has
806822
// been retired.

llvm/test/Transforms/LoopVectorize/AArch64/conditional-branches-cost.ll

Lines changed: 35 additions & 134 deletions
Large diffs are not rendered by default.

llvm/test/Transforms/LoopVectorize/AArch64/induction-costs-sve.ll

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -99,49 +99,49 @@ define void @iv_casts(ptr %dst, ptr %src, i32 %x, i64 %N) #0 {
9999
; PRED-NEXT: br i1 false, label %[[SCALAR_PH:.*]], label %[[VECTOR_MEMCHECK:.*]]
100100
; PRED: [[VECTOR_MEMCHECK]]:
101101
; PRED-NEXT: [[TMP1:%.*]] = call i64 @llvm.vscale.i64()
102-
; PRED-NEXT: [[TMP2:%.*]] = mul i64 [[TMP1]], 8
102+
; PRED-NEXT: [[TMP2:%.*]] = mul i64 [[TMP1]], 16
103103
; PRED-NEXT: [[TMP3:%.*]] = sub i64 [[DST1]], [[SRC2]]
104104
; PRED-NEXT: [[DIFF_CHECK:%.*]] = icmp ult i64 [[TMP3]], [[TMP2]]
105105
; PRED-NEXT: br i1 [[DIFF_CHECK]], label %[[SCALAR_PH]], label %[[VECTOR_PH:.*]]
106106
; PRED: [[VECTOR_PH]]:
107107
; PRED-NEXT: [[TMP4:%.*]] = call i64 @llvm.vscale.i64()
108-
; PRED-NEXT: [[TMP5:%.*]] = mul i64 [[TMP4]], 8
108+
; PRED-NEXT: [[TMP5:%.*]] = mul i64 [[TMP4]], 16
109109
; PRED-NEXT: [[TMP8:%.*]] = sub i64 [[TMP5]], 1
110110
; PRED-NEXT: [[N_RND_UP:%.*]] = add i64 [[TMP0]], [[TMP8]]
111111
; PRED-NEXT: [[N_MOD_VF:%.*]] = urem i64 [[N_RND_UP]], [[TMP5]]
112112
; PRED-NEXT: [[N_VEC:%.*]] = sub i64 [[N_RND_UP]], [[N_MOD_VF]]
113113
; PRED-NEXT: [[TMP9:%.*]] = call i64 @llvm.vscale.i64()
114-
; PRED-NEXT: [[TMP10:%.*]] = mul i64 [[TMP9]], 8
115-
; PRED-NEXT: [[BROADCAST_SPLATINSERT:%.*]] = insertelement <vscale x 8 x i32> poison, i32 [[X]], i64 0
116-
; PRED-NEXT: [[BROADCAST_SPLAT:%.*]] = shufflevector <vscale x 8 x i32> [[BROADCAST_SPLATINSERT]], <vscale x 8 x i32> poison, <vscale x 8 x i32> zeroinitializer
114+
; PRED-NEXT: [[TMP10:%.*]] = mul i64 [[TMP9]], 16
115+
; PRED-NEXT: [[BROADCAST_SPLATINSERT:%.*]] = insertelement <vscale x 16 x i32> poison, i32 [[X]], i64 0
116+
; PRED-NEXT: [[BROADCAST_SPLAT:%.*]] = shufflevector <vscale x 16 x i32> [[BROADCAST_SPLATINSERT]], <vscale x 16 x i32> poison, <vscale x 16 x i32> zeroinitializer
117117
; PRED-NEXT: [[TMP11:%.*]] = call i64 @llvm.vscale.i64()
118-
; PRED-NEXT: [[TMP12:%.*]] = mul i64 [[TMP11]], 8
118+
; PRED-NEXT: [[TMP12:%.*]] = mul i64 [[TMP11]], 16
119119
; PRED-NEXT: [[TMP13:%.*]] = sub i64 [[TMP0]], [[TMP12]]
120120
; PRED-NEXT: [[TMP14:%.*]] = icmp ugt i64 [[TMP0]], [[TMP12]]
121121
; PRED-NEXT: [[TMP15:%.*]] = select i1 [[TMP14]], i64 [[TMP13]], i64 0
122-
; PRED-NEXT: [[ACTIVE_LANE_MASK_ENTRY:%.*]] = call <vscale x 8 x i1> @llvm.get.active.lane.mask.nxv8i1.i64(i64 0, i64 [[TMP0]])
123-
; PRED-NEXT: [[TMP16:%.*]] = trunc <vscale x 8 x i32> [[BROADCAST_SPLAT]] to <vscale x 8 x i16>
122+
; PRED-NEXT: [[ACTIVE_LANE_MASK_ENTRY:%.*]] = call <vscale x 16 x i1> @llvm.get.active.lane.mask.nxv16i1.i64(i64 0, i64 [[TMP0]])
123+
; PRED-NEXT: [[TMP16:%.*]] = trunc <vscale x 16 x i32> [[BROADCAST_SPLAT]] to <vscale x 16 x i16>
124124
; PRED-NEXT: br label %[[VECTOR_BODY:.*]]
125125
; PRED: [[VECTOR_BODY]]:
126126
; PRED-NEXT: [[INDEX:%.*]] = phi i64 [ 0, %[[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], %[[VECTOR_BODY]] ]
127-
; PRED-NEXT: [[ACTIVE_LANE_MASK:%.*]] = phi <vscale x 8 x i1> [ [[ACTIVE_LANE_MASK_ENTRY]], %[[VECTOR_PH]] ], [ [[ACTIVE_LANE_MASK_NEXT:%.*]], %[[VECTOR_BODY]] ]
127+
; PRED-NEXT: [[ACTIVE_LANE_MASK:%.*]] = phi <vscale x 16 x i1> [ [[ACTIVE_LANE_MASK_ENTRY]], %[[VECTOR_PH]] ], [ [[ACTIVE_LANE_MASK_NEXT:%.*]], %[[VECTOR_BODY]] ]
128128
; PRED-NEXT: [[TMP17:%.*]] = add i64 [[INDEX]], 0
129129
; PRED-NEXT: [[TMP18:%.*]] = getelementptr i8, ptr [[SRC]], i64 [[TMP17]]
130130
; PRED-NEXT: [[TMP19:%.*]] = getelementptr i8, ptr [[TMP18]], i32 0
131-
; PRED-NEXT: [[WIDE_MASKED_LOAD:%.*]] = call <vscale x 8 x i8> @llvm.masked.load.nxv8i8.p0(ptr [[TMP19]], i32 1, <vscale x 8 x i1> [[ACTIVE_LANE_MASK]], <vscale x 8 x i8> poison)
132-
; PRED-NEXT: [[TMP20:%.*]] = zext <vscale x 8 x i8> [[WIDE_MASKED_LOAD]] to <vscale x 8 x i16>
133-
; PRED-NEXT: [[TMP21:%.*]] = mul <vscale x 8 x i16> [[TMP20]], [[TMP16]]
134-
; PRED-NEXT: [[TMP22:%.*]] = zext <vscale x 8 x i8> [[WIDE_MASKED_LOAD]] to <vscale x 8 x i16>
135-
; PRED-NEXT: [[TMP23:%.*]] = or <vscale x 8 x i16> [[TMP21]], [[TMP22]]
136-
; PRED-NEXT: [[TMP24:%.*]] = lshr <vscale x 8 x i16> [[TMP23]], trunc (<vscale x 8 x i32> splat (i32 1) to <vscale x 8 x i16>)
137-
; PRED-NEXT: [[TMP25:%.*]] = trunc <vscale x 8 x i16> [[TMP24]] to <vscale x 8 x i8>
131+
; PRED-NEXT: [[WIDE_MASKED_LOAD:%.*]] = call <vscale x 16 x i8> @llvm.masked.load.nxv16i8.p0(ptr [[TMP19]], i32 1, <vscale x 16 x i1> [[ACTIVE_LANE_MASK]], <vscale x 16 x i8> poison)
132+
; PRED-NEXT: [[TMP24:%.*]] = zext <vscale x 16 x i8> [[WIDE_MASKED_LOAD]] to <vscale x 16 x i16>
133+
; PRED-NEXT: [[TMP25:%.*]] = mul <vscale x 16 x i16> [[TMP24]], [[TMP16]]
134+
; PRED-NEXT: [[TMP20:%.*]] = zext <vscale x 16 x i8> [[WIDE_MASKED_LOAD]] to <vscale x 16 x i16>
135+
; PRED-NEXT: [[TMP21:%.*]] = or <vscale x 16 x i16> [[TMP25]], [[TMP20]]
136+
; PRED-NEXT: [[TMP22:%.*]] = lshr <vscale x 16 x i16> [[TMP21]], trunc (<vscale x 16 x i32> splat (i32 1) to <vscale x 16 x i16>)
137+
; PRED-NEXT: [[TMP23:%.*]] = trunc <vscale x 16 x i16> [[TMP22]] to <vscale x 16 x i8>
138138
; PRED-NEXT: [[TMP26:%.*]] = getelementptr i8, ptr [[DST]], i64 [[TMP17]]
139139
; PRED-NEXT: [[TMP27:%.*]] = getelementptr i8, ptr [[TMP26]], i32 0
140-
; PRED-NEXT: call void @llvm.masked.store.nxv8i8.p0(<vscale x 8 x i8> [[TMP25]], ptr [[TMP27]], i32 1, <vscale x 8 x i1> [[ACTIVE_LANE_MASK]])
140+
; PRED-NEXT: call void @llvm.masked.store.nxv16i8.p0(<vscale x 16 x i8> [[TMP23]], ptr [[TMP27]], i32 1, <vscale x 16 x i1> [[ACTIVE_LANE_MASK]])
141141
; PRED-NEXT: [[INDEX_NEXT]] = add i64 [[INDEX]], [[TMP10]]
142-
; PRED-NEXT: [[ACTIVE_LANE_MASK_NEXT]] = call <vscale x 8 x i1> @llvm.get.active.lane.mask.nxv8i1.i64(i64 [[INDEX]], i64 [[TMP15]])
143-
; PRED-NEXT: [[TMP28:%.*]] = xor <vscale x 8 x i1> [[ACTIVE_LANE_MASK_NEXT]], splat (i1 true)
144-
; PRED-NEXT: [[TMP29:%.*]] = extractelement <vscale x 8 x i1> [[TMP28]], i32 0
142+
; PRED-NEXT: [[ACTIVE_LANE_MASK_NEXT]] = call <vscale x 16 x i1> @llvm.get.active.lane.mask.nxv16i1.i64(i64 [[INDEX]], i64 [[TMP15]])
143+
; PRED-NEXT: [[TMP28:%.*]] = xor <vscale x 16 x i1> [[ACTIVE_LANE_MASK_NEXT]], splat (i1 true)
144+
; PRED-NEXT: [[TMP29:%.*]] = extractelement <vscale x 16 x i1> [[TMP28]], i32 0
145145
; PRED-NEXT: br i1 [[TMP29]], label %[[MIDDLE_BLOCK:.*]], label %[[VECTOR_BODY]], !llvm.loop [[LOOP0:![0-9]+]]
146146
; PRED: [[MIDDLE_BLOCK]]:
147147
; PRED-NEXT: br i1 true, label %[[EXIT:.*]], label %[[SCALAR_PH]]

0 commit comments

Comments
 (0)