Skip to content

Commit 82bb6e0

Browse files
committed
AMDGPU: Define v_mfma_f32_32x32x16_bf16 for gfx950
Unlike the existing gfx940 intrinsics using short/i16 in place of bfloat, this uses the natural bfloat type.
1 parent f3682aa commit 82bb6e0

File tree

12 files changed

+596
-8
lines changed

12 files changed

+596
-8
lines changed

clang/include/clang/Basic/BuiltinsAMDGPU.def

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -437,6 +437,8 @@ TARGET_BUILTIN(__builtin_amdgcn_cvt_sr_fp8_f32, "ifiiIi", "nc", "fp8-conversion-
437437
TARGET_BUILTIN(__builtin_amdgcn_mfma_f32_16x16x32_f16, "V4fV8hV8hV4fIiIiIi", "nc", "gfx950-insts")
438438
TARGET_BUILTIN(__builtin_amdgcn_mfma_f32_32x32x16_f16, "V16fV8hV8hV16fIiIiIi", "nc", "gfx950-insts")
439439

440+
TARGET_BUILTIN(__builtin_amdgcn_mfma_f32_32x32x16_bf16, "V16fV8yV8yV16fIiIiIi", "nc", "gfx950-insts")
441+
440442
//===----------------------------------------------------------------------===//
441443
// GFX12+ only builtins.
442444
//===----------------------------------------------------------------------===//

clang/test/CodeGenOpenCL/builtins-amdgcn-mfma.cl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ typedef short v8s __attribute__((ext_vector_type(8)));
2424
typedef short v16s __attribute__((ext_vector_type(16)));
2525
typedef short v32s __attribute__((ext_vector_type(32)));
2626
typedef double v4d __attribute__((ext_vector_type(4)));
27+
typedef __bf16 v8bf16 __attribute__((ext_vector_type(8)));
2728

2829

2930
#ifdef MFMA_GFX908_TESTS
@@ -424,5 +425,10 @@ v16f test_mfma_f32_32x32x16_f16(v8h a, v8h b, v16f c)
424425
return __builtin_amdgcn_mfma_f32_32x32x16_f16(a, b, c, 1, 2, 3);
425426
}
426427

428+
// CHECK-GFX950-LABEL: @test_mfma_f32_32x32x16_bf16(
429+
// CHECK-GFX950: tail call <16 x float> @llvm.amdgcn.mfma.f32.32x32x16.bf16(<8 x bfloat> %a, <8 x bfloat> %b, <16 x float> %c, i32 1, i32 2, i32 3)
430+
v16f test_mfma_f32_32x32x16_bf16(v8bf16 a, v8bf16 b, v16f c) {
431+
return __builtin_amdgcn_mfma_f32_32x32x16_bf16(a, b, c, 1, 2, 3);
432+
}
427433

428434
#endif

clang/test/SemaOpenCL/builtins-amdgcn-error-gfx950-param.cl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
typedef float float4 __attribute__((ext_vector_type(4)));
55
typedef float float16 __attribute__((ext_vector_type(16)));
66
typedef half half8 __attribute__((ext_vector_type(8)));
7+
typedef __bf16 bfloat8 __attribute__((ext_vector_type(8)));
78

89

