Skip to content

Commit f01ee7e

Browse files
committed
AMDGPU: Add basic verification for mfma scale intrinsics
Verify the format is valid and the type is one of the expected i32 vectors. Verify the used vector types at least cover the requirements of the corresponding format operand.
1 parent c26061e commit f01ee7e

File tree

3 files changed

+283
-6
lines changed

3 files changed

+283
-6
lines changed

llvm/include/llvm/IR/IntrinsicsAMDGPU.td

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2973,12 +2973,10 @@ class AMDGPUMfmaIntrinsic<LLVMType DestTy, LLVMType SrcABTy> :
29732973
// blgp.
29742974
//
29752975
// These should be <8 x i32> for f8 formats, <6 x i32> for f6 formats,
2976-
// and <4 x i32> for f4 formats. If the format control bits imply a
2977-
// smaller type than used, the high elements will be truncated.
2978-
//
2979-
// If the format control bits imply a larger type than used, the high
2980-
// elements are padded with undef.
2981-
2976+
// and <4 x i32> for f4 formats. It is invalid to use a format that
2977+
// requires more registers than the corresponding vector type (e.g. it
2978+
// is illegal to use <6 x i32> in operand 0 if cbsz specifies an f8
2979+
// format that requires 8 registers).
29822980
class AMDGPUMfmaScaleIntrinsic<LLVMType DestTy> :
29832981
DefaultAttrsIntrinsic<[DestTy],
29842982
[llvm_anyvector_ty, llvm_anyvector_ty, DestTy,

llvm/lib/IR/Verifier.cpp

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6383,6 +6383,55 @@ void Verifier::visitIntrinsicCall(Intrinsic::ID ID, CallBase &Call) {
63836383
"llvm.amdgcn.s.prefetch.data only supports global or constant memory");
63846384
break;
63856385
}
6386+
case Intrinsic::amdgcn_mfma_scale_f32_16x16x128_f8f6f4:
6387+
case Intrinsic::amdgcn_mfma_scale_f32_32x32x64_f8f6f4: {
6388+
Value *Src0 = Call.getArgOperand(0);
6389+
Value *Src1 = Call.getArgOperand(1);
6390+
6391+
uint64_t CBSZ = cast<ConstantInt>(Call.getArgOperand(3))->getZExtValue();
6392+
uint64_t BLGP = cast<ConstantInt>(Call.getArgOperand(4))->getZExtValue();
6393+
Check(CBSZ <= 4, "invalid value for cbsz format", Call,
6394+
Call.getArgOperand(3));
6395+
Check(BLGP <= 4, "invalid value for blgp format", Call,
6396+
Call.getArgOperand(4));
6397+
6398+
// AMDGPU::MFMAScaleFormats values
6399+
auto getFormatNumRegs = [](unsigned FormatVal) {
6400+
switch (FormatVal) {
6401+
case 0:
6402+
case 1:
6403+
return 8u;
6404+
case 2:
6405+
case 3:
6406+
return 6u;
6407+
case 4:
6408+
return 4u;
6409+
default:
6410+
llvm_unreachable("invalid format value");
6411+
}
6412+
};
6413+
6414+
auto isValidSrcASrcBVector = [](FixedVectorType *Ty) {
6415+
if (!Ty || !Ty->getElementType()->isIntegerTy(32))
6416+
return false;
6417+
unsigned NumElts = Ty->getNumElements();
6418+
return NumElts == 4 || NumElts == 6 || NumElts == 8;
6419+
};
6420+
6421+
auto *Src0Ty = dyn_cast<FixedVectorType>(Src0->getType());
6422+
auto *Src1Ty = dyn_cast<FixedVectorType>(Src1->getType());
6423+
Check(isValidSrcASrcBVector(Src0Ty),
6424+
"operand 0 must be 4, 6 or 8 element i32 vector", &Call, Src0);
6425+
Check(isValidSrcASrcBVector(Src1Ty),
6426+
"operand 1 must be 4, 6 or 8 element i32 vector", &Call, Src1);
6427+
6428+
// Permit excess registers for the format.
6429+
Check(Src0Ty->getNumElements() >= getFormatNumRegs(CBSZ),
6430+
"invalid vector type for format", &Call, Src0, Call.getArgOperand(3));
6431+
Check(Src1Ty->getNumElements() >= getFormatNumRegs(BLGP),
6432+
"invalid vector type for format", &Call, Src1, Call.getArgOperand(5));
6433+
break;
6434+
}
63866435
case Intrinsic::nvvm_setmaxnreg_inc_sync_aligned_u32:
63876436
case Intrinsic::nvvm_setmaxnreg_dec_sync_aligned_u32: {
63886437
Value *V = Call.getArgOperand(0);

0 commit comments

Comments
 (0)