Skip to content

[X86] Improve __bf16 code generation #134859

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions llvm/lib/Target/X86/X86CallingConv.td
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,7 @@ def RetCC_X86_32_VectorCall : CallingConv<[
def RetCC_X86_64_C : CallingConv<[
// The X86-64 calling convention always returns FP values in XMM0.
CCIfType<[f16], CCAssignToReg<[XMM0, XMM1]>>,
CCIfType<[bf16], CCAssignToReg<[XMM0, XMM1]>>,
CCIfType<[f32], CCAssignToReg<[XMM0, XMM1]>>,
CCIfType<[f64], CCAssignToReg<[XMM0, XMM1]>>,
CCIfType<[f128], CCAssignToReg<[XMM0, XMM1]>>,
Expand Down Expand Up @@ -569,6 +570,10 @@ def CC_X86_64_C : CallingConv<[
CCIfSubtarget<"hasSSE1()",
CCAssignToReg<[XMM0, XMM1, XMM2, XMM3, XMM4, XMM5, XMM6, XMM7]>>>,

// The first 8 128-bits bf16 arguments are passed in XMM registers (part of AVX-512_BF16).
CCIfType<[bf16], CCIfSubtarget<"hasAVX512()",
CCAssignToReg<[XMM0, XMM1, XMM2, XMM3, XMM4, XMM5, XMM6, XMM7]>>>,

// The first 8 256-bit vector arguments are passed in YMM registers, unless
// this is a vararg function.
// FIXME: This isn't precisely correct; the x86-64 ABI document says that
Expand All @@ -586,7 +591,7 @@ def CC_X86_64_C : CallingConv<[

// Integer/FP values get stored in stack slots that are 8 bytes in size and
// 8-byte aligned if there are no more registers to hold them.
CCIfType<[i32, i64, f16, f32, f64], CCAssignToStack<8, 8>>,
CCIfType<[i32, i64, bf16, f16, f32, f64], CCAssignToStack<8, 8>>,

// Long doubles get stack slots whose size and alignment depends on the
// subtarget.
Expand Down Expand Up @@ -649,7 +654,7 @@ def CC_X86_Win64_C : CallingConv<[
CCIfType<[f64], CCIfNotSubtarget<"hasSSE1()", CCBitConvertToType<i64>>>,

// The first 4 FP/Vector arguments are passed in XMM registers.
CCIfType<[f16, f32, f64],
CCIfType<[bf16, f16, f32, f64],
CCAssignToRegWithShadow<[XMM0, XMM1, XMM2, XMM3],
[RCX , RDX , R8 , R9 ]>>,

Expand All @@ -672,7 +677,7 @@ def CC_X86_Win64_C : CallingConv<[

// Integer/FP values get stored in stack slots that are 8 bytes in size and
// 8-byte aligned if there are no more registers to hold them.
CCIfType<[i8, i16, i32, i64, f16, f32, f64], CCAssignToStack<8, 8>>
CCIfType<[i8, i16, i32, i64, bf16, f16, f32, f64], CCAssignToStack<8, 8>>
]>;

def CC_X86_Win64_VectorCall : CallingConv<[
Expand Down
5 changes: 4 additions & 1 deletion llvm/lib/Target/X86/X86FastISel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,8 @@ class X86FastISel final : public FastISel {
/// computed in an SSE register, not on the X87 floating point stack.
bool isScalarFPTypeInSSEReg(EVT VT) const {
return (VT == MVT::f64 && Subtarget->hasSSE2()) ||
(VT == MVT::f32 && Subtarget->hasSSE1()) || VT == MVT::f16;
(VT == MVT::f32 && Subtarget->hasSSE1()) || VT == MVT::f16 ||
VT == MVT::bf16;
}

bool isTypeLegal(Type *Ty, MVT &VT, bool AllowI1 = false);
Expand Down Expand Up @@ -2283,6 +2284,7 @@ bool X86FastISel::X86FastEmitPseudoSelect(MVT RetVT, const Instruction *I) {
case MVT::i16: Opc = X86::CMOV_GR16; break;
case MVT::i32: Opc = X86::CMOV_GR32; break;
case MVT::f16:
case MVT::bf16:
Opc = Subtarget->hasAVX512() ? X86::CMOV_FR16X : X86::CMOV_FR16; break;
case MVT::f32:
Opc = Subtarget->hasAVX512() ? X86::CMOV_FR32X : X86::CMOV_FR32; break;
Expand Down Expand Up @@ -3972,6 +3974,7 @@ Register X86FastISel::fastMaterializeFloatZero(const ConstantFP *CF) {
switch (VT.SimpleTy) {
default: return 0;
case MVT::f16:
case MVT::bf16:
Opc = HasAVX512 ? X86::AVX512_FsFLD0SH : X86::FsFLD0SH;
break;
case MVT::f32:
Expand Down
63 changes: 62 additions & 1 deletion llvm/lib/Target/X86/X86ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -661,10 +661,12 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
};

if (!Subtarget.useSoftFloat() && Subtarget.hasSSE2()) {
// f16, f32 and f64 use SSE.
// f16, bf16, f32 and f64 use SSE.
// Set up the FP register classes.
addRegisterClass(MVT::f16, Subtarget.hasAVX512() ? &X86::FR16XRegClass
: &X86::FR16RegClass);
addRegisterClass(MVT::bf16, Subtarget.hasAVX512() ? &X86::FR16XRegClass
: &X86::FR16RegClass);
addRegisterClass(MVT::f32, Subtarget.hasAVX512() ? &X86::FR32XRegClass
: &X86::FR32RegClass);
addRegisterClass(MVT::f64, Subtarget.hasAVX512() ? &X86::FR64XRegClass
Expand All @@ -676,6 +678,12 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
// non-optsize case.
setLoadExtAction(ISD::EXTLOAD, MVT::f64, MVT::f32, Expand);

// Set the operation action Custom for bitcast and conversion, and fall-back
// to software libcalls for the latter for the now.
setOperationAction(ISD::BITCAST, MVT::bf16, Custom);
setOperationAction(ISD::FP_EXTEND, MVT::bf16, Custom);
setOperationAction(ISD::FP_ROUND, MVT::bf16, Custom);

for (auto VT : { MVT::f32, MVT::f64 }) {
// Use ANDPD to simulate FABS.
setOperationAction(ISD::FABS, VT, Custom);
Expand Down Expand Up @@ -22060,6 +22068,31 @@ SDValue X86TargetLowering::LowerFP_EXTEND(SDValue Op, SelectionDAG &DAG) const {
return Res;
}

if (SVT == MVT::bf16 && VT == MVT::f32) {
TargetLowering::CallLoweringInfo CLI(DAG);
Chain = IsStrict ? Op.getOperand(0) : DAG.getEntryNode();

TargetLowering::ArgListTy Args;
TargetLowering::ArgListEntry Entry;
Entry.Node = In;
Entry.Ty = EVT(SVT).getTypeForEVT(*DAG.getContext());
Args.push_back(Entry);

SDValue Callee =
DAG.getExternalSymbol(getLibcallName(RTLIB::FPEXT_BF16_F32),
getPointerTy(DAG.getDataLayout()));
CLI.setDebugLoc(DL).setChain(Chain).setLibCallee(
CallingConv::C, EVT(VT).getTypeForEVT(*DAG.getContext()), Callee,
std::move(Args));

SDValue Res;
std::tie(Res, Chain) = LowerCallTo(CLI);
if (IsStrict)
Res = DAG.getMergeValues({Res, Chain}, DL);

return Res;
}

if (!SVT.isVector() || SVT.getVectorElementType() == MVT::bf16)
return Op;

Expand Down Expand Up @@ -22143,6 +22176,30 @@ SDValue X86TargetLowering::LowerFP_ROUND(SDValue Op, SelectionDAG &DAG) const {
((Subtarget.hasBF16() && Subtarget.hasVLX()) ||
Subtarget.hasAVXNECONVERT()))
return Op;

// Need a soft libcall if the target has not BF16.
if (SVT.getScalarType() == MVT::f32 || SVT.getScalarType() == MVT::f64) {
TargetLowering::CallLoweringInfo CLI(DAG);
Chain = IsStrict ? Op.getOperand(0) : DAG.getEntryNode();

TargetLowering::ArgListTy Args;
TargetLowering::ArgListEntry Entry;
Entry.Node = In;
Entry.Ty = EVT(SVT).getTypeForEVT(*DAG.getContext());
Args.push_back(Entry);
SDValue Callee = DAG.getExternalSymbol(
getLibcallName(SVT == MVT::f64 ? RTLIB::FPROUND_F64_BF16
: RTLIB::FPROUND_F32_BF16),
getPointerTy(DAG.getDataLayout()));
CLI.setDebugLoc(DL).setChain(Chain).setLibCallee(
CallingConv::C, EVT(MVT::bf16).getTypeForEVT(*DAG.getContext()),
Callee, std::move(Args));

SDValue Res;
std::tie(Res, Chain) = LowerCallTo(CLI);
return IsStrict ? DAG.getMergeValues({Res, Chain}, DL) : Res;
}

return SDValue();
}

Expand Down Expand Up @@ -32151,6 +32208,10 @@ static SDValue LowerBITCAST(SDValue Op, const X86Subtarget &Subtarget,
return DAG.getZExtOrTrunc(V, DL, DstVT);
}

// Bitcasts between f16 and bf16 should be legal.
if (DstVT == MVT::f16 || DstVT == MVT::bf16)
return Op;

assert((SrcVT == MVT::v2i32 || SrcVT == MVT::v4i16 || SrcVT == MVT::v8i8 ||
SrcVT == MVT::i64) && "Unexpected VT!");

Expand Down
10 changes: 5 additions & 5 deletions llvm/lib/Target/X86/X86InstrAVX10.td
Original file line number Diff line number Diff line change
Expand Up @@ -103,14 +103,14 @@ multiclass avx10_minmax_packed<string OpStr, AVX512VLVectorVTInfo VTI, SDNode Op
}

multiclass avx10_minmax_scalar<string OpStr, X86VectorVTInfo _, SDNode OpNode,
SDNode OpNodeSAE> {
SDNode OpNodeSAE, ValueType CT> {
let ExeDomain = _.ExeDomain, Predicates = [HasAVX10_2] in {
let mayRaiseFPException = 1 in {
let isCodeGenOnly = 1 in {
def rri : AVX512Ii8<0x53, MRMSrcReg, (outs _.FRC:$dst),
(ins _.FRC:$src1, _.FRC:$src2, i32u8imm:$src3),
!strconcat(OpStr, "\t{$src3, $src2, $src1|$src1, $src2, $src3}"),
[(set _.FRC:$dst, (OpNode _.FRC:$src1, _.FRC:$src2, (i32 timm:$src3)))]>,
[(set _.FRC:$dst, (OpNode (CT _.FRC:$src1), (CT _.FRC:$src2), (i32 timm:$src3)))]>,
Sched<[WriteFMAX]>;

def rmi : AVX512Ii8<0x53, MRMSrcMem, (outs _.FRC:$dst),
Expand Down Expand Up @@ -165,11 +165,11 @@ defm VMINMAXPS : avx10_minmax_packed<"vminmaxps", avx512vl_f32_info, X86vminmax>
avx10_minmax_packed_sae<"vminmaxps", avx512vl_f32_info, X86vminmaxSae>,
AVX512PDIi8Base, TA, EVEX_CD8<32, CD8VF>;

defm VMINMAXSD : avx10_minmax_scalar<"vminmaxsd", v2f64x_info, X86vminmaxs, X86vminmaxsSae>,
defm VMINMAXSD : avx10_minmax_scalar<"vminmaxsd", v2f64x_info, X86vminmaxs, X86vminmaxsSae, f64>,
AVX512AIi8Base, VEX_LIG, EVEX, VVVV, EVEX_CD8<64, CD8VT1>, REX_W;
defm VMINMAXSH : avx10_minmax_scalar<"vminmaxsh", v8f16x_info, X86vminmaxs, X86vminmaxsSae>,
defm VMINMAXSH : avx10_minmax_scalar<"vminmaxsh", v8f16x_info, X86vminmaxs, X86vminmaxsSae, f16>,
AVX512PSIi8Base, VEX_LIG, EVEX, VVVV, EVEX_CD8<16, CD8VT1>, TA;
defm VMINMAXSS : avx10_minmax_scalar<"vminmaxss", v4f32x_info, X86vminmaxs, X86vminmaxsSae>,
defm VMINMAXSS : avx10_minmax_scalar<"vminmaxss", v4f32x_info, X86vminmaxs, X86vminmaxsSae, f32>,
AVX512AIi8Base, VEX_LIG, EVEX, VVVV, EVEX_CD8<32, CD8VT1>;

//-------------------------------------------------
Expand Down
Loading
Loading