910
void test_mfma_f32_16x16x32_f16(__global float4* out, half8 a, half8 b, float4 c, int X) {
@@ -19,3 +20,9 @@ void test_mfma_f32_32x32x16_f16(__global float16* out, half8 a, half8 b, float16
1920
*out = __builtin_amdgcn_mfma_f32_32x32x16_f16(a, b, c, 0, X, 0); // expected-error{{argument to '__builtin_amdgcn_mfma_f32_32x32x16_f16' must be a constant integer}}
2021
*out = __builtin_amdgcn_mfma_f32_32x32x16_f16(a, b, c, 0, 0, X); // expected-error{{argument to '__builtin_amdgcn_mfma_f32_32x32x16_f16' must be a constant integer}}
2122
}
23+
24+
void test_mfma_f32_32x32x16_bf16(__global float16* out, bfloat8 a, bfloat8 b, float16 c, int X) {
25+
*out = __builtin_amdgcn_mfma_f32_32x32x16_bf16(a, b, c, X, 0, 0); // expected-error{{argument to '__builtin_amdgcn_mfma_f32_32x32x16_bf16' must be a constant integer}}
26+
*out = __builtin_amdgcn_mfma_f32_32x32x16_bf16(a, b, c, 0, X, 0); // expected-error{{argument to '__builtin_amdgcn_mfma_f32_32x32x16_bf16' must be a constant integer}}
27+
*out = __builtin_amdgcn_mfma_f32_32x32x16_bf16(a, b, c, 0, 0, X); // expected-error{{argument to '__builtin_amdgcn_mfma_f32_32x32x16_bf16' must be a constant integer}}
28+
}

clang/test/SemaOpenCL/builtins-amdgcn-error-gfx950.cl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,12 @@
44
typedef float float4 __attribute__((ext_vector_type(4)));
55
typedef float float16 __attribute__((ext_vector_type(16)));
66
typedef half half8 __attribute__((ext_vector_type(8)));
7+
typedef __bf16 bfloat8 __attribute__((ext_vector_type(8)));
78

89
void test(__global float4* out0, half8 a0, half8 b0, float4 c0,
9-
__global float16* out1, half8 a1, half8 b1, float16 c1) {
10+
__global float16* out1, half8 a1, half8 b1, float16 c1,
11+
__global float16* out2, bfloat8 a2, bfloat8 b2, float16 c2) {
1012
*out0 = __builtin_amdgcn_mfma_f32_16x16x32_f16(a0, b0, c0, 0, 0, 0); // expected-error{{'__builtin_amdgcn_mfma_f32_16x16x32_f16' needs target feature gfx950-insts}}
1113
*out1 = __builtin_amdgcn_mfma_f32_32x32x16_f16(a1, b1, c1, 0, 0, 0); // expected-error{{'__builtin_amdgcn_mfma_f32_32x32x16_f16' needs target feature gfx950-insts}}
14+
*out2 = __builtin_amdgcn_mfma_f32_32x32x16_bf16(a2, b2, c2, 0, 0, 0); // expected-error{{'__builtin_amdgcn_mfma_f32_32x32x16_bf16' needs target feature gfx950-insts}}
1215
}

llvm/include/llvm/IR/IntrinsicsAMDGPU.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3117,6 +3117,8 @@ def int_amdgcn_cvt_sr_fp8_f32 : ClangBuiltin<"__builtin_amdgcn_cvt_sr_fp8_f32">,
31173117
defset list<Intrinsic> AMDGPUMFMAIntrinsics950 = {
31183118
def int_amdgcn_mfma_f32_16x16x32_f16 : AMDGPUMfmaIntrinsic<llvm_v4f32_ty, llvm_v8f16_ty>;
31193119
def int_amdgcn_mfma_f32_32x32x16_f16 : AMDGPUMfmaIntrinsic<llvm_v16f32_ty, llvm_v8f16_ty>;
3120+
3121+
def int_amdgcn_mfma_f32_32x32x16_bf16 : AMDGPUMfmaIntrinsic<llvm_v16f32_ty, llvm_v8bf16_ty>;
31203122
}
31213123

31223124
//===----------------------------------------------------------------------===//

llvm/lib/Target/AMDGPU/SIInstrInfo.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2848,6 +2848,7 @@ def VOP_V16F32_V2I32_V4I32_I32 : VOPProfile <[v16f32, v2i32, v4i32, i32]>;
28482848

28492849
def VOP_V4F32_V8F16_V8F16_V4F32 : VOPProfile <[v4f32, v8f16, v8f16, v4f32]>;
28502850
def VOP_V16F32_V8F16_V8F16_V16F32 : VOPProfile <[v16f32, v8f16, v8f16, v16f32]>;
2851+
def VOP_V16F32_V8BF16_V8BF16_V16F32 : VOPProfile <[v16f32, v8bf16, v8bf16, v16f32]>;
28512852

28522853

28532854
class Commutable_REV <string revOp, bit isOrig> {

llvm/lib/Target/AMDGPU/VOP3PInstructions.td

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -631,6 +631,9 @@ def VOPProfileMAI_F32_V8F16_X32_VCD : VOPProfileMAI<VOP_V4F32_V8F16_V8F16_V4F32,
631631
def VOPProfileMAI_F32_V8F16_X16 : VOPProfileMAI<VOP_V16F32_V8F16_V8F16_V16F32, AISrc_512_f32, ADst_512, AVSrc_128>;
632632
def VOPProfileMAI_F32_V8F16_X16_VCD : VOPProfileMAI<VOP_V16F32_V8F16_V8F16_V16F32, VISrc_512_f32, VDst_512, AVSrc_128>;
633633

634+
def VOPProfileMAI_F32_V8BF16_X16 : VOPProfileMAI<VOP_V16F32_V8BF16_V8BF16_V16F32, AISrc_512_f32, ADst_512, AVSrc_128>;
635+
def VOPProfileMAI_F32_V8BF16_X16_VCD : VOPProfileMAI<VOP_V16F32_V8BF16_V8BF16_V16F32, VISrc_512_f32, VDst_512, AVSrc_128>;
636+
634637
class MFMATable <bit is_mac, string Name> {
635638
bit IsMac = is_mac;
636639
string FMAOp = Name;
@@ -747,6 +750,7 @@ defm V_MFMA_F32_32X32X4BF16 : MAIInst<"v_mfma_f32_32x32x4bf16", "F32_V2I16_X16",
747750
let SubtargetPredicate = HasGFX950Insts, is_gfx940_xdl = 1 in {
748751
defm V_MFMA_F32_16X16X32_F16 : MAIInst<"v_mfma_f32_16x16x32f16", "F32_V8F16_X32", int_amdgcn_mfma_f32_16x16x32_f16>;
749752
defm V_MFMA_F32_32X32X16_F16 : MAIInst<"v_mfma_f32_32x32x16f16", "F32_V8F16_X16", int_amdgcn_mfma_f32_32x32x16_f16>;
753+
defm V_MFMA_F32_32X32X16_BF16 : MAIInst<"v_mfma_f32_32x32x16bf16", "F32_V8BF16_X16", int_amdgcn_mfma_f32_32x32x16_bf16>;
750754
}
751755

752756
let Predicates = [isGFX90APlus] in {
@@ -1786,6 +1790,8 @@ defm V_MFMA_F64_4X4X4F64 : VOP3P_Real_MFMA_gfx90a <0x6f>;
17861790

17871791
defm V_MFMA_F32_16X16X32_F16 : VOP3P_Real_MFMA_gfx950 <0x54, "v_mfma_f32_16x16x32_f16">;
17881792
defm V_MFMA_F32_32X32X16_F16 : VOP3P_Real_MFMA_gfx950 <0x55, "v_mfma_f32_32x32x16_f16">;
1793+
defm V_MFMA_F32_32X32X16_BF16 : VOP3P_Real_MFMA_gfx950 <0x37, "v_mfma_f32_32x32x16_bf16">;
1794+
17891795
defm V_MFMA_I32_32X32X16I8 : VOP3P_Real_MFMA_gfx940 <0x56, "v_mfma_i32_32x32x16_i8">;
17901796
defm V_MFMA_I32_16X16X32I8 : VOP3P_Real_MFMA_gfx940 <0x57, "v_mfma_i32_16x16x32_i8">;
17911797
let SubtargetPredicate = HasXF32Insts in {

llvm/test/Analysis/UniformityAnalysis/AMDGPU/intrinsics.ll

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,14 @@ define amdgpu_kernel void @mfma_f32_32x32x16_f16(<8 x half> %arg0, <8 x half> %a
278278
ret void
279279
}
280280

281+
; CHECK: DIVERGENT: %result = call <16 x float> @llvm.amdgcn.mfma.f32.32x32x16.bf16(<8 x bfloat> %arg0, <8 x bfloat> %arg1, <16 x float> %arg2, i32 immarg 0, i32 immarg 0, i32 immarg 0)
282+
define amdgpu_kernel void @mfma_f32_32x32x16_bf16(<8 x bfloat> %arg0, <8 x bfloat> %arg1, <16 x float> %arg2, ptr addrspace(1) %out) {
283+
%result = call <16 x float> @llvm.amdgcn.mfma.f32.32x32x16.bf16(<8 x bfloat> %arg0, <8 x bfloat> %arg1, <16 x float> %arg2, i32 immarg 0, i32 immarg 0, i32 immarg 0)
284+
store <16 x float> %result, ptr addrspace(1) %out
285+
ret void
286+
}
287+
288+
281289
declare i32 @llvm.amdgcn.ds.swizzle(i32, i32) #1
282290
declare i32 @llvm.amdgcn.permlane16.i32(i32, i32, i32, i32, i1, i1) #1
283291
declare i32 @llvm.amdgcn.permlanex16.i32(i32, i32, i32, i32, i1, i1) #1

0 commit comments

Comments
 (0)