@@ -166,6 +166,11 @@ int GetExponentBias(mlir::FloatType ty) {
166
166
return 1 - llvm::APFloat::semanticsMinExponent (ty.getFloatSemantics ());
167
167
}
168
168
169
+ bool IsFNUZ (mlir::FloatType ty) {
170
+ return ty.isFloat8E4M3B11FNUZ () || ty.isFloat8E4M3FNUZ () ||
171
+ ty.isFloat8E5M2FNUZ ();
172
+ }
173
+
169
174
Value IsInf (Value value, mlir::ImplicitLocOpBuilder& b) {
170
175
auto ty = mlir::cast<mlir::FloatType>(value.getType ());
171
176
if (mlir::LLVM::isCompatibleOuterType (ty)) {
@@ -175,7 +180,7 @@ Value IsInf(Value value, mlir::ImplicitLocOpBuilder& b) {
175
180
return b.create <ma::CmpFOp>(ma::CmpFPredicate::OEQ, value, inf);
176
181
}
177
182
178
- assert (ty.getIntOrFloatBitWidth () = = 8 );
183
+ assert (ty.getIntOrFloatBitWidth () < = 8 );
179
184
// F8E5M2, F8E4M3, F8E3M4 are the only 8 bit float with infinities.
180
185
if (ty.isFloat8E5M2 ()) {
181
186
Val bits{b.create <ma::BitcastOp>(b.getI8Type (), value), &b};
@@ -196,6 +201,9 @@ Value IsNaN(Value value, mlir::ImplicitLocOpBuilder& b) {
196
201
if (mlir::LLVM::isCompatibleOuterType (ty)) {
197
202
return b.create <ma::CmpFOp>(ma::CmpFPredicate::UNO, value, value);
198
203
}
204
+ if (ty.isFloat4E2M1FN ()) {
205
+ return b.create <ma::ConstantIntOp>(false , b.getI1Type ());
206
+ }
199
207
200
208
assert (ty.getIntOrFloatBitWidth () == 8 );
201
209
Val bits{b.create <ma::BitcastOp>(b.getI8Type (), value), &b};
@@ -281,7 +289,7 @@ Value EmitFloatConversion(Value value, mlir::FloatType to_ty,
281
289
auto to_int_ty = b.getIntegerType (to_ty.getIntOrFloatBitWidth ());
282
290
283
291
mlir::IntegerType wide_int_ty;
284
- if (from_ty.getWidth () == 8 && to_ty.getWidth () = = 8 ) {
292
+ if (from_ty.getWidth () <= 8 && to_ty.getWidth () < = 8 ) {
285
293
wide_int_ty = b.getI16Type ();
286
294
} else {
287
295
wide_int_ty = b.getIntegerType (
@@ -300,21 +308,20 @@ Value EmitFloatConversion(Value value, mlir::FloatType to_ty,
300
308
int64_t exp_offset = to_bias - from_bias;
301
309
int digit_shift = to_mantissa - from_mantissa;
302
310
303
- Val from_bits{
304
- b.create <ma::BitcastOp>(
305
- b.getIntegerType (value.getType ().getIntOrFloatBitWidth ()), value),
306
- &b};
311
+ int from_width = value.getType ().getIntOrFloatBitWidth ();
312
+ Val from_bits{b.create <ma::BitcastOp>(b.getIntegerType (from_width), value),
313
+ &b};
314
+ if (from_width < 8 ) {
315
+ from_bits = convert_int (b.getIntegerType (8 ), from_bits);
316
+ }
307
317
308
318
auto cst = [&](mlir::Type ty, int64_t n) -> Val {
309
319
return {b.create <ma::ConstantIntOp>(n, ty), &b};
310
320
};
311
321
312
322
// Shift bits to destination type, without sign bit.
313
- Val from_sign_bit =
314
- from_bits.shrui (value.getType ().getIntOrFloatBitWidth () - 1 ) != 0 ;
315
-
316
- from_bits =
317
- from_bits & ((1ULL << (value.getType ().getIntOrFloatBitWidth () - 1 )) - 1 );
323
+ Val from_sign_bit = from_bits.shrui (from_width - 1 ) != 0 ;
324
+ from_bits = from_bits & ((1ULL << (from_width - 1 )) - 1 );
318
325
319
326
Value result_is_inf = IsInf (value, b);
320
327
Value input_is_nan = IsNaN (value, b);
@@ -327,6 +334,12 @@ Value EmitFloatConversion(Value value, mlir::FloatType to_ty,
327
334
Value to_nan = cst_bits (llvm::APFloat::getNaN (to_ty.getFloatSemantics ()));
328
335
Val to_zero = cst_bits (llvm::APFloat::getZero (to_ty.getFloatSemantics ()));
329
336
337
+ // MX float types have neither infinities nor NaNs.
338
+ if (to_ty.isFloat4E2M1FN ()) {
339
+ to_inf = cst_bits (llvm::APFloat::getLargest (to_ty.getFloatSemantics ()));
340
+ to_nan = to_zero | 0x8 ;
341
+ }
342
+
330
343
auto round_bits_to_nearest_even = [&](Val bits, Val roundoff) {
331
344
assert (bits.value .getType () == roundoff.value .getType ());
332
345
// Round to nearest even by adding a bias term.
@@ -394,10 +407,10 @@ Value EmitFloatConversion(Value value, mlir::FloatType to_ty,
394
407
Val bits = convert_int (wide_int_ty, from_bits);
395
408
396
409
// Determine exponent in target type.
397
- Value normalization_factor =
398
- convert_int ( i32_ty,
399
- b. create <mlir::math::CountLeadingZerosOp>(from_bits) ) -
400
- (from_int_ty. getWidth () - from_mantissa - 1 ) ;
410
+ Value clz = convert_int (
411
+ i32_ty, b. create <mlir::math::CountLeadingZerosOp>(from_bits));
412
+ Value msb = cst (i32_ty, std::max (from_width, 8 ) - 1 ) - clz;
413
+ Value normalization_factor = cst (i32_ty, from_mantissa) - msb ;
401
414
402
415
Val biased_exponent = cst (i32_ty, exp_offset + 1 ) - normalization_factor;
403
416
// If the result is subnormal, adjust the subnormal bits to account for
@@ -469,18 +482,13 @@ Value EmitFloatConversion(Value value, mlir::FloatType to_ty,
469
482
result);
470
483
}
471
484
472
- // Handle types with no unsigned zero.
473
- auto is_nuz = [](mlir::FloatType ty) {
474
- return ty.isFloat8E4M3B11FNUZ () || ty.isFloat8E4M3FNUZ () ||
475
- ty.isFloat8E5M2FNUZ ();
476
- };
477
-
478
- if (is_nuz (to_ty)) {
485
+ if (IsFNUZ (to_ty)) {
479
486
// Clear the sign bit if the result is zero (the output has no negative
480
- // zero).
481
- Val result_is_non_zero = Val{result, &b} != 0 ;
487
+ // zero). Handle the edge case when the input is zero and the result is not.
488
+ Val result_is_non_zero =
489
+ (digit_shift > 0 ? from_bits : Val{result, &b}) != 0 ;
482
490
from_sign_bit = from_sign_bit & result_is_non_zero;
483
- } else if (is_nuz (from_ty)) {
491
+ } else if (IsFNUZ (from_ty)) {
484
492
// Clear the sign bit if the input is NaN (it's positive but encoded as
485
493
// negative 0).
486
494
from_sign_bit = from_sign_bit ^ input_is_nan;
@@ -506,8 +514,8 @@ struct RewriteTruncFPattern : public mlir::OpRewritePattern<ma::TruncFOp> {
506
514
using FloatValue = mlir::TypedValue<mlir::FloatType>;
507
515
auto src = mlir::cast<FloatValue>(op.getOperand ());
508
516
auto dst_ty = mlir::cast<mlir::FloatType>(op.getType ());
509
- if (dst_ty.getWidth () != 8 ) {
510
- return rewriter.notifyMatchFailure (op, " not an 8 bit truncf" );
517
+ if (dst_ty.getWidth () > 8 ) {
518
+ return rewriter.notifyMatchFailure (op, " not an 8 bit (or less) truncf" );
511
519
}
512
520
513
521
mlir::ImplicitLocOpBuilder b (op.getLoc (), rewriter);
@@ -524,8 +532,8 @@ struct RewriteExtFPattern : public mlir::OpRewritePattern<ma::ExtFOp> {
524
532
using FloatValue = mlir::TypedValue<mlir::FloatType>;
525
533
auto src = mlir::cast<FloatValue>(op.getOperand ());
526
534
auto dst_ty = mlir::cast<mlir::FloatType>(op.getType ());
527
- if (src.getType ().getWidth () != 8 ) {
528
- return rewriter.notifyMatchFailure (op, " not an 8 bit extf" );
535
+ if (src.getType ().getWidth () > 8 ) {
536
+ return rewriter.notifyMatchFailure (op, " not an 8 bit (or less) extf" );
529
537
}
530
538
531
539
mlir::ImplicitLocOpBuilder b (op.getLoc (), rewriter);
@@ -544,25 +552,25 @@ struct RewriteF8Cst : public mlir::OpRewritePattern<ma::CmpFOp> {
544
552
auto lhs = mlir::cast<FloatValue>(op.getLhs ());
545
553
auto rhs = mlir::cast<FloatValue>(op.getRhs ());
546
554
547
- if (lhs.getType ().getWidth () != 8 ) {
548
- return rewriter.notifyMatchFailure (op, " not an 8 bit cmpf" );
555
+ if (lhs.getType ().getWidth () > 8 ) {
556
+ return rewriter.notifyMatchFailure (op, " not an 8 bit (or less) cmpf" );
549
557
}
550
558
551
559
mlir::ImplicitLocOpBuilder b (op.getLoc (), rewriter);
552
560
// Skip the f32 conversion if we're comparing UNE.cst.
553
561
llvm::APFloat rhs_cst (rhs.getType ().getFloatSemantics ());
554
562
if (op.getPredicate () == ma::CmpFPredicate::UNE &&
555
563
mlir::matchPattern (rhs, mlir::m_ConstantFloat (&rhs_cst))) {
556
- Val int_value{b.create <ma::BitcastOp>(rewriter.getI8Type (), lhs), &b};
564
+ mlir::Type int_ty = rewriter.getIntegerType (lhs.getType ().getWidth ());
565
+ Val int_value{b.create <ma::BitcastOp>(int_ty, lhs), &b};
557
566
int64_t constant = rhs_cst.bitcastToAPInt ().getZExtValue ();
558
567
// If we're comparing to +-0, compare the absolute values.
559
- if (rhs_cst.isZero () &&
560
- (lhs.getType ().isFloat8E3M4 () || lhs.getType ().isFloat8E4M3 () ||
561
- lhs.getType ().isFloat8E4M3FN () || lhs.getType ().isFloat8E5M2 ())) {
562
- int_value = int_value & 0x7f ;
563
- constant &= 0x7f ;
568
+ if (rhs_cst.isZero () && !IsFNUZ (lhs.getType ())) {
569
+ int64_t mask = (1 << (lhs.getType ().getWidth () - 1 )) - 1 ;
570
+ int_value = int_value & mask;
571
+ constant &= mask;
564
572
}
565
- auto cst = b.create <ma::ConstantIntOp>(constant, rewriter. getI8Type () );
573
+ auto cst = b.create <ma::ConstantIntOp>(constant, int_ty );
566
574
rewriter.replaceOpWithNewOp <ma::CmpIOp>(op, ma::CmpIPredicate::ne,
567
575
int_value, cst);
568
576
return mlir::success ();
@@ -586,18 +594,15 @@ struct RewriteAbsFPattern : public mlir::OpRewritePattern<mlir::math::AbsFOp> {
586
594
auto src = mlir::cast<FloatValue>(op.getOperand ());
587
595
// LowerGpuOpsToNVVMOps has a lowering for abs that doesn't work with bf16.
588
596
// Once that's removed, remove the code for BF16 here.
589
- if (src.getType ().getWidth () != 8 && !src.getType ().isBF16 ()) {
590
- return rewriter.notifyMatchFailure (op, " not an f8 or bf16 absf" );
597
+ if (src.getType ().getWidth () > 8 && !src.getType ().isBF16 ()) {
598
+ return rewriter.notifyMatchFailure (op,
599
+ " not an f8 (or less) or bf16 absf" );
591
600
}
592
601
mlir::ImplicitLocOpBuilder b (op.getLoc (), rewriter);
593
602
mlir::Type i_ty = rewriter.getIntegerType (src.getType ().getWidth ());
594
603
Val value{b.create <ma::BitcastOp>(i_ty, src), &b};
595
- if (src.getType ().getWidth () == 8 ) {
596
- value = value & 0x7f ;
597
- } else {
598
- CHECK (src.getType ().isBF16 ());
599
- value = value & 0x7fff ;
600
- }
604
+ int64_t mask = (1ull << (src.getType ().getWidth () - 1 )) - 1 ;
605
+ value = value & mask;
601
606
rewriter.replaceOpWithNewOp <ma::BitcastOp>(op, src.getType (), value);
602
607
return mlir::success ();
603
608
}
@@ -609,8 +614,8 @@ struct RewriteIToFpPattern : public mlir::OpRewritePattern<Op> {
609
614
610
615
mlir::LogicalResult matchAndRewrite (
611
616
Op op, mlir::PatternRewriter& rewriter) const override {
612
- if (op.getType ().getIntOrFloatBitWidth () != 8 ) {
613
- return rewriter.notifyMatchFailure (op, " not an f8 itofp" );
617
+ if (op.getType ().getIntOrFloatBitWidth () > 8 ) {
618
+ return rewriter.notifyMatchFailure (op, " not an f8 (or less) itofp" );
614
619
}
615
620
Value to_float =
616
621
rewriter.create <Op>(op.getLoc (), rewriter.getF32Type (), op.getIn ());
@@ -625,8 +630,8 @@ struct RewriteFpToIPattern : public mlir::OpRewritePattern<Op> {
625
630
626
631
mlir::LogicalResult matchAndRewrite (
627
632
Op op, mlir::PatternRewriter& rewriter) const override {
628
- if (op.getIn ().getType ().getIntOrFloatBitWidth () != 8 ) {
629
- return rewriter.notifyMatchFailure (op, " not an f8 fptoi" );
633
+ if (op.getIn ().getType ().getIntOrFloatBitWidth () > 8 ) {
634
+ return rewriter.notifyMatchFailure (op, " not an f8 (or less) fptoi" );
630
635
}
631
636
Value to_f32 = rewriter.create <ma::ExtFOp>(
632
637
op.getLoc (), rewriter.getF32Type (), op.getIn ());
0 commit comments