Skip to content

Commit 7614a51

Browse files
Revert "[mlir][IR] Remove isF...() type API for low-precision FP types (llvm#123326)"
This reverts commit 7a77f14.
1 parent 9b39b61 commit 7614a51

File tree

11 files changed

+95
-75
lines changed

11 files changed

+95
-75
lines changed

mlir/include/mlir/IR/CommonTypeConstraints.td

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -329,31 +329,31 @@ def F64 : F<64>;
329329
def F80 : F<80>;
330330
def F128 : F<128>;
331331

332-
def BF16 : Type<CPred<"::llvm::isa<BFloat16Type>($_self)">, "bfloat16 type">,
332+
def BF16 : Type<CPred<"$_self.isBF16()">, "bfloat16 type">,
333333
BuildableType<"$_builder.getType<BFloat16Type>()">;
334-
def TF32 : Type<CPred<"::llvm::isa<FloatTF32Type>($_self)">, "tf32 type">,
334+
def TF32 : Type<CPred<"$_self.isTF32()">, "tf32 type">,
335335
BuildableType<"$_builder.getType<FloatTF32Type>()">;
336-
def F8E4M3FN : Type<CPred<"::llvm::isa<Float8E4M3FNType>($_self)">, "f8E4M3FN type">,
336+
def F8E4M3FN : Type<CPred<"$_self.isFloat8E4M3FN()">, "f8E4M3FN type">,
337337
BuildableType<"$_builder.getType<Float8E4M3FNType>()">;
338-
def F8E5M2 : Type<CPred<"::llvm::isa<Float8E5M2Type>($_self)">, "f8E5M2 type">,
338+
def F8E5M2 : Type<CPred<"$_self.isFloat8E5M2()">, "f8E5M2 type">,
339339
BuildableType<"$_builder.getType<Float8E5M2Type>()">;
340-
def F8E4M3 : Type<CPred<"::llvm::isa<Float8E4M3Type>($_self)">, "f8E4M3 type">,
340+
def F8E4M3 : Type<CPred<"$_self.isFloat8E4M3()">, "f8E4M3 type">,
341341
BuildableType<"$_builder.getType<Float8E4M3Type>()">;
342-
def F8E4M3FNUZ : Type<CPred<"::llvm::isa<Float8E4M3FNUZType>($_self)">, "f8E4M3FNUZ type">,
342+
def F8E4M3FNUZ : Type<CPred<"$_self.isFloat8E4M3FNUZ()">, "f8E4M3FNUZ type">,
343343
BuildableType<"$_builder.getType<Float8E4M3FNUZType>()">;
344-
def F8E4M3B11FNUZ : Type<CPred<"::llvm::isa<Float8E4M3B11FNUZType>($_self)">, "f8E4M3B11FNUZ type">,
344+
def F8E4M3B11FNUZ : Type<CPred<"$_self.isFloat8E4M3B11FNUZ()">, "f8E4M3B11FNUZ type">,
345345
BuildableType<"$_builder.getType<Float8E4M3B11FNUZType>()">;
346-
def F8E5M2FNUZ : Type<CPred<"::llvm::isa<Float8E5M2FNUZType>($_self)">, "f8E5M2FNUZ type">,
346+
def F8E5M2FNUZ : Type<CPred<"$_self.isFloat8E5M2FNUZ()">, "f8E5M2FNUZ type">,
347347
BuildableType<"$_builder.getType<Float8E5M2FNUZType>()">;
348-
def F8E3M4 : Type<CPred<"::llvm::isa<Float8E3M4Type>($_self)">, "f8E3M4 type">,
348+
def F8E3M4 : Type<CPred<"$_self.isFloat8E3M4()">, "f8E3M4 type">,
349349
BuildableType<"$_builder.getType<Float8E3M4Type>()">;
350-
def F4E2M1FN : Type<CPred<"::llvm::isa<Float4E2M1FNType>($_self)">, "f4E2M1FN type">,
350+
def F4E2M1FN : Type<CPred<"$_self.isFloat4E2M1FN()">, "f4E2M1FN type">,
351351
BuildableType<"$_builder.getType<Float4E2M1FNType>()">;
352-
def F6E2M3FN : Type<CPred<"::llvm::isa<Float6E2M3FNType>($_self)">, "f6E2M3FN type">,
352+
def F6E2M3FN : Type<CPred<"$_self.isFloat6E2M3FN()">, "f6E2M3FN type">,
353353
BuildableType<"$_builder.getType<Float6E2M3FNType>()">;
354-
def F6E3M2FN : Type<CPred<"::llvm::isa<Float6E3M2FNType($_self)">, "f6E3M2FN type">,
354+
def F6E3M2FN : Type<CPred<"$_self.isFloat6E3M2FN()">, "f6E3M2FN type">,
355355
BuildableType<"$_builder.getType<Float6E3M2FNType>()">;
356-
def F8E8M0FNU : Type<CPred<"::llvm::isa<Float8E8M0FNUType>($_self)">, "f8E8M0FNU type">,
356+
def F8E8M0FNU : Type<CPred<"$_self.isFloat8E8M0FNU()">, "f8E8M0FNU type">,
357357
BuildableType<"$_builder.getType<Float8E8M0FNUType>()">;
358358

