Skip to content

Commit c26061e

Browse files
committed
AMDGPU: Shrink used number of registers for mfma scale based on format
Currently the builtins assume you are using an 8-bit format that requires an 8 element vector. We can shrink the number of registers if the format requires 4 or 6.
1 parent a4b14ad commit c26061e

File tree

3 files changed

+372
-2
lines changed

3 files changed

+372
-2
lines changed

clang/test/CodeGenOpenCL/builtins-amdgcn-mfma.cl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -433,14 +433,16 @@ v16f test_mfma_f32_32x32x16_bf16(v8bf16 a, v8bf16 b, v16f c) {
433433
}
434434

435435
// CHECK-GFX950-LABEL: @test_mfma_scale_f32_16x16x128_f8f6f4
436-
// CHECK-GFX950: call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v8i32.v8i32(<8 x i32> %a, <8 x i32> %b, <4 x float> %c, i32 3, i32 1, i32 2, i32 %scale_a, i32 3, i32 %scale_b)
436+
// CHECK-GFX950: [[EXTRACT_A:%.+]] = shufflevector <8 x i32> %a, <8 x i32> poison, <6 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5>
437+
// CHECK-GFX950: call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v6i32.v8i32(<6 x i32> [[EXTRACT_A]], <8 x i32> %b, <4 x float> %c, i32 3, i32 1, i32 2, i32 %scale_a, i32 3, i32 %scale_b)
437438
void test_mfma_scale_f32_16x16x128_f8f6f4(global v4f* out, v8i a, v8i b, v4f c, int scale_a, int scale_b)
438439
{
439440
*out = __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(a, b, c, 3, 1, 2, scale_a, 3, scale_b);
440441
}
441442

442443
// CHECK-GFX950-LABEL: @test_mfma_scale_f32_32x32x64_f8f6f4
443-
// CHECK-GFX950: call <16 x float> @llvm.amdgcn.mfma.scale.f32.32x32x64.f8f6f4.v8i32.v8i32(<8 x i32> %a, <8 x i32> %b, <16 x float> %c, i32 3, i32 1, i32 2, i32 %scale_a, i32 3, i32 %scale_b)
444+
// CHECK-GFX950: [[EXTRACT_A:%.+]] = shufflevector <8 x i32> %a, <8 x i32> poison, <6 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5>
445+
// CHECK-GFX950: call <16 x float> @llvm.amdgcn.mfma.scale.f32.32x32x64.f8f6f4.v6i32.v8i32(<6 x i32> [[EXTRACT_A]], <8 x i32> %b, <16 x float> %c, i32 3, i32 1, i32 2, i32 %scale_a, i32 3, i32 %scale_b)
444446
void test_mfma_scale_f32_32x32x64_f8f6f4(global v16f* out, v8i a, v8i b, v16f c, int scale_a, int scale_b)
445447
{
446448
*out = __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(a, b, c, 3, 1, 2, scale_a, 3, scale_b);

llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1260,6 +1260,62 @@ GCNTTIImpl::instCombineIntrinsic(InstCombiner &IC, IntrinsicInst &II) const {
12601260
}
12611261
return std::nullopt;
12621262
}
1263+
case Intrinsic::amdgcn_mfma_scale_f32_16x16x128_f8f6f4:
1264+
case Intrinsic::amdgcn_mfma_scale_f32_32x32x64_f8f6f4: {
1265+
Value *Src0 = II.getArgOperand(0);
1266+
Value *Src1 = II.getArgOperand(1);
1267+
uint64_t CBSZ = cast<ConstantInt>(II.getArgOperand(3))->getZExtValue();
1268+
uint64_t BLGP = cast<ConstantInt>(II.getArgOperand(4))->getZExtValue();
1269+
auto *Src0Ty = cast<FixedVectorType>(Src0->getType());
1270+
auto *Src1Ty = cast<FixedVectorType>(Src1->getType());
1271+
1272+
auto getFormatNumRegs = [](unsigned FormatVal) {
1273+
switch (FormatVal) {
1274+
case AMDGPU::MFMAScaleFormats::FP6_E2M3:
1275+
case AMDGPU::MFMAScaleFormats::FP6_E3M2:
1276+
return 6u;
1277+
case AMDGPU::MFMAScaleFormats::FP4_E2M1:
1278+
return 4u;
1279+
case AMDGPU::MFMAScaleFormats::FP8_E4M3:
1280+
case AMDGPU::MFMAScaleFormats::FP8_E5M2:
1281+
return 8u;
1282+
default:
1283+
llvm_unreachable("invalid format value");
1284+
}
1285+
};
1286+
1287+
bool MadeChange = false;
1288+
unsigned Src0NumElts = getFormatNumRegs(CBSZ);
1289+
unsigned Src1NumElts = getFormatNumRegs(BLGP);
1290+
1291+
// Depending on the used format, fewer registers are required so shrink the
1292+
// vector type.
1293+
if (Src0Ty->getNumElements() > Src0NumElts) {
1294+
Src0 = IC.Builder.CreateExtractVector(
1295+
FixedVectorType::get(Src0Ty->getElementType(), Src0NumElts), Src0,
1296+
IC.Builder.getInt64(0));
1297+
MadeChange = true;
1298+
}
1299+
1300+
if (Src1Ty->getNumElements() > Src1NumElts) {
1301+
Src1 = IC.Builder.CreateExtractVector(
1302+
FixedVectorType::get(Src0Ty->getElementType(), Src1NumElts), Src1,
1303+
IC.Builder.getInt64(0));
1304+
MadeChange = true;
1305+
}
1306+
1307+
if (!MadeChange)
1308+
return std::nullopt;
1309+
1310+
SmallVector<Value *, 10> Args(II.args());
1311+
Args[0] = Src0;
1312+
Args[1] = Src1;
1313+
1314+
CallInst *NewII = IC.Builder.CreateIntrinsic(
1315+
IID, {Src0->getType(), Src1->getType()}, Args, &II);
1316+
NewII->takeName(&II);
1317+
return IC.replaceInstUsesWith(II, NewII);
1318+
}
12631319
}
12641320
if (const AMDGPU::ImageDimIntrinsicInfo *ImageDimIntr =
12651321
AMDGPU::getImageDimIntrinsicInfo(II.getIntrinsicID())) {

0 commit comments

Comments
 (0)