Skip to content

Commit 9745c13

Browse files
authored
[X86][BF16] Improve float -> bfloat lowering under AVX512BF16 and AVXNECONVERT (#78042)
1 parent 46a395d commit 9745c13

File tree

3 files changed

+388
-487
lines changed

3 files changed

+388
-487
lines changed

llvm/lib/Target/X86/X86ISelLowering.cpp

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21523,9 +21523,19 @@ static SDValue LowerFP_TO_FP16(SDValue Op, SelectionDAG &DAG) {
2152321523
SDValue X86TargetLowering::LowerFP_TO_BF16(SDValue Op,
2152421524
SelectionDAG &DAG) const {
2152521525
SDLoc DL(Op);
21526+
21527+
MVT SVT = Op.getOperand(0).getSimpleValueType();
21528+
if (SVT == MVT::f32 && (Subtarget.hasBF16() || Subtarget.hasAVXNECONVERT())) {
21529+
SDValue Res;
21530+
Res = DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, MVT::v4f32, Op.getOperand(0));
21531+
Res = DAG.getNode(X86ISD::CVTNEPS2BF16, DL, MVT::v8bf16, Res);
21532+
Res = DAG.getBitcast(MVT::v8i16, Res);
21533+
return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::i16, Res,
21534+
DAG.getIntPtrConstant(0, DL));
21535+
}
21536+
2152621537
MakeLibCallOptions CallOptions;
21527-
RTLIB::Libcall LC =
21528-
RTLIB::getFPROUND(Op.getOperand(0).getValueType(), MVT::bf16);
21538+
RTLIB::Libcall LC = RTLIB::getFPROUND(SVT, MVT::bf16);
2152921539
SDValue Res =
2153021540
makeLibCall(DAG, LC, MVT::f16, Op.getOperand(0), CallOptions, DL).first;
2153121541
return DAG.getBitcast(MVT::i16, Res);

llvm/lib/Target/X86/X86InstrSSE.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8331,6 +8331,10 @@ let Predicates = [HasAVXNECONVERT] in {
83318331
f256mem>, T8;
83328332
defm VCVTNEPS2BF16 : VCVTNEPS2BF16_BASE, VEX, T8, XS, ExplicitVEXPrefix;
83338333

8334+
def : Pat<(v8bf16 (X86cvtneps2bf16 (v4f32 VR128X:$src))),
8335+
(VCVTNEPS2BF16rr VR128:$src)>;
8336+
def : Pat<(v8bf16 (X86cvtneps2bf16 (loadv4f32 addr:$src))),
8337+
(VCVTNEPS2BF16rm addr:$src)>;
83348338
def : Pat<(v8bf16 (X86vfpround (v8f32 VR256:$src))),
83358339
(VCVTNEPS2BF16Yrr VR256:$src)>;
83368340
def : Pat<(v8bf16 (X86vfpround (loadv8f32 addr:$src))),

0 commit comments

Comments
 (0)