Skip to content

Commit a05a1d6

Browse files
authored
AMDGPU: Add basic verification for mfma scale intrinsics (#117048)
Verify the format is valid and the type is one of the expected i32 vectors. Verify the used vector types at least cover the requirements of the corresponding format operand.
1 parent 7d544c6 commit a05a1d6

File tree

3 files changed

+283
-6
lines changed

3 files changed

+283
-6
lines changed

llvm/include/llvm/IR/IntrinsicsAMDGPU.td

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2973,12 +2973,10 @@ class AMDGPUMfmaIntrinsic<LLVMType DestTy, LLVMType SrcABTy> :
29732973
// blgp.
29742974
//
29752975
// These should be <8 x i32> for f8 formats, <6 x i32> for f6 formats,
2976-
// and <4 x i32> for f4 formats. If the format control bits imply a
2977-
// smaller type than used, the high elements will be truncated.
2978-
//
2979-
// If the format control bits imply a larger type than used, the high
2980-
// elements are padded with undef.
2981-
2976+
// and <4 x i32> for f4 formats. It is invalid to use a format that
2977+
// requires more registers than the corresponding vector type (e.g. it
2978+
// is illegal to use <6 x i32> in operand 0 if cbsz specifies an f8
2979+
// format that requires 8 registers).
29822980
class AMDGPUMfmaScaleIntrinsic<LLVMType DestTy> :
29832981
DefaultAttrsIntrinsic<[DestTy],
29842982
[llvm_anyvector_ty, llvm_anyvector_ty, DestTy,

llvm/lib/IR/Verifier.cpp

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6390,6 +6390,55 @@ void Verifier::visitIntrinsicCall(Intrinsic::ID ID, CallBase &Call) {
63906390
"llvm.amdgcn.s.prefetch.data only supports global or constant memory");
63916391
break;
63926392
}
6393+
case Intrinsic::amdgcn_mfma_scale_f32_16x16x128_f8f6f4:
6394+
case Intrinsic::amdgcn_mfma_scale_f32_32x32x64_f8f6f4: {
6395+
Value *Src0 = Call.getArgOperand(0);
6396+
Value *Src1 = Call.getArgOperand(1);
6397+
6398+
uint64_t CBSZ = cast<ConstantInt>(Call.getArgOperand(3))->getZExtValue();
6399+
uint64_t BLGP = cast<ConstantInt>(Call.getArgOperand(4))->getZExtValue();
6400+
Check(CBSZ <= 4, "invalid value for cbsz format", Call,
6401+
Call.getArgOperand(3));
6402+
Check(BLGP <= 4, "invalid value for blgp format", Call,
6403+
Call.getArgOperand(4));
6404+
6405+
// AMDGPU::MFMAScaleFormats values
6406+
auto getFormatNumRegs = [](unsigned FormatVal) {
6407+
switch (FormatVal) {
6408+
case 0:
6409+
case 1:
6410+
return 8u;
6411+
case 2:
6412+
case 3:
6413+
return 6u;
6414+
case 4:
6415+
return 4u;
6416+
default:
6417+
llvm_unreachable("invalid format value");
6418+
}
6419+
};
6420+
6421+
auto isValidSrcASrcBVector = [](FixedVectorType *Ty) {
6422+
if (!Ty || !Ty->getElementType()->isIntegerTy(32))
6423+
return false;
6424+
unsigned NumElts = Ty->getNumElements();
6425+
return NumElts == 4 || NumElts == 6 || NumElts == 8;
6426+
};
6427+
6428+
auto *Src0Ty = dyn_cast<FixedVectorType>(Src0->getType());
6429+
auto *Src1Ty = dyn_cast<FixedVectorType>(Src1->getType());
6430+
Check(isValidSrcASrcBVector(Src0Ty),
6431+
"operand 0 must be 4, 6 or 8 element i32 vector", &Call, Src0);
6432+
Check(isValidSrcASrcBVector(Src1Ty),
6433+
"operand 1 must be 4, 6 or 8 element i32 vector", &Call, Src1);
6434+
6435+
// Permit excess registers for the format.
6436+
Check(Src0Ty->getNumElements() >= getFormatNumRegs(CBSZ),
6437+
"invalid vector type for format", &Call, Src0, Call.getArgOperand(3));
6438+
Check(Src1Ty->getNumElements() >= getFormatNumRegs(BLGP),
6439+
"invalid vector type for format", &Call, Src1, Call.getArgOperand(5));
6440+
break;
6441+
}
63936442
case Intrinsic::nvvm_setmaxnreg_inc_sync_aligned_u32:
63946443
case Intrinsic::nvvm_setmaxnreg_dec_sync_aligned_u32: {
63956444
Value *V = Call.getArgOperand(0);

0 commit comments

Comments
 (0)