Skip to content

Commit 054c23d

Browse files
authored
X86: Improve cost model of fp16 conversion (#113195)
Improve cost-modeling for x86 __fp16 conversions so the SLPVectorizer transforms the patterns: - Override `X86TTIImpl::getStoreMinimumVF` to report a minimum VF of 4 (SSE register can hold 4xfloat converted/stored to 4xf16) this is necessary as fp16 stores are neither modeled as trunc-stores nor can we mark direct Xxfp16 stores as legal as we generally expand fp16 operations). - Add missing cost entries to `X86TTIImpl::getCastInstrCost` conversion from/to fp16. Note that conversion from f64 to f16 is not supported by an X86 instruction.
1 parent 75c1c26 commit 054c23d

File tree

3 files changed

+650
-0
lines changed

3 files changed

+650
-0
lines changed

llvm/lib/Target/X86/X86TargetTransformInfo.cpp

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2296,7 +2296,10 @@ InstructionCost X86TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst,
22962296
{ ISD::FP_EXTEND, MVT::v8f64, MVT::v8f32, { 1, 1, 1, 1 } },
22972297
{ ISD::FP_EXTEND, MVT::v8f64, MVT::v16f32, { 3, 1, 1, 1 } },
22982298
{ ISD::FP_EXTEND, MVT::v16f64, MVT::v16f32, { 4, 1, 1, 1 } }, // 2*vcvtps2pd+vextractf64x4
2299+
{ ISD::FP_EXTEND, MVT::v16f32, MVT::v16f16, { 1, 1, 1, 1 } }, // vcvtph2ps
2300+
{ ISD::FP_EXTEND, MVT::v8f64, MVT::v8f16, { 2, 1, 1, 1 } }, // vcvtph2ps+vcvtps2pd
22992301
{ ISD::FP_ROUND, MVT::v8f32, MVT::v8f64, { 1, 1, 1, 1 } },
2302+
{ ISD::FP_ROUND, MVT::v16f16, MVT::v16f32, { 1, 1, 1, 1 } }, // vcvtps2ph
23002303

23012304
{ ISD::TRUNCATE, MVT::v2i1, MVT::v2i8, { 3, 1, 1, 1 } }, // sext+vpslld+vptestmd
23022305
{ ISD::TRUNCATE, MVT::v4i1, MVT::v4i8, { 3, 1, 1, 1 } }, // sext+vpslld+vptestmd
@@ -2973,6 +2976,17 @@ InstructionCost X86TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst,
29732976
{ ISD::TRUNCATE, MVT::v4i32, MVT::v2i64, { 1, 1, 1, 1 } }, // PSHUFD
29742977
};
29752978

2979+
static const TypeConversionCostKindTblEntry F16ConversionTbl[] = {
2980+
{ ISD::FP_ROUND, MVT::f16, MVT::f32, { 1, 1, 1, 1 } },
2981+
{ ISD::FP_ROUND, MVT::v8f16, MVT::v8f32, { 1, 1, 1, 1 } },
2982+
{ ISD::FP_ROUND, MVT::v4f16, MVT::v4f32, { 1, 1, 1, 1 } },
2983+
{ ISD::FP_EXTEND, MVT::f32, MVT::f16, { 1, 1, 1, 1 } },
2984+
{ ISD::FP_EXTEND, MVT::f64, MVT::f16, { 2, 1, 1, 1 } }, // vcvtph2ps+vcvtps2pd
2985+
{ ISD::FP_EXTEND, MVT::v8f32, MVT::v8f16, { 1, 1, 1, 1 } },
2986+
{ ISD::FP_EXTEND, MVT::v4f32, MVT::v4f16, { 1, 1, 1, 1 } },
2987+
{ ISD::FP_EXTEND, MVT::v4f64, MVT::v4f16, { 2, 1, 1, 1 } }, // vcvtph2ps+vcvtps2pd
2988+
};
2989+
29762990
// Attempt to map directly to (simple) MVT types to let us match custom entries.
29772991
EVT SrcTy = TLI->getValueType(DL, Src);
29782992
EVT DstTy = TLI->getValueType(DL, Dst);
@@ -3034,6 +3048,13 @@ InstructionCost X86TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst,
30343048
return *KindCost;
30353049
}
30363050

3051+
if (ST->hasF16C()) {
3052+
if (const auto *Entry = ConvertCostTableLookup(F16ConversionTbl, ISD,
3053+
SimpleDstTy, SimpleSrcTy))
3054+
if (auto KindCost = Entry->Cost[CostKind])
3055+
return *KindCost;
3056+
}
3057+
30373058
if (ST->hasSSE41()) {
30383059
if (const auto *Entry = ConvertCostTableLookup(SSE41ConversionTbl, ISD,
30393060
SimpleDstTy, SimpleSrcTy))
@@ -3107,6 +3128,13 @@ InstructionCost X86TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst,
31073128
if (auto KindCost = Entry->Cost[CostKind])
31083129
return std::max(LTSrc.first, LTDest.first) * *KindCost;
31093130

3131+
if (ST->hasF16C()) {
3132+
if (const auto *Entry = ConvertCostTableLookup(F16ConversionTbl, ISD,
3133+
LTDest.second, LTSrc.second))
3134+
if (auto KindCost = Entry->Cost[CostKind])
3135+
return std::max(LTSrc.first, LTDest.first) * *KindCost;
3136+
}
3137+
31103138
if (ST->hasSSE41())
31113139
if (const auto *Entry = ConvertCostTableLookup(SSE41ConversionTbl, ISD,
31123140
LTDest.second, LTSrc.second))
@@ -3146,6 +3174,11 @@ InstructionCost X86TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst,
31463174
TTI::CastContextHint::None, CostKind);
31473175
}
31483176

3177+
if (ISD == ISD::FP_ROUND && LTDest.second.getScalarType() == MVT::f16) {
3178+
// Conversion requires a libcall.
3179+
return InstructionCost::getInvalid();
3180+
}
3181+
31493182
// TODO: Allow non-throughput costs that aren't binary.
31503183
auto AdjustCost = [&CostKind](InstructionCost Cost,
31513184
InstructionCost N = 1) -> InstructionCost {
@@ -6923,6 +6956,14 @@ bool X86TTIImpl::isVectorShiftByScalarCheap(Type *Ty) const {
69236956
return true;
69246957
}
69256958

6959+
unsigned X86TTIImpl::getStoreMinimumVF(unsigned VF, Type *ScalarMemTy,
6960+
Type *ScalarValTy) const {
6961+
if (ST->hasF16C() && ScalarMemTy->isHalfTy()) {
6962+
return 4;
6963+
}
6964+
return BaseT::getStoreMinimumVF(VF, ScalarMemTy, ScalarValTy);
6965+
}
6966+
69266967
bool X86TTIImpl::isProfitableToSinkOperands(Instruction *I,
69276968
SmallVectorImpl<Use *> &Ops) const {
69286969
using namespace llvm::PatternMatch;

llvm/lib/Target/X86/X86TargetTransformInfo.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,9 @@ class X86TTIImpl : public BasicTTIImplBase<X86TTIImpl> {
302302

303303
bool isVectorShiftByScalarCheap(Type *Ty) const;
304304

305+
unsigned getStoreMinimumVF(unsigned VF, Type *ScalarMemTy,
306+
Type *ScalarValTy) const;
307+
305308
private:
306309
bool supportsGather() const;
307310
InstructionCost getGSVectorCost(unsigned Opcode, TTI::TargetCostKind CostKind,

0 commit comments

Comments
 (0)