Skip to content

Commit 75379aa

Browse files
Replace isF...() LLVM API calls with the corresponding isa<...>()
The isF...() methods have been removed in the main LLVM branch: llvm/llvm-project#123326
1 parent 3f44826 commit 75379aa

File tree

15 files changed

+70
-58
lines changed

15 files changed

+70
-58
lines changed

include/triton/Conversion/MLIRTypes.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,15 +28,15 @@ inline Type bf16Ty(MLIRContext *ctx) { return BFloat16Type::get(ctx); }
2828

2929
inline bool isFloat(Type type) {
3030
return type.isF32() || type.isF64() || type.isF16() || type.isF128() ||
31-
type.isBF16() || type.isFloat8E4M3B11FNUZ() || type.isFloat8E4M3FN() ||
32-
type.isFloat8E4M3FNUZ() || type.isFloat8E5M2() ||
33-
type.isFloat8E5M2FNUZ();
31+
type.isBF16() || isa<Float8E4M3B11FNUZType>(type) ||
32+
isa<Float8E4M3FNType>(type) || isa<Float8E4M3FNUZType>(type) ||
33+
isa<Float8E5M2Type>(type) || isa<Float8E5M2FNUZType>(type);
3434
}
3535

3636
inline bool isFloat8(Type type) {
37-
return type.isFloat8E4M3B11FNUZ() || type.isFloat8E4M3FN() ||
38-
type.isFloat8E4M3FNUZ() || type.isFloat8E5M2() ||
39-
type.isFloat8E5M2FNUZ();
37+
return isa<Float8E4M3B11FNUZType>(type) || isa<Float8E4M3FNType>(type) ||
38+
isa<Float8E4M3FNUZType>(type) || isa<Float8E5M2Type>(type) ||
39+
isa<Float8E5M2FNUZType>(type);
4040
}
4141

4242
inline bool isInt(Type type) { return type.isIntOrFloat() && !isFloat(type); }

