Skip to content

Commit 2a450e7

Browse files
arsenmpravinjagtap
authored andcommitted
AMDGPU: Define v_mfma_f32_{16x16x128|32x32x64}_f8f6f4 instructions (llvm#116723)
These use a new VOP3PX encoding for the v_mfma_scale_* instructions, which bundles the pre-scale v_mfma_ld_scale_b32. None of the modifiers are supported yet (op_sel, neg or clamp). I'm not sure the intrinsic should really expose op_sel (or any of the others). If I'm reading the documentation correctly, we should be able to just have the raw scale operands and auto-match op_sel to byte extract patterns. The op_sel syntax also seems extra horrible in this usage, especially with the usual assumed op_sel_hi=-1 behavior.
1 parent 2f4d0ac commit 2a450e7

33 files changed

+9930
-33
lines changed

clang/include/clang/Basic/BuiltinsAMDGPU.def

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -434,9 +434,11 @@ TARGET_BUILTIN(__builtin_amdgcn_cvt_sr_fp8_f32, "ifiiIi", "nc", "fp8-conversion-
434434
//===----------------------------------------------------------------------===//
435435
// GFX950 only builtins.
436436
//===----------------------------------------------------------------------===//
437+
TARGET_BUILTIN(__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4, "V4fV8ZiV8ZiV4fIiIiIiiIii", "nc", "gfx950-insts")
438+
TARGET_BUILTIN(__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4, "V16fV8ZiV8ZiV16fIiIiIiiIii", "nc", "gfx950-insts")
439+
437440
TARGET_BUILTIN(__builtin_amdgcn_mfma_f32_16x16x32_f16, "V4fV8hV8hV4fIiIiIi", "nc", "gfx950-insts")
438441
TARGET_BUILTIN(__builtin_amdgcn_mfma_f32_32x32x16_f16, "V16fV8hV8hV16fIiIiIi", "nc", "gfx950-insts")
439-
440442
TARGET_BUILTIN(__builtin_amdgcn_mfma_f32_32x32x16_bf16, "V16fV8yV8yV16fIiIiIi", "nc", "gfx950-insts")
441443

442444
//===----------------------------------------------------------------------===//

clang/lib/CodeGen/CGBuiltin.cpp

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18909,7 +18909,20 @@ Value *CodeGenFunction::EmitAMDGPUBuiltinExpr(unsigned BuiltinID,
1890918909
(uint64_t)0);
1891018910
return Builder.CreateInsertElement(I0, A, 1);
1891118911
}
18912-
18912+
case AMDGPU::BI__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4:
18913+
case AMDGPU::BI__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4: {
18914+
llvm::FixedVectorType *VT = FixedVectorType::get(Builder.getInt32Ty(), 8);
18915+
Function *F = CGM.getIntrinsic(
18916+
BuiltinID == AMDGPU::BI__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4
18917+
? Intrinsic::amdgcn_mfma_scale_f32_32x32x64_f8f6f4
18918+
: Intrinsic::amdgcn_mfma_scale_f32_16x16x128_f8f6f4,
18919+
{VT, VT});
18920+
18921+
SmallVector<Value *, 9> Args;
18922+
for (unsigned I = 0, N = E->getNumArgs(); I != N; ++I)
18923+
Args.push_back(EmitScalarExpr(E->getArg(I)));
18924+
return Builder.CreateCall(F, Args);
18925+
}
1891318926
case AMDGPU::BI__builtin_amdgcn_wmma_bf16_16x16x16_bf16_w32:
1891418927
case AMDGPU::BI__builtin_amdgcn_wmma_bf16_16x16x16_bf16_tied_w32:
1891518928
case AMDGPU::BI__builtin_amdgcn_wmma_bf16_16x16x16_bf16_w64:

clang/test/CodeGenOpenCL/builtins-amdgcn-mfma.cl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ typedef half v16h __attribute__((ext_vector_type(16)));
1616
typedef half v32h __attribute__((ext_vector_type(32)));
1717
typedef int v2i __attribute__((ext_vector_type(2)));
1818
typedef int v4i __attribute__((ext_vector_type(4)));
19+
typedef int v8i __attribute__((ext_vector_type(8)));
1920
typedef int v16i __attribute__((ext_vector_type(16)));
2021
typedef int v32i __attribute__((ext_vector_type(32)));
2122
typedef short v2s __attribute__((ext_vector_type(2)));
@@ -431,4 +432,18 @@ v16f test_mfma_f32_32x32x16_bf16(v8bf16 a, v8bf16 b, v16f c) {
431432
return __builtin_amdgcn_mfma_f32_32x32x16_bf16(a, b, c, 1, 2, 3);
432433
}
433434

435+
// CHECK-GFX950-LABEL: @test_mfma_scale_f32_16x16x128_f8f6f4
436+
// CHECK-GFX950: call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v8i32.v8i32(<8 x i32> %a, <8 x i32> %b, <4 x float> %c, i32 3, i32 1, i32 2, i32 %scale_a, i32 3, i32 %scale_b)
437+
void test_mfma_scale_f32_16x16x128_f8f6f4(global v4f* out, v8i a, v8i b, v4f c, int scale_a, int scale_b)
438+
{
439+
*out = __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(a, b, c, 3, 1, 2, scale_a, 3, scale_b);
440+
}
441+
442+
// CHECK-GFX950-LABEL: @test_mfma_scale_f32_32x32x64_f8f6f4
443+
// CHECK-GFX950: call <16 x float> @llvm.amdgcn.mfma.scale.f32.32x32x64.f8f6f4.v8i32.v8i32(<8 x i32> %a, <8 x i32> %b, <16 x float> %c, i32 3, i32 1, i32 2, i32 %scale_a, i32 3, i32 %scale_b)
444+
void test_mfma_scale_f32_32x32x64_f8f6f4(global v16f* out, v8i a, v8i b, v16f c, int scale_a, int scale_b)
445+
{
446+
*out = __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(a, b, c, 3, 1, 2, scale_a, 3, scale_b);
447+
}
448+
434449
#endif

clang/test/SemaOpenCL/builtins-amdgcn-error-gfx950-param.cl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ typedef float float4 __attribute__((ext_vector_type(4)));
55
typedef float float16 __attribute__((ext_vector_type(16)));
66
typedef half half8 __attribute__((ext_vector_type(8)));
77
typedef __bf16 bfloat8 __attribute__((ext_vector_type(8)));
8+
typedef int int8 __attribute__((ext_vector_type(8)));
89

910

1011
void test_mfma_f32_16x16x32_f16(__global float4* out, half8 a, half8 b, float4 c, int X) {
@@ -26,3 +27,17 @@ void test_mfma_f32_32x32x16_bf16(__global float16* out, bfloat8 a, bfloat8 b, fl
2627
*out = __builtin_amdgcn_mfma_f32_32x32x16_bf16(a, b, c, 0, X, 0); // expected-error{{argument to '__builtin_amdgcn_mfma_f32_32x32x16_bf16' must be a constant integer}}
2728
*out = __builtin_amdgcn_mfma_f32_32x32x16_bf16(a, b, c, 0, 0, X); // expected-error{{argument to '__builtin_amdgcn_mfma_f32_32x32x16_bf16' must be a constant integer}}
2829
}
30+
31+
void test_mfma_scale_f32_16x16x128_f8f6f4(__global float4* out, int8 a, int8 b, float4 c, int X, int Y) {
32+
*out = __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(a, b, c, X, 0, 1, Y, 2, Y); // expected-error{{argument to '__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4' must be a constant integer}}
33+
*out = __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(a, b, c, 0, X, 1, Y, 2, Y); // expected-error{{argument to '__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4' must be a constant integer}}
34+
*out = __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(a, b, c, 0, 0, X, Y, 2, Y); // expected-error{{argument to '__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4' must be a constant integer}}
35+
*out = __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(a, b, c, 0, 0, 0, Y, X, Y); // expected-error{{argument to '__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4' must be a constant integer}}
36+
}
37+
38+
void test_mfma_scale_f32_32x32x64_f8f6f4(__global float16* out, int8 a, int8 b, float16 c, int X, int Y) {
39+
*out = __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(a, b, c, X, 0, 1, Y, 2, Y); // expected-error{{argument to '__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4' must be a constant integer}}
40+
*out = __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(a, b, c, 0, X, 1, Y, 2, Y); // expected-error{{argument to '__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4' must be a constant integer}}
41+
*out = __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(a, b, c, 0, 0, X, Y, 2, Y); // expected-error{{argument to '__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4' must be a constant integer}}
42+
*out = __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(a, b, c, 0, 0, 0, Y, X, Y); // expected-error{{argument to '__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4' must be a constant integer}}
43+
}

clang/test/SemaOpenCL/builtins-amdgcn-error-gfx950.cl

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,33 @@
44
typedef float float4 __attribute__((ext_vector_type(4)));
55
typedef float float16 __attribute__((ext_vector_type(16)));
66
typedef half half8 __attribute__((ext_vector_type(8)));
7+
typedef half half16 __attribute__((ext_vector_type(16)));
78
typedef __bf16 bfloat8 __attribute__((ext_vector_type(8)));
9+
typedef __bf16 bfloat16 __attribute__((ext_vector_type(16)));
10+
typedef unsigned int uint2 __attribute__((ext_vector_type(2)));
11+
typedef int int4 __attribute__((ext_vector_type(4)));
12+
typedef int int8 __attribute__((ext_vector_type(8)));
13+
typedef int int16 __attribute__((ext_vector_type(16)));
814

915
void test(__global float4* out0, half8 a0, half8 b0, float4 c0,
1016
__global float16* out1, half8 a1, half8 b1, float16 c1,
11-
__global float16* out2, bfloat8 a2, bfloat8 b2, float16 c2) {
17+
__global float16* out2, bfloat8 a2, bfloat8 b2, float16 c2,
18+
__global int4* out3, int4 a3, int4 b3, int4 c3,
19+
__global int16* out4, int4 a4, int4 b4, int16 c4,
20+
__global float4* out5, bfloat8 a5, bfloat8 b5, float4 c5,
21+
__global float4* out6, half8 a6, half16 b6, float4 c6,
22+
__global float16* out7, half8 a7, half16 b7, float16 c7,
23+
__global float4* out8, bfloat8 a8, bfloat16 b8, float4 c8,
24+
__global float16* out9, bfloat8 a9, bfloat16 b9, float16 c9,
25+
__global int4* out10, int4 a10, int8 b10, int4 c10,
26+
__global int16* out11, int4 a11, int8 b11, int16 c11,
27+
__global float4* out12, int4 a12, int8 b12, float4 c12,
28+
__global float16* out13, int4 a13, int8 b13, float16 c13,
29+
__global float4* out14, int8 a14, int8 b14, float4 c14, int d14, int e14,
30+
__global float16* out15, int8 a15, int8 b15, float16 c15, int d15, int e15) {
1231
*out0 = __builtin_amdgcn_mfma_f32_16x16x32_f16(a0, b0, c0, 0, 0, 0); // expected-error{{'__builtin_amdgcn_mfma_f32_16x16x32_f16' needs target feature gfx950-insts}}
1332
*out1 = __builtin_amdgcn_mfma_f32_32x32x16_f16(a1, b1, c1, 0, 0, 0); // expected-error{{'__builtin_amdgcn_mfma_f32_32x32x16_f16' needs target feature gfx950-insts}}
1433
*out2 = __builtin_amdgcn_mfma_f32_32x32x16_bf16(a2, b2, c2, 0, 0, 0); // expected-error{{'__builtin_amdgcn_mfma_f32_32x32x16_bf16' needs target feature gfx950-insts}}
34+
*out14 = __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(a14, b14, c14, 0, 0, 0, d14, 0, e14); // expected-error{{'__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4' needs target feature gfx950-insts}}
35+
*out15 = __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(a15, b15, c15, 0, 0, 0, d15, 0, e15); // expected-error{{'__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4' needs target feature gfx950-insts}}
1536
}

llvm/docs/AMDGPUUsage.rst

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1391,6 +1391,16 @@ The AMDGPU backend implements the following LLVM IR intrinsics.
13911391
sign-extended from the width of the underlying PC hardware register even on
13921392
processors where the s_getpc_b64 instruction returns a zero-extended value.
13931393

1394+
llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4 Emit `v_mfma_scale_f32_16x16x128_f8f6f4` to set the scale factor. The
1395+
last 4 operands correspond to the scale inputs.
1396+
1397+
- 2-bit byte index to use for each lane for matrix A
1398+
- Matrix A scale values
1399+
- 2-bit byte index to use for each lane for matrix B
1400+
- Matrix B scale values
1401+
1402+
llvm.amdgcn.mfma.scale.f32.32x32x64.f8f6f4 Emit `v_mfma_scale_f32_32x32x64_f8f6f4`
1403+
13941404
============================================== ==========================================================
13951405

13961406
.. TODO::

llvm/include/llvm/IR/IntrinsicsAMDGPU.td

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2987,6 +2987,35 @@ class AMDGPUMfmaIntrinsic<LLVMType DestTy, LLVMType SrcABTy> :
29872987
[IntrConvergent, IntrNoMem,
29882988
ImmArg<ArgIndex<3>>, ImmArg<ArgIndex<4>>, ImmArg<ArgIndex<5>>]>;
29892989

2990+
2991+
// srcA's format is determined by cbsz. srcB's format is determined by
2992+
// blgp.
2993+
//
2994+
// These should be <8 x i32> for f8 formats, <6 x i32> for f6 formats,
2995+
// and <4 x i32> for f4 formats. If the format control bits imply a
2996+
// smaller type than used, the high elements will be truncated.
2997+
//
2998+
// If the format control bits imply a larger type than used, the high
2999+
// elements are padded with undef.
3000+
3001+
class AMDGPUMfmaScaleIntrinsic<LLVMType DestTy> :
3002+
DefaultAttrsIntrinsic<[DestTy],
3003+
[llvm_anyvector_ty, llvm_anyvector_ty, DestTy,
3004+
llvm_i32_ty, // cbsz
3005+
llvm_i32_ty, // blgp
3006+
// llvm_i1_ty, // TODO: neg_src2
3007+
// llvm_i1_ty, // TODO: abs_src2
3008+
// llvm_i1_ty, // TODO: clamp
3009+
llvm_i32_ty, // op_sel (A matrix scale, 2-bits) // TODO: Make i2?
3010+
llvm_i32_ty, // v_mfma_ld_scale_b32 src0 (A matrix scale)
3011+
llvm_i32_ty, // op_sel (B matrix scale, 2-bits) // TODO: Make i2?
3012+
llvm_i32_ty // v_mfma_ld_scale_b32 src1 (B matrix scale)
3013+
],
3014+
[IntrConvergent, IntrNoMem,
3015+
ImmArg<ArgIndex<3>>, ImmArg<ArgIndex<4>>,
3016+
ImmArg<ArgIndex<5>>, ImmArg<ArgIndex<7>>
3017+
]>;
3018+
29903019
defset list<Intrinsic> AMDGPUMFMAIntrinsics908 = {
29913020
def int_amdgcn_mfma_f32_32x32x1f32 : AMDGPUMfmaIntrinsic<llvm_v32f32_ty, llvm_float_ty>;
29923021
def int_amdgcn_mfma_f32_16x16x1f32 : AMDGPUMfmaIntrinsic<llvm_v16f32_ty, llvm_float_ty>;
@@ -3148,6 +3177,8 @@ def int_amdgcn_mfma_f32_16x16x32_f16 : AMDGPUMfmaIntrinsic<llvm_v4f32_ty, llvm_v
31483177
def int_amdgcn_mfma_f32_32x32x16_f16 : AMDGPUMfmaIntrinsic<llvm_v16f32_ty, llvm_v8f16_ty>;
31493178

31503179
def int_amdgcn_mfma_f32_32x32x16_bf16 : AMDGPUMfmaIntrinsic<llvm_v16f32_ty, llvm_v8bf16_ty>;
3180+
def int_amdgcn_mfma_scale_f32_16x16x128_f8f6f4 : AMDGPUMfmaScaleIntrinsic<llvm_v4f32_ty>;
3181+
def int_amdgcn_mfma_scale_f32_32x32x64_f8f6f4 : AMDGPUMfmaScaleIntrinsic<llvm_v16f32_ty>;
31513182
}
31523183

31533184
//===----------------------------------------------------------------------===//

llvm/lib/Target/AMDGPU/AMDGPUGISel.td

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -418,3 +418,6 @@ def gi_fp_pow2_to_exponent : GICustomOperandRenderer<"renderFPPow2ToExponent">,
418418

419419
def gi_as_hw_round_mode : GICustomOperandRenderer<"renderRoundMode">,
420420
GISDNodeXFormEquiv<as_hw_round_mode>;
421+
422+
def gi_MFMALdScaleModifierOp : GICustomOperandRenderer<"renderScaledMAIIntrinsicOperand">,
423+
GISDNodeXFormEquiv<MFMALdScaleXForm>;

llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1281,6 +1281,7 @@ GCNTTIImpl::instCombineIntrinsic(InstCombiner &IC, IntrinsicInst &II) const {
12811281
if (isa<UndefValue>(Src)) {
12821282
return IC.replaceInstUsesWith(II, Src);
12831283
}
1284+
return std::nullopt;
12841285
}
12851286
}
12861287
if (const AMDGPU::ImageDimIntrinsicInfo *ImageDimIntr =

llvm/lib/Target/AMDGPU/AMDGPUInstructionSelector.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5711,6 +5711,18 @@ void AMDGPUInstructionSelector::renderRoundMode(MachineInstrBuilder &MIB,
57115711
MIB.addImm((MI.getOperand(OpIdx).getImm() + 3) % 4);
57125712
}
57135713

5714+
/// Convert from 2-bit value to enum values used for op_sel* source modifiers.
5715+
void AMDGPUInstructionSelector::renderScaledMAIIntrinsicOperand(
5716+
MachineInstrBuilder &MIB, const MachineInstr &MI, int OpIdx) const {
5717+
unsigned Val = MI.getOperand(OpIdx).getImm();
5718+
unsigned New = 0;
5719+
if (Val & 0x1)
5720+
New |= SISrcMods::OP_SEL_0;
5721+
if (Val & 0x2)
5722+
New |= SISrcMods::OP_SEL_1;
5723+
MIB.addImm(New);
5724+
}
5725+
57145726
bool AMDGPUInstructionSelector::isInlineImmediate(const APInt &Imm) const {
57155727
return TII.isInlineConstant(Imm);
57165728
}

llvm/lib/Target/AMDGPU/AMDGPUInstructionSelector.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -353,6 +353,8 @@ class AMDGPUInstructionSelector final : public InstructionSelector {
353353

354354
void renderRoundMode(MachineInstrBuilder &MIB, const MachineInstr &MI,
355355
int OpIdx) const;
356+
void renderScaledMAIIntrinsicOperand(MachineInstrBuilder &MIB,
357+
const MachineInstr &MI, int OpIdx) const;
356358

357359
bool isInlineImmediate(const APInt &Imm) const;
358360
bool isInlineImmediate(const APFloat &Imm) const;

llvm/lib/Target/AMDGPU/AMDGPUInstructions.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ class AMDGPUInst <dag outs, dag ins, string asm = "",
4040
// instructions to not match without killing the whole decode process. It is
4141
// mainly used for ARM, but Tablegen expects this field to exist or it fails
4242
// to build the decode table.
43-
field bits<96> SoftFail = 0;
43+
field bits<128> SoftFail = 0; // FIXME: If this is smaller than largest instruction, DecodeEmitter crashes
4444

4545
let DecoderNamespace = Namespace;
4646

llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4746,6 +4746,25 @@ AMDGPURegisterBankInfo::getInstrMapping(const MachineInstr &MI) const {
47464746
: getVGPROpMapping(MI.getOperand(4).getReg(), MRI, *TRI);
47474747
break;
47484748
}
4749+
case Intrinsic::amdgcn_mfma_scale_f32_16x16x128_f8f6f4:
4750+
case Intrinsic::amdgcn_mfma_scale_f32_32x32x64_f8f6f4: {
4751+
const SIMachineFunctionInfo *Info = MF.getInfo<SIMachineFunctionInfo>();
4752+
OpdsMapping[0] =
4753+
Info->mayNeedAGPRs()
4754+
? getAGPROpMapping(MI.getOperand(0).getReg(), MRI, *TRI)
4755+
: getVGPROpMapping(MI.getOperand(0).getReg(), MRI, *TRI);
4756+
4757+
OpdsMapping[2] = getVGPROpMapping(MI.getOperand(2).getReg(), MRI, *TRI);
4758+
OpdsMapping[3] = getVGPROpMapping(MI.getOperand(3).getReg(), MRI, *TRI);
4759+
OpdsMapping[4] =
4760+
Info->mayNeedAGPRs()
4761+
? getAGPROpMapping(MI.getOperand(4).getReg(), MRI, *TRI)
4762+
: getVGPROpMapping(MI.getOperand(4).getReg(), MRI, *TRI);
4763+
4764+
OpdsMapping[8] = getVGPROpMapping(MI.getOperand(8).getReg(), MRI, *TRI);
4765+
OpdsMapping[10] = getVGPROpMapping(MI.getOperand(10).getReg(), MRI, *TRI);
4766+
break;
4767+
}
47494768
case Intrinsic::amdgcn_smfmac_f32_16x16x32_f16:
47504769
case Intrinsic::amdgcn_smfmac_f32_32x32x16_f16:
47514770
case Intrinsic::amdgcn_smfmac_f32_16x16x32_bf16:

llvm/lib/Target/AMDGPU/Disassembler/AMDGPUDisassembler.cpp

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -477,6 +477,17 @@ static inline DecoderUInt128 eat12Bytes(ArrayRef<uint8_t> &Bytes) {
477477
return DecoderUInt128(Lo, Hi);
478478
}
479479

480+
static inline DecoderUInt128 eat16Bytes(ArrayRef<uint8_t> &Bytes) {
481+
assert(Bytes.size() >= 16);
482+
uint64_t Lo =
483+
support::endian::read<uint64_t, llvm::endianness::little>(Bytes.data());
484+
Bytes = Bytes.slice(8);
485+
uint64_t Hi =
486+
support::endian::read<uint64_t, llvm::endianness::little>(Bytes.data());
487+
Bytes = Bytes.slice(8);
488+
return DecoderUInt128(Lo, Hi);
489+
}
490+
480491
DecodeStatus AMDGPUDisassembler::getInstruction(MCInst &MI, uint64_t &Size,
481492
ArrayRef<uint8_t> Bytes_,
482493
uint64_t Address,
@@ -513,6 +524,15 @@ DecodeStatus AMDGPUDisassembler::getInstruction(MCInst &MI, uint64_t &Size,
513524

514525
// Reinitialize Bytes
515526
Bytes = Bytes_.slice(0, MaxInstBytesNum);
527+
528+
} else if (Bytes.size() >= 16 &&
529+
STI.hasFeature(AMDGPU::FeatureGFX950Insts)) {
530+
DecoderUInt128 DecW = eat16Bytes(Bytes);
531+
if (tryDecodeInst(DecoderTableGFX940128, MI, DecW, Address, CS))
532+
break;
533+
534+
// Reinitialize Bytes
535+
Bytes = Bytes_.slice(0, MaxInstBytesNum);
516536
}
517537

518538
if (Bytes.size() >= 8) {
@@ -722,6 +742,9 @@ DecodeStatus AMDGPUDisassembler::getInstruction(MCInst &MI, uint64_t &Size,
722742
if (MCII->get(MI.getOpcode()).TSFlags & SIInstrFlags::SDWA)
723743
convertSDWAInst(MI);
724744

745+
if (MCII->get(MI.getOpcode()).TSFlags & SIInstrFlags::IsMAI)
746+
convertMAIInst(MI);
747+
725748
int VDstIn_Idx = AMDGPU::getNamedOperandIdx(MI.getOpcode(),
726749
AMDGPU::OpName::vdst_in);
727750
if (VDstIn_Idx != -1) {
@@ -791,6 +814,58 @@ void AMDGPUDisassembler::convertSDWAInst(MCInst &MI) const {
791814
}
792815
}
793816

817+
/// Adjust the register values used by V_MFMA_F8F6F4_f8_f8 instructions to the
818+
/// appropriate subregister for the used format width.
819+
static void adjustMFMA_F8F6F4OpRegClass(const MCRegisterInfo &MRI,
820+
MCOperand &MO, uint8_t NumRegs) {
821+
switch (NumRegs) {
822+
case 4:
823+
return MO.setReg(MRI.getSubReg(MO.getReg(), AMDGPU::sub0_sub1_sub2_sub3));
824+
case 6:
825+
return MO.setReg(
826+
MRI.getSubReg(MO.getReg(), AMDGPU::sub0_sub1_sub2_sub3_sub4_sub5));
827+
case 8:
828+
// No-op in cases where one operand is still f8/bf8.
829+
return;
830+
default:
831+
llvm_unreachable("Unexpected size for mfma f8f6f4 operand");
832+
}
833+
}
834+
835+
/// f8f6f4 instructions have different pseudos depending on the used formats. In
836+
/// the disassembler table, we only have the variants with the largest register
837+
/// classes which assume using an fp8/bf8 format for both operands. The actual
838+
/// register class depends on the format in blgp and cbsz operands. Adjust the
839+
/// register classes depending on the used format.
840+
void AMDGPUDisassembler::convertMAIInst(MCInst &MI) const {
841+
int BlgpIdx =
842+
AMDGPU::getNamedOperandIdx(MI.getOpcode(), AMDGPU::OpName::blgp);
843+
if (BlgpIdx == -1)
844+
return;
845+
846+
int CbszIdx =
847+
AMDGPU::getNamedOperandIdx(MI.getOpcode(), AMDGPU::OpName::cbsz);
848+
849+
unsigned CBSZ = MI.getOperand(CbszIdx).getImm();
850+
unsigned BLGP = MI.getOperand(BlgpIdx).getImm();
851+
852+
const AMDGPU::MFMA_F8F6F4_Info *AdjustedRegClassOpcode =
853+
AMDGPU::getMFMA_F8F6F4_WithFormatArgs(CBSZ, BLGP, MI.getOpcode());
854+
if (!AdjustedRegClassOpcode ||
855+
AdjustedRegClassOpcode->Opcode == MI.getOpcode())
856+
return;
857+
858+
MI.setOpcode(AdjustedRegClassOpcode->Opcode);
859+
int Src0Idx =
860+
AMDGPU::getNamedOperandIdx(MI.getOpcode(), AMDGPU::OpName::src0);
861+
int Src1Idx =
862+
AMDGPU::getNamedOperandIdx(MI.getOpcode(), AMDGPU::OpName::src1);
863+
adjustMFMA_F8F6F4OpRegClass(MRI, MI.getOperand(Src0Idx),
864+
AdjustedRegClassOpcode->NumRegsSrcA);
865+
adjustMFMA_F8F6F4OpRegClass(MRI, MI.getOperand(Src1Idx),
866+
AdjustedRegClassOpcode->NumRegsSrcB);
867+
}
868+
794869
struct VOPModifiers {
795870
unsigned OpSel = 0;
796871
unsigned OpSelHi = 0;

llvm/lib/Target/AMDGPU/Disassembler/AMDGPUDisassembler.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,7 @@ class AMDGPUDisassembler : public MCDisassembler {
204204
void convertVINTERPInst(MCInst &MI) const;
205205
void convertFMAanyK(MCInst &MI, int ImmLitIdx) const;
206206
void convertSDWAInst(MCInst &MI) const;
207+
void convertMAIInst(MCInst &MI) const;
207208
void convertDPP8Inst(MCInst &MI) const;
208209
void convertMIMGInst(MCInst &MI) const;
209210
void convertVOP3DPPInst(MCInst &MI) const;

0 commit comments

Comments
 (0)