@@ -2296,7 +2296,10 @@ InstructionCost X86TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst,
2296
2296
{ ISD::FP_EXTEND, MVT::v8f64, MVT::v8f32, { 1, 1, 1, 1 } },
2297
2297
{ ISD::FP_EXTEND, MVT::v8f64, MVT::v16f32, { 3, 1, 1, 1 } },
2298
2298
{ ISD::FP_EXTEND, MVT::v16f64, MVT::v16f32, { 4, 1, 1, 1 } }, // 2*vcvtps2pd+vextractf64x4
2299
+ { ISD::FP_EXTEND, MVT::v16f32, MVT::v16f16, { 1, 1, 1, 1 } }, // vcvtph2ps
2300
+ { ISD::FP_EXTEND, MVT::v8f64, MVT::v8f16, { 2, 1, 1, 1 } }, // vcvtph2ps+vcvtps2pd
2299
2301
{ ISD::FP_ROUND, MVT::v8f32, MVT::v8f64, { 1, 1, 1, 1 } },
2302
+ { ISD::FP_ROUND, MVT::v16f16, MVT::v16f32, { 1, 1, 1, 1 } }, // vcvtps2ph
2300
2303
2301
2304
{ ISD::TRUNCATE, MVT::v2i1, MVT::v2i8, { 3, 1, 1, 1 } }, // sext+vpslld+vptestmd
2302
2305
{ ISD::TRUNCATE, MVT::v4i1, MVT::v4i8, { 3, 1, 1, 1 } }, // sext+vpslld+vptestmd
@@ -2973,6 +2976,17 @@ InstructionCost X86TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst,
2973
2976
{ ISD::TRUNCATE, MVT::v4i32, MVT::v2i64, { 1, 1, 1, 1 } }, // PSHUFD
2974
2977
};
2975
2978
2979
+ static const TypeConversionCostKindTblEntry F16ConversionTbl[] = {
2980
+ { ISD::FP_ROUND, MVT::f16, MVT::f32, { 1, 1, 1, 1 } },
2981
+ { ISD::FP_ROUND, MVT::v8f16, MVT::v8f32, { 1, 1, 1, 1 } },
2982
+ { ISD::FP_ROUND, MVT::v4f16, MVT::v4f32, { 1, 1, 1, 1 } },
2983
+ { ISD::FP_EXTEND, MVT::f32, MVT::f16, { 1, 1, 1, 1 } },
2984
+ { ISD::FP_EXTEND, MVT::f64, MVT::f16, { 2, 1, 1, 1 } }, // vcvtph2ps+vcvtps2pd
2985
+ { ISD::FP_EXTEND, MVT::v8f32, MVT::v8f16, { 1, 1, 1, 1 } },
2986
+ { ISD::FP_EXTEND, MVT::v4f32, MVT::v4f16, { 1, 1, 1, 1 } },
2987
+ { ISD::FP_EXTEND, MVT::v4f64, MVT::v4f16, { 2, 1, 1, 1 } }, // vcvtph2ps+vcvtps2pd
2988
+ };
2989
+
2976
2990
// Attempt to map directly to (simple) MVT types to let us match custom entries.
2977
2991
EVT SrcTy = TLI->getValueType(DL, Src);
2978
2992
EVT DstTy = TLI->getValueType(DL, Dst);
@@ -3034,6 +3048,13 @@ InstructionCost X86TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst,
3034
3048
return *KindCost;
3035
3049
}
3036
3050
3051
+ if (ST->hasF16C()) {
3052
+ if (const auto *Entry = ConvertCostTableLookup(F16ConversionTbl, ISD,
3053
+ SimpleDstTy, SimpleSrcTy))
3054
+ if (auto KindCost = Entry->Cost[CostKind])
3055
+ return *KindCost;
3056
+ }
3057
+
3037
3058
if (ST->hasSSE41()) {
3038
3059
if (const auto *Entry = ConvertCostTableLookup(SSE41ConversionTbl, ISD,
3039
3060
SimpleDstTy, SimpleSrcTy))
@@ -3107,6 +3128,13 @@ InstructionCost X86TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst,
3107
3128
if (auto KindCost = Entry->Cost[CostKind])
3108
3129
return std::max(LTSrc.first, LTDest.first) * *KindCost;
3109
3130
3131
+ if (ST->hasF16C()) {
3132
+ if (const auto *Entry = ConvertCostTableLookup(F16ConversionTbl, ISD,
3133
+ LTDest.second, LTSrc.second))
3134
+ if (auto KindCost = Entry->Cost[CostKind])
3135
+ return std::max(LTSrc.first, LTDest.first) * *KindCost;
3136
+ }
3137
+
3110
3138
if (ST->hasSSE41())
3111
3139
if (const auto *Entry = ConvertCostTableLookup(SSE41ConversionTbl, ISD,
3112
3140
LTDest.second, LTSrc.second))
@@ -3146,6 +3174,11 @@ InstructionCost X86TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst,
3146
3174
TTI::CastContextHint::None, CostKind);
3147
3175
}
3148
3176
3177
+ if (ISD == ISD::FP_ROUND && LTDest.second.getScalarType() == MVT::f16) {
3178
+ // Conversion requires a libcall.
3179
+ return InstructionCost::getInvalid();
3180
+ }
3181
+
3149
3182
// TODO: Allow non-throughput costs that aren't binary.
3150
3183
auto AdjustCost = [&CostKind](InstructionCost Cost,
3151
3184
InstructionCost N = 1) -> InstructionCost {
@@ -6923,6 +6956,14 @@ bool X86TTIImpl::isVectorShiftByScalarCheap(Type *Ty) const {
6923
6956
return true;
6924
6957
}
6925
6958
6959
+ unsigned X86TTIImpl::getStoreMinimumVF(unsigned VF, Type *ScalarMemTy,
6960
+ Type *ScalarValTy) const {
6961
+ if (ST->hasF16C() && ScalarMemTy->isHalfTy()) {
6962
+ return 4;
6963
+ }
6964
+ return BaseT::getStoreMinimumVF(VF, ScalarMemTy, ScalarValTy);
6965
+ }
6966
+
6926
6967
bool X86TTIImpl::isProfitableToSinkOperands(Instruction *I,
6927
6968
SmallVectorImpl<Use *> &Ops) const {
6928
6969
using namespace llvm::PatternMatch;
0 commit comments