lib/Analysis/Utility.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -756,14 +756,14 @@ bool supportMMA(triton::DotOp op, int version) {
756756
return false;
757757
if (!(numWarps % 4 == 0 && retShapePerCTA[rank - 2] % 64 == 0 &&
758758
retShapePerCTA[rank - 1] % 8 == 0 &&
759-
(aElemTy.isFloat8E5M2() || aElemTy.isFloat8E4M3FN() ||
759+
(isa<Float8E5M2Type>(aElemTy) || isa<Float8E4M3FNType>(aElemTy) ||
760760
aElemTy.isInteger(8) || aElemTy.isF16() || aElemTy.isBF16() ||
761761
aElemTy.isF32()))) {
762762
return false;
763763
}
764764
// We cannot use MMA_V3 if we need to accumulate in F32 within the MMA op.
765765
if (op.getMaxNumImpreciseAcc() < 32 &&
766-
(aElemTy.isFloat8E5M2() || aElemTy.isFloat8E4M3FN()) &&
766+
(isa<Float8E5M2Type>(aElemTy) || isa<Float8E4M3FNType>(aElemTy)) &&
767767
cast<RankedTensorType>(op.getType()).getElementType().isF32()) {
768768
return false;
769769
}
@@ -784,8 +784,9 @@ bool supportMMA(Value value, int version) {
784784
cast<triton::gpu::TensorOrMemDesc>(value.getType()).getElementType();
785785
// FP8 is not natively supported on all mma versions but it can always be
786786
// promoted to fp16 therefore we can always support it.
787-
bool isFP8 = elemTy.isFloat8E5M2() || elemTy.isFloat8E4M3FN() ||
788-
elemTy.isFloat8E5M2FNUZ() || elemTy.isFloat8E4M3FNUZ();
787+
bool isFP8 = isa<Float8E5M2Type>(elemTy) || isa<Float8E4M3FNType>(elemTy) ||
788+
isa<Float8E5M2FNUZType>(elemTy) ||
789+
isa<Float8E4M3FNUZType>(elemTy);
789790
return isFP8 || elemTy.isF16() || elemTy.isBF16() ||
790791
(elemTy.isF32() && version >= 2) ||
791792
(elemTy.isInteger(8) && version >= 2);

lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -632,7 +632,8 @@ static void decomposeMixedModeDotOp(ModuleOp mod, int computeCapability) {
632632
NvidiaMmaEncodingAttr mmaLayout =
633633
dyn_cast<NvidiaMmaEncodingAttr>(D.getType().getEncoding());
634634
if (mmaLayout) {
635-
bool isNativeFP8 = AElType.isFloat8E5M2() || AElType.isFloat8E4M3FN();
635+
bool isNativeFP8 =
636+
isa<Float8E5M2Type>(AElType) || isa<Float8E4M3FNType>(AElType);
636637
// promote operands for sm < 89 since fp8 mma is not natively supported
637638
// promote operands for sm >= 90 when mma is not v3
638639
if (!isNativeFP8 ||

lib/Dialect/TritonGPU/Transforms/Utility.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,9 @@ SmallVector<unsigned, 3> mmaVersionToInstrShape(int version,
4545
SmallVector<unsigned> validN;
4646

4747
// MMAv3 with larger instruction shape is preferred.
48-
if (eltType.isFloat8E5M2() || eltType.isFloat8E4M3FN() ||
49-
eltType.isFloat8E4M3FNUZ() || eltType.isF16() || eltType.isBF16() ||
50-
eltType.isF32()) {
48+
if (isa<Float8E5M2Type>(eltType) || isa<Float8E4M3FNType>(eltType) ||
49+
isa<Float8E4M3FNUZType>(eltType) || eltType.isF16() ||
50+
eltType.isBF16() || eltType.isF32()) {
5151
validN.assign({256, 248, 240, 232, 224, 216, 208, 200, 192, 184, 176,
5252
168, 160, 152, 144, 136, 128, 120, 112, 104, 96, 88,
5353
80, 72, 64, 56, 48, 40, 32, 24, 16, 8});

lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,8 @@ bool WarpGroupDotOp::needsPartialAccumulator() {
7777
const auto &d = getD();
7878
auto aTensorTy = cast<triton::gpu::TensorOrMemDesc>(a.getType());
7979
auto aElTy = cast<triton::gpu::TensorOrMemDesc>(a.getType()).getElementType();
80-
bool isFP8 = aElTy.isFloat8E5M2() || aElTy.isFloat8E4M3FN() ||
81-
aElTy.isFloat8E5M2FNUZ() || aElTy.isFloat8E4M3FNUZ();
80+
bool isFP8 = isa<Float8E5M2Type>(aElTy) || isa<Float8E4M3FNType>(aElTy) ||
81+
isa<Float8E5M2FNUZType>(aElTy) || isa<Float8E4M3FNUZType>(aElTy);
8282
bool accFP32 =
8383
cast<triton::gpu::TensorOrMemDesc>(d.getType()).getElementType().isF32();
8484
uint32_t maxNumImpreciseAcc = getMaxNumImpreciseAcc();

third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1019,17 +1019,18 @@ struct FpToFpOpConversion
10191019
return outVals;
10201020
}
10211021
size_t numElements = 4;
1022-
if (srcElementType.isFloat8E4M3FN() || dstElementType.isFloat8E4M3FN() ||
1023-
srcElementType.isFloat8E4M3FNUZ() ||
1024-
dstElementType.isFloat8E4M3FNUZ() ||
1025-
srcElementType.isFloat8E5M2FNUZ() ||
1026-
dstElementType.isFloat8E5M2FNUZ()) {
1022+
if (isa<Float8E4M3FNType>(srcElementType) ||
1023+
isa<Float8E4M3FNType>(dstElementType) ||
1024+
isa<Float8E4M3FNUZType>(srcElementType) ||
1025+
isa<Float8E4M3FNUZType>(dstElementType) ||
1026+
isa<Float8E5M2FNUZType>(srcElementType) ||
1027+
isa<Float8E5M2FNUZType>(dstElementType)) {
10271028
numElements = 2;
10281029
}
10291030
bool useFP16IntermediateSrc =
10301031
srcElementType.isF32() && !(isaFamily == AMD::ISAFamily::CDNA3 &&
1031-
(dstElementType.isFloat8E4M3FNUZ() ||
1032-
dstElementType.isFloat8E5M2FNUZ()));
1032+
(isa<Float8E4M3FNUZType>(dstElementType) ||
1033+
isa<Float8E5M2FNUZType>(dstElementType)));
10331034
bool isDstFP32 = dstElementType.isF32();
10341035
Type srcType = useFP16IntermediateSrc ? f16_ty : srcElementType;
10351036
Type dstType = isDstFP32 ? f16_ty : dstElementType;

third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -416,7 +416,8 @@ class BlockedToMFMA : public OpRewritePattern<tt::DotOp> {
416416
// store instructions, except for fp8 matmul kernels due to regression
417417
// TODO (lixun): investigate the regression and enable this feature again
418418
auto aElemTy = mfmaInstr.getElementTypeA();
419-
bool isFP8 = aElemTy.isFloat8E5M2FNUZ() || aElemTy.isFloat8E4M3FNUZ();
419+
bool isFP8 =
420+
isa<Float8E5M2FNUZType>(aElemTy) || isa<Float8E4M3FNUZType>(aElemTy);
420421
bool isTransposed = isChainDot(dotOp) || !isFP8;
421422
mfmaEnc = ttg::AMDMfmaEncodingAttr::get(
422423
oldRetType.getContext(),

third_party/amd/lib/TritonAMDGPUTransforms/MfmaGroup.cpp

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,19 +20,23 @@ static MfmaTypeId chooseAppropriateMfmaId(mlir::Type dataTypeA,
2020
if (dataTypeA.isInteger(8) && dataTypeB.isInteger(8)) {
2121
return MfmaTypeId::I8TyId;
2222
}
23-
if (dataTypeA.isFloat8E4M3FNUZ() && dataTypeB.isFloat8E4M3FNUZ()) {
23+
if (isa<Float8E4M3FNUZType>(dataTypeA) &&
24+
isa<Float8E4M3FNUZType>(dataTypeB)) {
2425
return MfmaTypeId::Fp8Fp8TyId;
2526
}
26-
if (dataTypeA.isFloat8E4M3FNUZ() && dataTypeB.isFloat8E5M2FNUZ()) {
27+
if (isa<Float8E4M3FNUZType>(dataTypeA) &&
28+
isa<Float8E5M2FNUZType>(dataTypeB)) {
2729
return MfmaTypeId::Fp8Bf8TyId;
2830
}
29-
if (dataTypeA.isFloat8E5M2FNUZ() && dataTypeB.isFloat8E4M3FNUZ()) {
31+
if (isa<Float8E5M2FNUZType>(dataTypeA) &&
32+
isa<Float8E4M3FNUZType>(dataTypeB)) {
3033
return MfmaTypeId::Bf8Fp8TyId;
3134
}
32-
if (dataTypeA.isFloat8E5M2FNUZ() && dataTypeB.isFloat8E5M2FNUZ()) {
35+
if (isa<Float8E5M2FNUZType>(dataTypeA) &&
36+
isa<Float8E5M2FNUZType>(dataTypeB)) {
3337
return MfmaTypeId::Bf8Bf8TyId;
3438
}
35-
if (dataTypeA.isFloat8E5M2() && dataTypeB.isFloat8E5M2()) {
39+
if (isa<Float8E5M2Type>(dataTypeA) && isa<Float8E5M2Type>(dataTypeB)) {
3640
return MfmaTypeId::Fp16TyId;
3741
}
3842
llvm_unreachable("Unsupported input argument type.");

third_party/intel/lib/Analysis/DPAS.cpp

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -125,9 +125,9 @@ DPASAnalysis::getDPASType(OpTy op) {
125125
if (aElemTy.isF32() && op.getInputPrecision() == InputPrecision::TF32)
126126
return DPASEngineType::FP32_FP32_TF32_TF32;
127127
// For FP8XFP8->FP32, upcast to FP16
128-
if (aElemTy.isFloat8E5M2())
128+
if (isa<Float8E5M2Type>(aElemTy))
129129
return DPASEngineType::FP32_FP32_FP16_FP16;
130-
if (aElemTy.isFloat8E4M3FN())
130+
if (isa<Float8E4M3FNType>(aElemTy))
131131
return DPASEngineType::FP32_FP32_FP16_FP16;
132132
} else if (dElemTy.isF16()) {
133133
if (aElemTy.isF16())
@@ -148,35 +148,35 @@ DPASAnalysis::getDPASType(OpTy op) {
148148
if (isa<FloatType>(dElemTy)) {
149149
if (dElemTy.isF32()) {
150150
if (aElemTy.isBF16() &&
151-
(bElemTy.isFloat8E4M3FN() || bElemTy.isFloat8E5M2()))
151+
(isa<Float8E4M3FNType>(bElemTy) || isa<Float8E5M2Type>(bElemTy)))
152152
return DPASEngineType::FP32_FP32_BF16_FP8;
153153
// 2 E2M1 are packed into 1 int8
154154
if (aElemTy.isBF16() && bElemTy.isInteger(8))
155155
return DPASEngineType::FP32_FP32_BF16_FP4;
156-
if ((aElemTy.isFloat8E4M3FN() || aElemTy.isFloat8E5M2()) &&
156+
if ((isa<Float8E4M3FNType>(aElemTy) || isa<Float8E5M2Type>(aElemTy)) &&
157157
bElemTy.isBF16())
158158
return DPASEngineType::FP32_FP32_FP8_BF16;
159159
if (aElemTy.isF16() &&
160-
(bElemTy.isFloat8E4M3FN() || bElemTy.isFloat8E5M2()))
160+
(isa<Float8E4M3FNType>(bElemTy) || isa<Float8E5M2Type>(bElemTy)))
161161
return DPASEngineType::FP32_FP32_FP16_FP8;
162162
// 2 E2M1 are packed into 1 int8
163163
if (aElemTy.isF16() && bElemTy.isInteger(8))
164164
return DPASEngineType::FP32_FP32_FP16_FP4;
165-
if ((aElemTy.isFloat8E4M3FN() || aElemTy.isFloat8E5M2()) &&
165+
if ((isa<Float8E4M3FNType>(aElemTy) || isa<Float8E5M2Type>(aElemTy)) &&
166166
bElemTy.isF16())
167167
return DPASEngineType::FP32_FP32_FP8_FP16;
168-
if ((aElemTy.isFloat8E4M3FN() || aElemTy.isFloat8E5M2()) &&
169-
(bElemTy.isFloat8E4M3FN() || bElemTy.isFloat8E5M2()))
168+
if ((isa<Float8E4M3FNType>(aElemTy) || isa<Float8E5M2Type>(aElemTy)) &&
169+
(isa<Float8E4M3FNType>(bElemTy) || isa<Float8E5M2Type>(bElemTy)))
170170
return DPASEngineType::FP32_FP32_FP8_FP8;
171-
if ((aElemTy.isFloat8E4M3FN() || aElemTy.isFloat8E5M2()) &&
171+
if ((isa<Float8E4M3FNType>(aElemTy) || isa<Float8E5M2Type>(aElemTy)) &&
172172
bElemTy.isInteger(8))
173173
return DPASEngineType::FP32_FP32_FP8_FP4;
174174
if (aElemTy.isInteger(8) && bElemTy.isBF16())
175175
return DPASEngineType::FP32_FP32_FP4_BF16;
176176
if (aElemTy.isInteger(8) && bElemTy.isF16())
177177
return DPASEngineType::FP32_FP32_FP4_FP16;
178178
if (aElemTy.isInteger(8) &&
179-
(bElemTy.isFloat8E4M3FN() || bElemTy.isFloat8E5M2()))
179+
(isa<Float8E4M3FNType>(bElemTy) || isa<Float8E5M2Type>(bElemTy)))
180180
return DPASEngineType::FP32_FP32_FP4_FP8;
181181
}
182182
}

third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -405,7 +405,8 @@ unsigned DpasEncodingAttr::getOpsPerChannel(Type elemType) {
405405
assert(elemType.isIntOrFloat() && "unsupported type for DpasEncodingAttr");
406406

407407
unsigned dpasElemBitWidths = elemType.getIntOrFloatBitWidth();
408-
if (elemType.isFloat8E5M2() || elemType.isFloat8E4M3FN())
408+
if (llvm::isa<Float8E5M2Type>(elemType) ||
409+
llvm::isa<Float8E4M3FNType>(elemType))
409410
dpasElemBitWidths *= 2; // We are upcasting FP8 to FP16.
410411

411412
return DPASCapability::opsChanBitWidths / dpasElemBitWidths;

third_party/intel/lib/TritonIntelGPUToLLVM/ElementwiseOpToLLVM.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -960,7 +960,8 @@ struct FpToFpOpConversion
960960
auto dstElementType = getElementType(op.getResult());
961961
auto roundingMode = op.getRounding();
962962

963-
if (dstElementType.isFloat8E5M2() || dstElementType.isFloat8E4M3FN()) {
963+
if (isa<Float8E5M2Type>(dstElementType) ||
964+
isa<Float8E4M3FNType>(dstElementType)) {
964965
assert(roundingMode.has_value() &&
965966
"Rounding mode must be specified for conversions to fp8");
966967

third_party/intel/lib/TritonIntelGPUTransforms/AccelerateMatmul.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -132,8 +132,8 @@ class BlockedToDPAS : public OpRewritePattern<tt::DotOp> {
132132
oldAType.getElementType().getIntOrFloatBitWidth();
133133

134134
// We are upcasting FP8 to FP16
135-
if (oldAType.getElementType().isFloat8E5M2() ||
136-
oldAType.getElementType().isFloat8E4M3FN())
135+
if (isa<Float8E5M2Type>(oldAType.getElementType()) ||
136+
isa<Float8E4M3FNType>(oldAType.getElementType()))
137137
dpasElemBitWidths = 2 * dpasElemBitWidths;
138138

139139
// Enlarge the repCluster size to use the large 2D load for A and B
@@ -488,7 +488,8 @@ static void decomposeMixedModeDotOp(ModuleOp mod) {
488488

489489
Type promoteType;
490490
if (dpasLayout) {
491-
bool isNativeFP8 = AElType.isFloat8E5M2() || AElType.isFloat8E4M3FN();
491+
bool isNativeFP8 =
492+
isa<Float8E5M2Type>(AElType) || isa<Float8E4M3FNType>(AElType);
492493
// fp8 is not natively supported by the the DPAS instruction, promote it
493494
// to fp16.
494495
if (!isNativeFP8)

third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -299,17 +299,17 @@ TensorCoreType getMmaType(triton::DotOp op) {
299299
return TensorCoreType::FP32_FP16_FP16_FP32;
300300
if (aTy.getElementType().isBF16() && bTy.getElementType().isBF16())
301301
return TensorCoreType::FP32_BF16_BF16_FP32;
302-
if (aTy.getElementType().isFloat8E5M2() &&
303-
bTy.getElementType().isFloat8E5M2())
302+
if (isa<Float8E5M2Type>(aTy.getElementType()) &&
303+
isa<Float8E5M2Type>(bTy.getElementType()))
304304
return TensorCoreType::FP32_FP8E5M2_FP8E5M2_FP32;
305-
if (aTy.getElementType().isFloat8E5M2() &&
306-
bTy.getElementType().isFloat8E4M3FN())
305+
if (isa<Float8E5M2Type>(aTy.getElementType()) &&
306+
isa<Float8E4M3FNType>(bTy.getElementType()))
307307
return TensorCoreType::FP32_FP8E5M2_FP8E4M3FN_FP32;
308-
if (aTy.getElementType().isFloat8E4M3FN() &&
309-
bTy.getElementType().isFloat8E5M2())
308+
if (isa<Float8E4M3FNType>(aTy.getElementType()) &&
309+
isa<Float8E5M2Type>(bTy.getElementType()))
310310
return TensorCoreType::FP32_FP8E4M3FN_FP8E5M2_FP32;
311-
if (aTy.getElementType().isFloat8E4M3FN() &&
312-
bTy.getElementType().isFloat8E4M3FN())
311+
if (isa<Float8E4M3FNType>(aTy.getElementType()) &&
312+
isa<Float8E4M3FNType>(bTy.getElementType()))
313313
return TensorCoreType::FP32_FP8E4M3FN_FP8E4M3FN_FP32;
314314
if (aTy.getElementType().isF32() && bTy.getElementType().isF32() &&
315315
op.getInputPrecision() == InputPrecision::TF32)

third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,9 +59,9 @@ triton::nvgpu::WGMMAEltType getMmaOperandType(Value a, bool allowTF32) {
5959
return triton::nvgpu::WGMMAEltType::tf32;
6060
} else if (aTy.isInteger(8)) {
6161
return triton::nvgpu::WGMMAEltType::s8;
62-
} else if (aTy.isFloat8E5M2()) {
62+
} else if (isa<Float8E5M2Type>(aTy)) {
6363
return triton::nvgpu::WGMMAEltType::e5m2;
64-
} else if (aTy.isFloat8E4M3FN()) {
64+
} else if (isa<Float8E4M3FNType>(aTy)) {
6565
return triton::nvgpu::WGMMAEltType::e4m3;
6666
} else {
6767
llvm::report_fatal_error("Unsupported mma operand type found");

third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ElementwiseOpToLLVM.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -467,7 +467,7 @@ struct FpToFpOpConversion
467467
llvm::report_fatal_error("Unsupported rounding mode for conversion.");
468468
}
469469
if (computeCapability < 89 &&
470-
(srcTy.isFloat8E4M3FN() || dstTy.isFloat8E4M3FN())) {
470+
(isa<Float8E4M3FNType>(srcTy) || isa<Float8E4M3FNType>(dstTy))) {
471471
llvm::errs() << "Conversion from/to f8e4m3nv is only supported on "
472472
"compute capability >= 89"
473473
<< "\n";
@@ -489,7 +489,8 @@ struct FpToFpOpConversion
489489
auto dstElementType = getElementType(op.getResult());
490490
auto roundingMode = op.getRounding();
491491

492-
if (dstElementType.isFloat8E5M2() || dstElementType.isFloat8E4M3FN()) {
492+
if (isa<Float8E5M2Type>(dstElementType) ||
493+
isa<Float8E4M3FNType>(dstElementType)) {
493494
assert(roundingMode.has_value() &&
494495
"Rounding mode must be specified for convertsions to fp8");
495496

@@ -526,8 +527,8 @@ struct FpToFpOpConversion
526527

527528
bool useFP16IntermediateSrc =
528529
srcElementType.isF32() &&
529-
(!(computeCapability >= 90 && (dstElementType.isFloat8E4M3FN() ||
530-
dstElementType.isFloat8E5M2())) ||
530+
(!(computeCapability >= 90 && (isa<Float8E4M3FNType>(dstElementType) ||
531+
isa<Float8E5M2Type>(dstElementType))) ||
531532
roundingMode.value() == RoundingMode::RTZ);
532533
bool isDstFP32 = dstElementType.isF32();
533534
Type srcType = useFP16IntermediateSrc ? f16_ty : srcElementType;

0 commit comments

Comments
 (0)