Skip to content

[X86] Improve __bf16 code generation #134859

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion llvm/lib/Target/X86/X86ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -661,10 +661,12 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
};

if (!Subtarget.useSoftFloat() && Subtarget.hasSSE2()) {
// f16, f32 and f64 use SSE.
// f16, bf16, f32 and f64 use SSE.
// Set up the FP register classes.
addRegisterClass(MVT::f16, Subtarget.hasAVX512() ? &X86::FR16XRegClass
: &X86::FR16RegClass);
addRegisterClass(MVT::bf16, Subtarget.hasAVX512() ? &X86::FR16XRegClass
: &X86::FR16RegClass);
addRegisterClass(MVT::f32, Subtarget.hasAVX512() ? &X86::FR32XRegClass
: &X86::FR32RegClass);
addRegisterClass(MVT::f64, Subtarget.hasAVX512() ? &X86::FR64XRegClass
Expand All @@ -676,6 +678,10 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
// non-optsize case.
setLoadExtAction(ISD::EXTLOAD, MVT::f64, MVT::f32, Expand);

// Set the operation action Custom for bitcast to do the customization
// later.
setOperationAction(ISD::BITCAST, MVT::bf16, Custom);

for (auto VT : { MVT::f32, MVT::f64 }) {
// Use ANDPD to simulate FABS.
setOperationAction(ISD::FABS, VT, Custom);
Expand Down Expand Up @@ -32151,6 +32157,10 @@ static SDValue LowerBITCAST(SDValue Op, const X86Subtarget &Subtarget,
return DAG.getZExtOrTrunc(V, DL, DstVT);
}

// Bitcasts between f16 and bf16 should be legal.
if (DstVT == MVT::f16 || DstVT == MVT::bf16)
return Op;

assert((SrcVT == MVT::v2i32 || SrcVT == MVT::v4i16 || SrcVT == MVT::v8i8 ||
SrcVT == MVT::i64) && "Unexpected VT!");

Expand Down
5 changes: 5 additions & 0 deletions llvm/lib/Target/X86/X86InstrAVX512.td
Original file line number Diff line number Diff line change
Expand Up @@ -2456,6 +2456,11 @@ let Predicates = [HasFP16] in {
(VCMPSHZrmi FR16X:$src1, addr:$src2, (X86cmpm_imm_commute timm:$cc))>;
}

let Predicates = [HasAVX512, HasBF16] in {
def : Pat<(f16 (bitconvert (bf16 FR16X:$src))), (f16 FR16X:$src)>;
def : Pat<(bf16 (bitconvert (f16 FR16X:$src))), (bf16 FR16X:$src)>;
}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@phoebewang Could we reuse FR16X for bf16 as well here? Changing FR16X definition to include bf16:


def FR16X : RegisterClass<"X86", [f16, bf16], 16, (add FR32X)> {let Size = 32;}

Seems to lead to a bunch of compilation issues due to ambiguity in instruction pattern :( Should we create a new class?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we can use FR16X for bf16. We used in this way for different vector types within the same size. It just requires updating affected patterns with explicit type cast.


// ----------------------------------------------------------------
// FPClass

Expand Down
Loading