Skip to content

Commit c479f09

Browse files
committed
Add F4E2M1FN type: conversion codegen
1 parent 70ca820 commit c479f09

File tree

7 files changed

+383
-67
lines changed

7 files changed

+383
-67
lines changed

xla/backends/gpu/codegen/transforms/expand_float_ops.cc

Lines changed: 55 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,11 @@ int GetExponentBias(mlir::FloatType ty) {
166166
return 1 - llvm::APFloat::semanticsMinExponent(ty.getFloatSemantics());
167167
}
168168

169+
bool IsFNUZ(mlir::FloatType ty) {
170+
return ty.isFloat8E4M3B11FNUZ() || ty.isFloat8E4M3FNUZ() ||
171+
ty.isFloat8E5M2FNUZ();
172+
}
173+
169174
Value IsInf(Value value, mlir::ImplicitLocOpBuilder& b) {
170175
auto ty = mlir::cast<mlir::FloatType>(value.getType());
171176
if (mlir::LLVM::isCompatibleOuterType(ty)) {
@@ -175,7 +180,7 @@ Value IsInf(Value value, mlir::ImplicitLocOpBuilder& b) {
175180
return b.create<ma::CmpFOp>(ma::CmpFPredicate::OEQ, value, inf);
176181
}
177182

178-
assert(ty.getIntOrFloatBitWidth() == 8);
183+
assert(ty.getIntOrFloatBitWidth() <= 8);
179184
// F8E5M2, F8E4M3, F8E3M4 are the only 8 bit float with infinities.
180185
if (ty.isFloat8E5M2()) {
181186
Val bits{b.create<ma::BitcastOp>(b.getI8Type(), value), &b};
@@ -196,6 +201,9 @@ Value IsNaN(Value value, mlir::ImplicitLocOpBuilder& b) {
196201
if (mlir::LLVM::isCompatibleOuterType(ty)) {
197202
return b.create<ma::CmpFOp>(ma::CmpFPredicate::UNO, value, value);
198203
}
204+
if (ty.isFloat4E2M1FN()) {
205+
return b.create<ma::ConstantIntOp>(false, b.getI1Type());
206+
}
199207

200208
assert(ty.getIntOrFloatBitWidth() == 8);
201209
Val bits{b.create<ma::BitcastOp>(b.getI8Type(), value), &b};
@@ -281,7 +289,7 @@ Value EmitFloatConversion(Value value, mlir::FloatType to_ty,
281289
auto to_int_ty = b.getIntegerType(to_ty.getIntOrFloatBitWidth());
282290

283291
mlir::IntegerType wide_int_ty;
284-
if (from_ty.getWidth() == 8 && to_ty.getWidth() == 8) {
292+
if (from_ty.getWidth() <= 8 && to_ty.getWidth() <= 8) {
285293
wide_int_ty = b.getI16Type();
286294
} else {
287295
wide_int_ty = b.getIntegerType(
@@ -300,21 +308,20 @@ Value EmitFloatConversion(Value value, mlir::FloatType to_ty,
300308
int64_t exp_offset = to_bias - from_bias;
301309
int digit_shift = to_mantissa - from_mantissa;
302310

303-
Val from_bits{
304-
b.create<ma::BitcastOp>(
305-
b.getIntegerType(value.getType().getIntOrFloatBitWidth()), value),
306-
&b};
311+
int from_width = value.getType().getIntOrFloatBitWidth();
312+
Val from_bits{b.create<ma::BitcastOp>(b.getIntegerType(from_width), value),
313+
&b};
314+
if (from_width < 8) {
315+
from_bits = convert_int(b.getIntegerType(8), from_bits);
316+
}
307317

308318
auto cst = [&](mlir::Type ty, int64_t n) -> Val {
309319
return {b.create<ma::ConstantIntOp>(n, ty), &b};
310320
};
311321

312322
// Shift bits to destination type, without sign bit.
313-
Val from_sign_bit =
314-
from_bits.shrui(value.getType().getIntOrFloatBitWidth() - 1) != 0;
315-
316-
from_bits =
317-
from_bits & ((1ULL << (value.getType().getIntOrFloatBitWidth() - 1)) - 1);
323+
Val from_sign_bit = from_bits.shrui(from_width - 1) != 0;
324+
from_bits = from_bits & ((1ULL << (from_width - 1)) - 1);
318325

319326
Value result_is_inf = IsInf(value, b);
320327
Value input_is_nan = IsNaN(value, b);
@@ -327,6 +334,12 @@ Value EmitFloatConversion(Value value, mlir::FloatType to_ty,
327334
Value to_nan = cst_bits(llvm::APFloat::getNaN(to_ty.getFloatSemantics()));
328335
Val to_zero = cst_bits(llvm::APFloat::getZero(to_ty.getFloatSemantics()));
329336

337+
// MX float types have neither infinities nor NaNs.
338+
if (to_ty.isFloat4E2M1FN()) {
339+
to_inf = cst_bits(llvm::APFloat::getLargest(to_ty.getFloatSemantics()));
340+
to_nan = to_zero | 0x8;
341+
}
342+
330343
auto round_bits_to_nearest_even = [&](Val bits, Val roundoff) {
331344
assert(bits.value.getType() == roundoff.value.getType());
332345
// Round to nearest even by adding a bias term.
@@ -394,10 +407,10 @@ Value EmitFloatConversion(Value value, mlir::FloatType to_ty,
394407
Val bits = convert_int(wide_int_ty, from_bits);
395408

396409
// Determine exponent in target type.
397-
Value normalization_factor =
398-
convert_int(i32_ty,
399-
b.create<mlir::math::CountLeadingZerosOp>(from_bits)) -
400-
(from_int_ty.getWidth() - from_mantissa - 1);
410+
Value clz = convert_int(
411+
i32_ty, b.create<mlir::math::CountLeadingZerosOp>(from_bits));
412+
Value msb = cst(i32_ty, std::max(from_width, 8) - 1) - clz;
413+
Value normalization_factor = cst(i32_ty, from_mantissa) - msb;
401414

402415
Val biased_exponent = cst(i32_ty, exp_offset + 1) - normalization_factor;
403416
// If the result is subnormal, adjust the subnormal bits to account for
@@ -469,18 +482,13 @@ Value EmitFloatConversion(Value value, mlir::FloatType to_ty,
469482
result);
470483
}
471484

472-
// Handle types with no unsigned zero.
473-
auto is_nuz = [](mlir::FloatType ty) {
474-
return ty.isFloat8E4M3B11FNUZ() || ty.isFloat8E4M3FNUZ() ||
475-
ty.isFloat8E5M2FNUZ();
476-
};
477-
478-
if (is_nuz(to_ty)) {
485+
if (IsFNUZ(to_ty)) {
479486
// Clear the sign bit if the result is zero (the output has no negative
480-
// zero).
481-
Val result_is_non_zero = Val{result, &b} != 0;
487+
// zero). Handle the edge case when the input is zero and the result is not.
488+
Val result_is_non_zero =
489+
(digit_shift > 0 ? from_bits : Val{result, &b}) != 0;
482490
from_sign_bit = from_sign_bit & result_is_non_zero;
483-
} else if (is_nuz(from_ty)) {
491+
} else if (IsFNUZ(from_ty)) {
484492
// Clear the sign bit if the input is NaN (it's positive but encoded as
485493
// negative 0).
486494
from_sign_bit = from_sign_bit ^ input_is_nan;
@@ -506,8 +514,8 @@ struct RewriteTruncFPattern : public mlir::OpRewritePattern<ma::TruncFOp> {
506514
using FloatValue = mlir::TypedValue<mlir::FloatType>;
507515
auto src = mlir::cast<FloatValue>(op.getOperand());
508516
auto dst_ty = mlir::cast<mlir::FloatType>(op.getType());
509-
if (dst_ty.getWidth() != 8) {
510-
return rewriter.notifyMatchFailure(op, "not an 8 bit truncf");
517+
if (dst_ty.getWidth() > 8) {
518+
return rewriter.notifyMatchFailure(op, "not an 8 bit (or less) truncf");
511519
}
512520

513521
mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
@@ -524,8 +532,8 @@ struct RewriteExtFPattern : public mlir::OpRewritePattern<ma::ExtFOp> {
524532
using FloatValue = mlir::TypedValue<mlir::FloatType>;
525533
auto src = mlir::cast<FloatValue>(op.getOperand());
526534
auto dst_ty = mlir::cast<mlir::FloatType>(op.getType());
527-
if (src.getType().getWidth() != 8) {
528-
return rewriter.notifyMatchFailure(op, "not an 8 bit extf");
535+
if (src.getType().getWidth() > 8) {
536+
return rewriter.notifyMatchFailure(op, "not an 8 bit (or less) extf");
529537
}
530538

531539
mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
@@ -544,25 +552,25 @@ struct RewriteF8Cst : public mlir::OpRewritePattern<ma::CmpFOp> {
544552
auto lhs = mlir::cast<FloatValue>(op.getLhs());
545553
auto rhs = mlir::cast<FloatValue>(op.getRhs());
546554

547-
if (lhs.getType().getWidth() != 8) {
548-
return rewriter.notifyMatchFailure(op, "not an 8 bit cmpf");
555+
if (lhs.getType().getWidth() > 8) {
556+
return rewriter.notifyMatchFailure(op, "not an 8 bit (or less) cmpf");
549557
}
550558

551559
mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
552560
// Skip the f32 conversion if we're comparing UNE.cst.
553561
llvm::APFloat rhs_cst(rhs.getType().getFloatSemantics());
554562
if (op.getPredicate() == ma::CmpFPredicate::UNE &&
555563
mlir::matchPattern(rhs, mlir::m_ConstantFloat(&rhs_cst))) {
556-
Val int_value{b.create<ma::BitcastOp>(rewriter.getI8Type(), lhs), &b};
564+
mlir::Type int_ty = rewriter.getIntegerType(lhs.getType().getWidth());
565+
Val int_value{b.create<ma::BitcastOp>(int_ty, lhs), &b};
557566
int64_t constant = rhs_cst.bitcastToAPInt().getZExtValue();
558567
// If we're comparing to +-0, compare the absolute values.
559-
if (rhs_cst.isZero() &&
560-
(lhs.getType().isFloat8E3M4() || lhs.getType().isFloat8E4M3() ||
561-
lhs.getType().isFloat8E4M3FN() || lhs.getType().isFloat8E5M2())) {
562-
int_value = int_value & 0x7f;
563-
constant &= 0x7f;
568+
if (rhs_cst.isZero() && !IsFNUZ(lhs.getType())) {
569+
int64_t mask = (1 << (lhs.getType().getWidth() - 1)) - 1;
570+
int_value = int_value & mask;
571+
constant &= mask;
564572
}
565-
auto cst = b.create<ma::ConstantIntOp>(constant, rewriter.getI8Type());
573+
auto cst = b.create<ma::ConstantIntOp>(constant, int_ty);
566574
rewriter.replaceOpWithNewOp<ma::CmpIOp>(op, ma::CmpIPredicate::ne,
567575
int_value, cst);
568576
return mlir::success();
@@ -586,18 +594,15 @@ struct RewriteAbsFPattern : public mlir::OpRewritePattern<mlir::math::AbsFOp> {
586594
auto src = mlir::cast<FloatValue>(op.getOperand());
587595
// LowerGpuOpsToNVVMOps has a lowering for abs that doesn't work with bf16.
588596
// Once that's removed, remove the code for BF16 here.
589-
if (src.getType().getWidth() != 8 && !src.getType().isBF16()) {
590-
return rewriter.notifyMatchFailure(op, "not an f8 or bf16 absf");
597+
if (src.getType().getWidth() > 8 && !src.getType().isBF16()) {
598+
return rewriter.notifyMatchFailure(op,
599+
"not an f8 (or less) or bf16 absf");
591600
}
592601
mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
593602
mlir::Type i_ty = rewriter.getIntegerType(src.getType().getWidth());
594603
Val value{b.create<ma::BitcastOp>(i_ty, src), &b};
595-
if (src.getType().getWidth() == 8) {
596-
value = value & 0x7f;
597-
} else {
598-
CHECK(src.getType().isBF16());
599-
value = value & 0x7fff;
600-
}
604+
int64_t mask = (1ull << (src.getType().getWidth() - 1)) - 1;
605+
value = value & mask;
601606
rewriter.replaceOpWithNewOp<ma::BitcastOp>(op, src.getType(), value);
602607
return mlir::success();
603608
}
@@ -609,8 +614,8 @@ struct RewriteIToFpPattern : public mlir::OpRewritePattern<Op> {
609614

610615
mlir::LogicalResult matchAndRewrite(
611616
Op op, mlir::PatternRewriter& rewriter) const override {
612-
if (op.getType().getIntOrFloatBitWidth() != 8) {
613-
return rewriter.notifyMatchFailure(op, "not an f8 itofp");
617+
if (op.getType().getIntOrFloatBitWidth() > 8) {
618+
return rewriter.notifyMatchFailure(op, "not an f8 (or less) itofp");
614619
}
615620
Value to_float =
616621
rewriter.create<Op>(op.getLoc(), rewriter.getF32Type(), op.getIn());
@@ -625,8 +630,8 @@ struct RewriteFpToIPattern : public mlir::OpRewritePattern<Op> {
625630

626631
mlir::LogicalResult matchAndRewrite(
627632
Op op, mlir::PatternRewriter& rewriter) const override {
628-
if (op.getIn().getType().getIntOrFloatBitWidth() != 8) {
629-
return rewriter.notifyMatchFailure(op, "not an f8 fptoi");
633+
if (op.getIn().getType().getIntOrFloatBitWidth() > 8) {
634+
return rewriter.notifyMatchFailure(op, "not an f8 (or less) fptoi");
630635
}
631636
Value to_f32 = rewriter.create<ma::ExtFOp>(
632637
op.getLoc(), rewriter.getF32Type(), op.getIn());

xla/backends/gpu/codegen/transforms/tests/expand_float_ops.mlir

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,3 +115,40 @@ module {
115115
// CHECK: %[[EXT:.*]] = arith.extf {{.*}} : bf16 to f32
116116
// CHECK: arith.truncf %[[EXT]] : f32 to f16
117117
// CHECK-NOT: arith.truncf
118+
119+
// -----
120+
121+
module {
122+
func.func @f4_to_f16(%arg0: f4E2M1FN) -> f16 {
123+
%ret = arith.extf %arg0 : f4E2M1FN to f16
124+
return %ret : f16
125+
}
126+
}
127+
128+
// CHECK-LABEL: @f4_to_f16
129+
// CHECK-NOT: arith.extf
130+
131+
// -----
132+
133+
module {
134+
func.func @f16_to_f4(%arg0: f16) -> f4E2M1FN {
135+
%ret = arith.truncf %arg0 : f16 to f4E2M1FN
136+
return %ret : f4E2M1FN
137+
}
138+
}
139+
140+
// CHECK-LABEL: @f16_to_f4
141+
// CHECK-NOT: arith.truncf
142+
143+
// -----
144+
145+
module {
146+
func.func @f4_abs(%arg0: f4E2M1FN) -> f4E2M1FN {
147+
%ret = math.absf %arg0 : f4E2M1FN
148+
return %ret : f4E2M1FN
149+
}
150+
}
151+
152+
// CHECK-LABEL: @f4_abs
153+
// CHECK-NOT: math.absf
154+
// CHECK: arith.constant 7 : i4

0 commit comments

Comments
 (0)