359359
def AnyComplex : Type<CPred<"::llvm::isa<::mlir::ComplexType>($_self)">,

mlir/include/mlir/IR/Types.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,17 @@ class Type {
125125
// Convenience predicates. This is only for floating point types,
126126
// derived types should use isa/dyn_cast.
127127
bool isIndex() const;
128+
bool isFloat4E2M1FN() const;
129+
bool isFloat6E2M3FN() const;
130+
bool isFloat6E3M2FN() const;
131+
bool isFloat8E5M2() const;
132+
bool isFloat8E4M3() const;
133+
bool isFloat8E4M3FN() const;
134+
bool isFloat8E5M2FNUZ() const;
135+
bool isFloat8E4M3FNUZ() const;
136+
bool isFloat8E4M3B11FNUZ() const;
137+
bool isFloat8E3M4() const;
138+
bool isFloat8E8M0FNU() const;
128139
bool isBF16() const;
129140
bool isF16() const;
130141
bool isTF32() const;

mlir/lib/CAPI/IR/BuiltinTypes.cpp

Lines changed: 16 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ MlirTypeID mlirFloat4E2M1FNTypeGetTypeID() {
9090
}
9191

9292
bool mlirTypeIsAFloat4E2M1FN(MlirType type) {
93-
return llvm::isa<Float4E2M1FNType>(unwrap(type));
93+
return unwrap(type).isFloat4E2M1FN();
9494
}
9595

9696
MlirType mlirFloat4E2M1FNTypeGet(MlirContext ctx) {
@@ -102,7 +102,7 @@ MlirTypeID mlirFloat6E2M3FNTypeGetTypeID() {
102102
}
103103

104104
bool mlirTypeIsAFloat6E2M3FN(MlirType type) {
105-
return llvm::isa<Float6E2M3FNType>(unwrap(type));
105+
return unwrap(type).isFloat6E2M3FN();
106106
}
107107

108108
MlirType mlirFloat6E2M3FNTypeGet(MlirContext ctx) {
@@ -114,7 +114,7 @@ MlirTypeID mlirFloat6E3M2FNTypeGetTypeID() {
114114
}
115115

116116
bool mlirTypeIsAFloat6E3M2FN(MlirType type) {
117-
return llvm::isa<Float6E3M2FNType>(unwrap(type));
117+
return unwrap(type).isFloat6E3M2FN();
118118
}
119119

120120
MlirType mlirFloat6E3M2FNTypeGet(MlirContext ctx) {
@@ -126,7 +126,7 @@ MlirTypeID mlirFloat8E5M2TypeGetTypeID() {
126126
}
127127

128128
bool mlirTypeIsAFloat8E5M2(MlirType type) {
129-
return llvm::isa<Float8E5M2Type>(unwrap(type));
129+
return unwrap(type).isFloat8E5M2();
130130
}
131131

132132
MlirType mlirFloat8E5M2TypeGet(MlirContext ctx) {
@@ -138,7 +138,7 @@ MlirTypeID mlirFloat8E4M3TypeGetTypeID() {
138138
}
139139

140140
bool mlirTypeIsAFloat8E4M3(MlirType type) {
141-
return llvm::isa<Float8E4M3Type>(unwrap(type));
141+
return unwrap(type).isFloat8E4M3();
142142
}
143143

144144
MlirType mlirFloat8E4M3TypeGet(MlirContext ctx) {
@@ -150,7 +150,7 @@ MlirTypeID mlirFloat8E4M3FNTypeGetTypeID() {
150150
}
151151

152152
bool mlirTypeIsAFloat8E4M3FN(MlirType type) {
153-
return llvm::isa<Float8E4M3FNType>(unwrap(type));
153+
return unwrap(type).isFloat8E4M3FN();
154154
}
155155

156156
MlirType mlirFloat8E4M3FNTypeGet(MlirContext ctx) {
@@ -162,7 +162,7 @@ MlirTypeID mlirFloat8E5M2FNUZTypeGetTypeID() {
162162
}
163163

164164
bool mlirTypeIsAFloat8E5M2FNUZ(MlirType type) {
165-
return llvm::isa<Float8E5M2FNUZType>(unwrap(type));
165+
return unwrap(type).isFloat8E5M2FNUZ();
166166
}
167167

168168
MlirType mlirFloat8E5M2FNUZTypeGet(MlirContext ctx) {
@@ -174,7 +174,7 @@ MlirTypeID mlirFloat8E4M3FNUZTypeGetTypeID() {
174174
}
175175

176176
bool mlirTypeIsAFloat8E4M3FNUZ(MlirType type) {
177-
return llvm::isa<Float8E4M3FNUZType>(unwrap(type));
177+
return unwrap(type).isFloat8E4M3FNUZ();
178178
}
179179

180180
MlirType mlirFloat8E4M3FNUZTypeGet(MlirContext ctx) {
@@ -186,7 +186,7 @@ MlirTypeID mlirFloat8E4M3B11FNUZTypeGetTypeID() {
186186
}
187187

188188
bool mlirTypeIsAFloat8E4M3B11FNUZ(MlirType type) {
189-
return llvm::isa<Float8E4M3B11FNUZType>(unwrap(type));
189+
return unwrap(type).isFloat8E4M3B11FNUZ();
190190
}
191191

192192
MlirType mlirFloat8E4M3B11FNUZTypeGet(MlirContext ctx) {
@@ -198,7 +198,7 @@ MlirTypeID mlirFloat8E3M4TypeGetTypeID() {
198198
}
199199

200200
bool mlirTypeIsAFloat8E3M4(MlirType type) {
201-
return llvm::isa<Float8E3M4Type>(unwrap(type));
201+
return unwrap(type).isFloat8E3M4();
202202
}
203203

204204
MlirType mlirFloat8E3M4TypeGet(MlirContext ctx) {
@@ -210,7 +210,7 @@ MlirTypeID mlirFloat8E8M0FNUTypeGetTypeID() {
210210
}
211211

212212
bool mlirTypeIsAFloat8E8M0FNU(MlirType type) {
213-
return llvm::isa<Float8E8M0FNUType>(unwrap(type));
213+
return unwrap(type).isFloat8E8M0FNU();
214214
}
215215

216216
MlirType mlirFloat8E8M0FNUTypeGet(MlirContext ctx) {
@@ -221,19 +221,15 @@ MlirTypeID mlirBFloat16TypeGetTypeID() {
221221
return wrap(BFloat16Type::getTypeID());
222222
}
223223

224-
bool mlirTypeIsABF16(MlirType type) {
225-
return llvm::isa<BFloat16Type>(unwrap(type));
226-
}
224+
bool mlirTypeIsABF16(MlirType type) { return unwrap(type).isBF16(); }
227225

228226
MlirType mlirBF16TypeGet(MlirContext ctx) {
229227
return wrap(BFloat16Type::get(unwrap(ctx)));
230228
}
231229

232230
MlirTypeID mlirFloat16TypeGetTypeID() { return wrap(Float16Type::getTypeID()); }
233231

234-
bool mlirTypeIsAF16(MlirType type) {
235-
return llvm::isa<Float16Type>(unwrap(type));
236-
}
232+
bool mlirTypeIsAF16(MlirType type) { return unwrap(type).isF16(); }
237233

238234
MlirType mlirF16TypeGet(MlirContext ctx) {
239235
return wrap(Float16Type::get(unwrap(ctx)));
@@ -243,29 +239,23 @@ MlirTypeID mlirFloatTF32TypeGetTypeID() {
243239
return wrap(FloatTF32Type::getTypeID());
244240
}
245241

246-
bool mlirTypeIsATF32(MlirType type) {
247-
return llvm::isa<FloatTF32Type>(unwrap(type));
248-
}
242+
bool mlirTypeIsATF32(MlirType type) { return unwrap(type).isTF32(); }
249243

250244
MlirType mlirTF32TypeGet(MlirContext ctx) {
251245
return wrap(FloatTF32Type::get(unwrap(ctx)));
252246
}
253247

254248
MlirTypeID mlirFloat32TypeGetTypeID() { return wrap(Float32Type::getTypeID()); }
255249

256-
bool mlirTypeIsAF32(MlirType type) {
257-
return llvm::isa<Float32Type>(unwrap(type));
258-
}
250+
bool mlirTypeIsAF32(MlirType type) { return unwrap(type).isF32(); }
259251

260252
MlirType mlirF32TypeGet(MlirContext ctx) {
261253
return wrap(Float32Type::get(unwrap(ctx)));
262254
}
263255

264256
MlirTypeID mlirFloat64TypeGetTypeID() { return wrap(Float64Type::getTypeID()); }
265257

266-
bool mlirTypeIsAF64(MlirType type) {
267-
return llvm::isa<Float64Type>(unwrap(type));
268-
}
258+
bool mlirTypeIsAF64(MlirType type) { return unwrap(type).isF64(); }
269259

270260
MlirType mlirF64TypeGet(MlirContext ctx) {
271261
return wrap(Float64Type::get(unwrap(ctx)));

mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp

Lines changed: 18 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -564,40 +564,38 @@ static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
564564
return ROCDL::mfma_f64_4x4x4f64::getOperationName();
565565
}
566566

567-
if (isa<Float8E5M2FNUZType>(sourceElem) && destElem.isF32() &&
568-
chipset >= kGfx940) {
567+
if (sourceElem.isFloat8E5M2FNUZ() && destElem.isF32() && chipset >= kGfx940) {
569568
// Known to be correct because there are no scalar f8 instructions and
570569
// because a length mismatch will have been caught by the verifier.
571570
Type sourceBElem =
572571
cast<VectorType>(mfma.getSourceB().getType()).getElementType();
573572
if (m == 16 && n == 16 && k == 32 && b == 1) {
574-
if (isa<Float8E5M2FNUZType>(sourceBElem))
573+
if (sourceBElem.isFloat8E5M2FNUZ())
575574
return ROCDL::mfma_f32_16x16x32_bf8_bf8::getOperationName();
576-
if (isa<Float8E4M3FNUZType>(sourceBElem))
575+
if (sourceBElem.isFloat8E4M3FNUZ())
577576
return ROCDL::mfma_f32_16x16x32_bf8_fp8::getOperationName();
578577
}
579578
if (m == 32 && n == 32 && k == 16 && b == 1) {
580-
if (isa<Float8E5M2FNUZType>(sourceBElem))
579+
if (sourceBElem.isFloat8E5M2FNUZ())
581580
return ROCDL::mfma_f32_32x32x16_bf8_bf8::getOperationName();
582-
if (isa<Float8E4M3FNUZType>(sourceBElem))
581+
if (sourceBElem.isFloat8E4M3FNUZ())
583582
return ROCDL::mfma_f32_32x32x16_bf8_fp8::getOperationName();
584583
}
585584
}
586585

587-
if (isa<Float8E4M3FNUZType>(sourceElem) && destElem.isF32() &&
588-
chipset >= kGfx940) {
586+
if (sourceElem.isFloat8E4M3FNUZ() && destElem.isF32() && chipset >= kGfx940) {
589587
Type sourceBElem =
590588
cast<VectorType>(mfma.getSourceB().getType()).getElementType();
591589
if (m == 16 && n == 16 && k == 32 && b == 1) {
592-
if (isa<Float8E5M2FNUZType>(sourceBElem))
590+
if (sourceBElem.isFloat8E5M2FNUZ())
593591
return ROCDL::mfma_f32_16x16x32_fp8_bf8::getOperationName();
594-
if (isa<Float8E4M3FNUZType>(sourceBElem))
592+
if (sourceBElem.isFloat8E4M3FNUZ())
595593
return ROCDL::mfma_f32_16x16x32_fp8_fp8::getOperationName();
596594
}
597595
if (m == 32 && n == 32 && k == 16 && b == 1) {
598-
if (isa<Float8E5M2FNUZType>(sourceBElem))
596+
if (sourceBElem.isFloat8E5M2FNUZ())
599597
return ROCDL::mfma_f32_32x32x16_fp8_bf8::getOperationName();
600-
if (isa<Float8E4M3FNUZType>(sourceBElem))
598+
if (sourceBElem.isFloat8E4M3FNUZ())
601599
return ROCDL::mfma_f32_32x32x16_fp8_fp8::getOperationName();
602600
}
603601
}
@@ -625,9 +623,9 @@ static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
625623
return ROCDL::wmma_bf16_16x16x16_bf16::getOperationName();
626624
if (elemSourceType.isInteger(8) && elemDestType.isInteger(32))
627625
return ROCDL::wmma_i32_16x16x16_iu8::getOperationName();
628-
if (isa<Float8E4M3FNType>(elemSourceType) && elemDestType.isF32())
626+
if (elemSourceType.isFloat8E4M3FN() && elemDestType.isF32())
629627
return ROCDL::wmma_f32_16x16x16_fp8::getOperationName();
630-
if (isa<Float8E5M2Type>(elemSourceType) && elemDestType.isF32())
628+
if (elemSourceType.isFloat8E5M2() && elemDestType.isF32())
631629
return ROCDL::wmma_f32_16x16x16_bf8::getOperationName();
632630
return std::nullopt;
633631
}
@@ -805,10 +803,10 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
805803
}
806804
Value i32Source = rewriter.create<LLVM::BitcastOp>(loc, i32, source);
807805
Value wordSel = createI32Constant(rewriter, loc, op.getIndex());
808-
if (isa<Float8E5M2FNUZType>(sourceElemType)) {
806+
if (sourceElemType.isFloat8E5M2FNUZ()) {
809807
rewriter.replaceOpWithNewOp<ROCDL::CvtF32Bf8Op>(op, f32, i32Source,
810808
wordSel);
811-
} else if (isa<Float8E4M3FNUZType>(sourceElemType)) {
809+
} else if (sourceElemType.isFloat8E4M3FNUZ()) {
812810
rewriter.replaceOpWithNewOp<ROCDL::CvtF32Fp8Op>(op, f32, i32Source,
813811
wordSel);
814812
}
@@ -840,10 +838,10 @@ LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite(
840838
Value wordSel = createI1Constant(rewriter, loc, op.getWordIndex());
841839

842840
Value result;
843-
if (isa<Float8E5M2FNUZType>(resultElemType))
841+
if (resultElemType.isFloat8E5M2FNUZ())
844842
result = rewriter.create<ROCDL::CvtPkBf8F32Op>(loc, i32, sourceA, sourceB,
845843
existing, wordSel);
846-
else if (isa<Float8E4M3FNUZType>(resultElemType))
844+
else if (resultElemType.isFloat8E4M3FNUZ())
847845
result = rewriter.create<ROCDL::CvtPkFp8F32Op>(loc, i32, sourceA, sourceB,
848846
existing, wordSel);
849847

@@ -875,10 +873,10 @@ LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(
875873
Value byteSel = createI32Constant(rewriter, loc, op.getStoreIndex());
876874

877875
Value result;
878-
if (isa<Float8E5M2FNUZType>(resultElemType))
876+
if (resultElemType.isFloat8E5M2FNUZ())
879877
result = rewriter.create<ROCDL::CvtSrBf8F32Op>(loc, i32, source, stoch,
880878
existing, byteSel);
881-
else if (isa<Float8E4M3FNUZType>(resultElemType))
879+
else if (resultElemType.isFloat8E4M3FNUZ())
882880
result = rewriter.create<ROCDL::CvtSrFp8F32Op>(loc, i32, source, stoch,
883881
existing, byteSel);
884882

mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ LogicalResult ExtFOnFloat8RewritePattern::match(arith::ExtFOp op) const {
8686
return failure();
8787
inType = inVecType.getElementType();
8888
}
89-
return success(isa<Float8E5M2FNUZType, Float8E4M3FNUZType>(inType));
89+
return success(inType.isFloat8E5M2FNUZ() || inType.isFloat8E4M3FNUZ());
9090
}
9191

9292
void ExtFOnFloat8RewritePattern::rewrite(arith::ExtFOp op,
@@ -216,7 +216,7 @@ LogicalResult TruncFToFloat8RewritePattern::match(arith::TruncFOp op) const {
216216
if (inType && inType.getWidth() <= 8 && saturateFP8)
217217
// Conversion between 8-bit floats is not supported with truncation enabled.
218218
return failure();
219-
return success(isa<Float8E5M2FNUZType, Float8E4M3FNUZType>(outType));
219+
return success(outType.isFloat8E5M2FNUZ() || outType.isFloat8E4M3FNUZ());
220220
}
221221

222222
void TruncFToFloat8RewritePattern::rewrite(arith::TruncFOp op,

mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -299,10 +299,11 @@ Type LLVMTypeConverter::convertFloatType(FloatType type) const {
299299
return type;
300300

301301
// F4, F6, F8 types are converted to integer types with the same bit width.
302-
if (isa<Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType, Float8E5M2FNUZType,
303-
Float8E4M3FNUZType, Float8E4M3B11FNUZType, Float8E3M4Type,
304-
Float4E2M1FNType, Float6E2M3FNType, Float6E3M2FNType,
305-
Float8E8M0FNUType>(type))
302+
if (type.isFloat8E5M2() || type.isFloat8E4M3() || type.isFloat8E4M3FN() ||
303+
type.isFloat8E5M2FNUZ() || type.isFloat8E4M3FNUZ() ||
304+
type.isFloat8E4M3B11FNUZ() || type.isFloat8E3M4() ||
305+
type.isFloat4E2M1FN() || type.isFloat6E2M3FN() || type.isFloat6E3M2FN() ||
306+
type.isFloat8E8M0FNU())
306307
return IntegerType::get(&getContext(), type.getWidth());
307308

308309
// Other floating-point types: A custom type conversion rule must be

mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1254,8 +1254,8 @@ struct NVGPUWarpgroupMmaOpLowering
12541254
wgmmaK = 8;
12551255
} else if (inputElemType.isF16() || inputElemType.isBF16()) {
12561256
wgmmaK = 16;
1257-
} else if (isa<Float8E4M3FNType, Float8E5M2Type>(inputElemType) ||
1258-
inputElemType.isInteger(16)) {
1257+
} else if (inputElemType.isFloat8E4M3FN() ||
1258+
inputElemType.isFloat8E5M2() || inputElemType.isInteger(16)) {
12591259
wgmmaK = 32;
12601260
} else if (inputElemType.isInteger(1)) {
12611261
wgmmaK = 256;
@@ -1276,9 +1276,9 @@ struct NVGPUWarpgroupMmaOpLowering
12761276
return NVVM::WGMMATypes::f16;
12771277
if (elemType.isBF16())
12781278
return NVVM::WGMMATypes::bf16;
1279-
if (isa<Float8E4M3FNType>(elemType))
1279+
if (elemType.isFloat8E4M3FN())
12801280
return NVVM::WGMMATypes::e4m3;
1281-
if (isa<Float8E5M2Type>(elemType))
1281+
if (elemType.isFloat8E5M2())
12821282
return NVVM::WGMMATypes::e5m2;
12831283
if (elemType.isInteger(1))
12841284
return NVVM::WGMMATypes::b1;

mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -272,14 +272,14 @@ LogicalResult MFMAOp::verify() {
272272
}
273273

274274
Type sourceBType = getSourceB().getType();
275-
if (isa<Float8E5M2FNUZType, Float8E4M3FNUZType>(sourceElem)) {
275+
if (sourceElem.isFloat8E5M2FNUZ() || sourceElem.isFloat8E4M3FNUZ()) {
276276
int64_t sourceBLen = 1;
277277
Type sourceBElem = sourceBType;
278278
if (auto sourceBVector = llvm::dyn_cast<VectorType>(sourceBType)) {
279279
sourceBLen = sourceBVector.getNumElements();
280280
sourceBElem = sourceBVector.getElementType();
281281
}
282-
if (!isa<Float8E5M2FNUZType, Float8E4M3FNUZType>(sourceBElem))
282+
if (!sourceBElem.isFloat8E5M2FNUZ() && !sourceBElem.isFloat8E4M3FNUZ())
283283
return emitOpError("expected both source operands to have f8 elements");
284284
if (sourceLen != sourceBLen)
285285
return emitOpError(

0 commit comments

Comments
 (0)