Skip to content

Commit 9e8c7bc

Browse files
committed
Add F8E8M0FNU type
1 parent d7d5af7 commit 9e8c7bc

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

71 files changed

+810
-140
lines changed
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
diff --git a/include/float8.h b/include/float8.h
2+
index 51d91fd..c53eb45 100644
3+
--- a/include/float8.h
4+
+++ b/include/float8.h
5+
@@ -1021,11 +1021,11 @@ struct numeric_limits_float8_e8m0fnu : public numeric_limits_float8_base {
6+
static inline constexpr const int max_digits10 =
7+
MaxDigits10FromDigits(digits);
8+
// 2**-127 smallest valid normalized value..
9+
- static inline constexpr const int min_exponent = -127 + 1;
10+
+ static inline constexpr const int min_exponent = -kExponentBias + 1;
11+
static inline constexpr const int min_exponent10 =
12+
MinExponent10FromMinExponent(min_exponent);
13+
// 128 encoding using for NaN
14+
- static inline constexpr const int max_exponent = 127;
15+
+ static inline constexpr const int max_exponent = kExponentBias + 1;
16+
static inline constexpr const int max_exponent10 =
17+
MaxExponent10FromMaxExponentAndDigits(max_exponent, digits);
18+
static inline constexpr const bool is_iec559 = false;
19+
@@ -1292,7 +1292,8 @@ struct Traits<float8_e8m0fnu> : public TraitsBase<float8_e8m0fnu> {
20+
};
21+
22+
template <typename Bits>
23+
-constexpr inline Bits RoundBitsToNearestEven(Bits bits, int roundoff) {
24+
+constexpr inline Bits RoundBitsToNearestEven(Bits bits, int roundoff,
25+
+ bool use_implicit_bit) {
26+
// Round to nearest even by adding a bias term.
27+
// Consider a bit pattern
28+
// FFF...FLRTT...T,
29+
@@ -1301,9 +1302,12 @@ constexpr inline Bits RoundBitsToNearestEven(Bits bits, int roundoff) {
30+
// - L is 1, R is 1, OR
31+
// - L is 0, R is 1, any T is one.
32+
// We do this by adding L to a bit pattern consisting of all T = 1.
33+
- Bits bias = roundoff == 0
34+
- ? 0
35+
- : ((bits >> roundoff) & 1) + (Bits{1} << (roundoff - 1)) - 1;
36+
+ //
37+
+ // When rounding to zero mantissa (E8M0 type), the L bit is implicitly 1 (do
38+
+ // not use the exponent bits for rounding). Add only the R bit in this case.
39+
+ Bits bias = !use_implicit_bit
40+
+ ? ((bits >> roundoff) & 1) + (Bits{1} << (roundoff - 1)) - 1
41+
+ : Bits{1} << (roundoff - 1);
42+
return bits + bias;
43+
}
44+
45+
@@ -1443,6 +1447,7 @@ struct ConvertImpl<From, To, kSaturate, kTruncate,
46+
}
47+
48+
const int biased_from_exponent = from_bits >> kFromMantissaBits;
49+
+ const bool to_zero_mantissa = kToMantissaBits == 0;
50+
51+
// `To` supports more exponents near zero which means that some subnormal
52+
// values in `From` may become normal.
53+
@@ -1473,11 +1478,14 @@ struct ConvertImpl<From, To, kSaturate, kTruncate,
54+
}
55+
56+
// Truncate/round mantissa if necessary.
57+
- if constexpr (kDigitShift > 0) {
58+
+ if constexpr (kDigitShift >= 0) {
59+
bits <<= kDigitShift;
60+
} else {
61+
if constexpr (!kTruncate) {
62+
- bits = RoundBitsToNearestEven(bits, -kDigitShift);
63+
+ // When converting float to e8m0, the bits represent a denormal,
64+
+ // so don't use the implicit mantissa bit for rounding.
65+
+ bits = RoundBitsToNearestEven(
66+
+ bits, -kDigitShift, to_zero_mantissa && kExponentOffset != 0);
67+
}
68+
bits >>= -kDigitShift;
69+
}
70+
@@ -1514,8 +1522,8 @@ struct ConvertImpl<From, To, kSaturate, kTruncate,
71+
// otherwise the lower precision bits may already be lost. There
72+
// is an edge-case where rounding to a normalized value would
73+
// normally round down, but for a subnormal, we need to round up.
74+
- rounded_from_bits =
75+
- RoundBitsToNearestEven(rounded_from_bits, exponent_shift);
76+
+ rounded_from_bits = RoundBitsToNearestEven(rounded_from_bits,
77+
+ exponent_shift, false);
78+
}
79+
bits = rounded_from_bits >> exponent_shift;
80+
}
81+
@@ -1532,7 +1540,8 @@ struct ConvertImpl<From, To, kSaturate, kTruncate,
82+
WideBits rounded_from_bits = from_bits;
83+
if constexpr (kDigitShift < 0) {
84+
if constexpr (!kTruncate) {
85+
- rounded_from_bits = RoundBitsToNearestEven(from_bits, -kDigitShift);
86+
+ rounded_from_bits =
87+
+ RoundBitsToNearestEven(from_bits, -kDigitShift, to_zero_mantissa);
88+
}
89+
// Zero-out tail bits.
90+
rounded_from_bits &= ~((WideBits{1} << (-kDigitShift)) - 1);
91+
@@ -1602,7 +1611,7 @@ struct ConvertImpl<Eigen::half, float8_e5m2, kSaturate, kTruncate> {
92+
}
93+
94+
if constexpr (!kTruncate) {
95+
- from_bits = RoundBitsToNearestEven(from_bits, 8);
96+
+ from_bits = RoundBitsToNearestEven(from_bits, 8, false);
97+
// Rounding can cause an overflow to infinity. Clamp to the largest finite
98+
// value if saturation is requested.
99+
if constexpr (kSaturate) {

third_party/tsl/third_party/py/ml_dtypes/workspace.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ def repo():
1212
tf_http_archive(
1313
name = "ml_dtypes",
1414
build_file = "//third_party/py/ml_dtypes:ml_dtypes.BUILD",
15+
patch_file = ["//third_party/py/ml_dtypes:e8m0.patch"],
1516
link_files = {
1617
"//third_party/py/ml_dtypes:ml_dtypes.tests.BUILD": "tests/BUILD.bazel",
1718
"//third_party/py/ml_dtypes:LICENSE": "LICENSE",

third_party/tsl/tsl/platform/ml_dtypes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ using float8_e4m3fnuz = ::ml_dtypes::float8_e4m3fnuz;
2929
using float8_e4m3b11fnuz = ::ml_dtypes::float8_e4m3b11fnuz;
3030
using float8_e5m2 = ::ml_dtypes::float8_e5m2;
3131
using float8_e5m2fnuz = ::ml_dtypes::float8_e5m2fnuz;
32+
using float8_e8m0fnu = ::ml_dtypes::float8_e8m0fnu;
3233

3334
using int1 = ::ml_dtypes::int1;
3435
using uint1 = ::ml_dtypes::uint1;

xla/array2d_test.cc

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,20 @@ TEST(Array2dTest, LinspaceF4E2M1FN) {
233233
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(2, 1)), 4.0); // 3.5 rounded up
234234
}
235235

236+
TEST(Array2dTest, LinspaceF8E8M0FNU) {
237+
auto arr = MakeLinspaceArray2D<tsl::float8_e8m0fnu>(1.0, 3.5, 3, 2);
238+
239+
EXPECT_EQ(arr->n1(), 3);
240+
EXPECT_EQ(arr->n2(), 2);
241+
242+
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(0, 0)), 1.0);
243+
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(0, 1)), 2.0); // 1.5 rounded up
244+
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(1, 0)), 2.0);
245+
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(1, 1)), 2.0); // 2.5 rounded down
246+
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(2, 0)), 4.0); // 3.0 rounded up
247+
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(2, 1)), 4.0); // 3.5 rounded up
248+
}
249+
236250
TEST(Array2dTest, Stringification) {
237251
auto arr = MakeLinspaceArray2D(1.0, 3.5, 3, 2);
238252
const std::string expected = R"([[1, 1.5],

xla/backends/gpu/codegen/transforms/expand_float_ops.cc

Lines changed: 59 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,8 @@ int GetSignificandBits(mlir::FloatType ty) {
163163
}
164164

165165
int GetExponentBias(mlir::FloatType ty) {
166-
return 1 - llvm::APFloat::semanticsMinExponent(ty.getFloatSemantics());
166+
return 1 - llvm::APFloat::semanticsMinExponent(ty.getFloatSemantics()) -
167+
ty.isFloat8E8M0FNU(); // No zero exponent for E8M0.
167168
}
168169

169170
bool IsFNUZ(mlir::FloatType ty) {
@@ -215,6 +216,8 @@ Value IsNaN(Value value, mlir::ImplicitLocOpBuilder& b) {
215216
return (bits & 0b0111'1111) == 0b0111'1111;
216217
} else if (ty.isFloat8E3M4()) {
217218
return (bits & 0b0111'1111).cmp(ma::CmpIPredicate::ugt, 0b0111'0000);
219+
} else if (ty.isFloat8E8M0FNU()) {
220+
return bits == 0xFF;
218221
}
219222
return bits == 0x80;
220223
}
@@ -294,6 +297,13 @@ Value EmitFloatConversion(Value value, mlir::FloatType to_ty,
294297
} else {
295298
wide_int_ty = b.getIntegerType(
296299
std::max(from_int_ty.getWidth(), to_int_ty.getWidth()));
300+
// Avoid overflow for bit shifts.
301+
auto may_overflow = [&](mlir::Type a, mlir::Type b) {
302+
return a.isFloat8E8M0FNU() && b.isF16();
303+
};
304+
if (may_overflow(from_ty, to_ty) || may_overflow(to_ty, from_ty)) {
305+
wide_int_ty = b.getI32Type();
306+
}
297307
}
298308
auto convert_int = [&](mlir::Type ty, Value v) -> Val {
299309
if (v.getType() == ty) {
@@ -320,11 +330,11 @@ Value EmitFloatConversion(Value value, mlir::FloatType to_ty,
320330
};
321331

322332
// Shift bits to destination type, without sign bit.
323-
Val from_sign_bit = from_bits.shrui(from_width - 1) != 0;
324-
from_bits = from_bits & ((1ULL << (from_width - 1)) - 1);
325-
326-
Value result_is_inf = IsInf(value, b);
327-
Value input_is_nan = IsNaN(value, b);
333+
Val from_sign_bit;
334+
if (!from_ty.isFloat8E8M0FNU()) {
335+
from_sign_bit = from_bits.shrui(from_width - 1) != 0;
336+
from_bits = from_bits & ((1ULL << (from_width - 1)) - 1);
337+
}
328338

329339
auto cst_bits = [&](llvm::APFloat f) {
330340
return cst(b.getIntegerType(llvm::APFloat::getSizeInBits(f.getSemantics())),
@@ -338,9 +348,13 @@ Value EmitFloatConversion(Value value, mlir::FloatType to_ty,
338348
if (to_ty.isFloat4E2M1FN()) {
339349
to_inf = cst_bits(llvm::APFloat::getLargest(to_ty.getFloatSemantics()));
340350
to_nan = to_zero | 0x8;
351+
} else if (to_ty.isFloat8E8M0FNU()) {
352+
to_inf = to_nan;
353+
to_zero = Val{to_nan, &b};
341354
}
342355

343-
auto round_bits_to_nearest_even = [&](Val bits, Val roundoff) {
356+
auto round_bits_to_nearest_even = [&](Val bits, Val roundoff,
357+
bool use_implicit_bit = false) {
344358
assert(bits.value.getType() == roundoff.value.getType());
345359
// Round to nearest even by adding a bias term.
346360
// Consider a bit pattern
@@ -350,9 +364,10 @@ Value EmitFloatConversion(Value value, mlir::FloatType to_ty,
350364
// - L is 1, R is 1, OR
351365
// - L is 0, R is 1, any T is one.
352366
// We do this by adding L to a bit pattern consisting of all T = 1.
353-
Val rounded = (bits.shrui(roundoff) & 1) +
354-
(bits.MakeConstant(1).shl(roundoff - 1) - 1);
355-
Val bias{b.create<SelectOp>(roundoff == 0, roundoff, rounded), &b};
367+
Val bias = !use_implicit_bit
368+
? (bits.shrui(roundoff) & 1) +
369+
(bits.MakeConstant(1).shl(roundoff - 1) - 1)
370+
: bits.MakeConstant(1).shl(roundoff - 1);
356371
return bits + bias;
357372
};
358373

@@ -362,9 +377,11 @@ Value EmitFloatConversion(Value value, mlir::FloatType to_ty,
362377
// Round the mantissa if it is shrinking.
363378
Val rounded_from_bits = convert_int(wide_int_ty, from_bits);
364379
if (digit_shift < 0) {
365-
rounded_from_bits = round_bits_to_nearest_even(
366-
from_bits, from_bits.MakeConstant(-digit_shift)) &
367-
~((1ll << (-digit_shift)) - 1);
380+
rounded_from_bits =
381+
round_bits_to_nearest_even(
382+
rounded_from_bits, rounded_from_bits.MakeConstant(-digit_shift),
383+
/*use_implicit_bit=*/to_mantissa == 0) &
384+
~((1ll << (-digit_shift)) - 1);
368385
}
369386

370387
// Re-bias the exponent.
@@ -431,10 +448,12 @@ Value EmitFloatConversion(Value value, mlir::FloatType to_ty,
431448
Value biased_exp_sle_zero = biased_exponent.cmp(CmpIPredicate::sle, 0);
432449
bits.value =
433450
b.create<SelectOp>(biased_exp_sle_zero, subnormal_bits, normal_bits);
434-
if (digit_shift > 0) {
451+
if (digit_shift >= 0) {
435452
bits = bits.shl(digit_shift);
436453
} else {
437-
bits = round_bits_to_nearest_even(bits, bits.MakeConstant(-digit_shift));
454+
bits = round_bits_to_nearest_even(
455+
bits, bits.MakeConstant(-digit_shift),
456+
/*use_implicit_bit=*/to_mantissa == 0 && exp_offset != 0);
438457
bits = bits.shrui(-digit_shift);
439458
}
440459
bits = convert_int(to_int_ty, bits);
@@ -443,11 +462,11 @@ Value EmitFloatConversion(Value value, mlir::FloatType to_ty,
443462
} else if (to_min_exp > from_min_exp) {
444463
// `To` supports fewer exponents near zero which means that some values in
445464
// `From` may become subnormal.
446-
Val unbiased_exp = biased_from_exp - from_bias;
447-
Val biased_to_exp = unbiased_exp + to_bias;
465+
Val biased_to_exp = biased_from_exp + (to_bias - from_bias);
448466
// Subnormals and zero.
449467
// Round and shift mantissa down.
450-
Val from_has_leading_one = biased_from_exp != 0;
468+
Val from_has_leading_one =
469+
!from_ty.isFloat8E8M0FNU() ? biased_from_exp != 0 : cst(i32_ty, 1);
451470
Val from_has_leading_one_i32 = convert_int(i32_ty, from_has_leading_one);
452471
from_has_leading_one = convert_int(from_int_ty, from_has_leading_one);
453472
Val exponent_shift_i32 =
@@ -482,7 +501,13 @@ Value EmitFloatConversion(Value value, mlir::FloatType to_ty,
482501
result);
483502
}
484503

485-
if (IsFNUZ(to_ty)) {
504+
Value result_is_inf = IsInf(value, b);
505+
Value input_is_nan = IsNaN(value, b);
506+
507+
if (to_ty.isFloat8E8M0FNU()) {
508+
// Converting a negative number to E8M0 results in NaN.
509+
input_is_nan = from_sign_bit | input_is_nan;
510+
} else if (IsFNUZ(to_ty)) {
486511
// Clear the sign bit if the result is zero (the output has no negative
487512
// zero). Handle the edge case when the input is zero and the result is not.
488513
Val result_is_non_zero =
@@ -494,14 +519,17 @@ Value EmitFloatConversion(Value value, mlir::FloatType to_ty,
494519
from_sign_bit = from_sign_bit ^ input_is_nan;
495520
}
496521

522+
if (!from_ty.isFloat8E8M0FNU()) {
523+
result = b.create<SelectOp>(from_bits == 0, to_zero, result);
524+
}
497525
result = b.create<SelectOp>(result_is_inf, to_inf, result);
498-
result = b.create<SelectOp>(from_bits == 0, to_zero, result);
499526
result = b.create<SelectOp>(input_is_nan, to_nan, result);
500527

501-
Value neg_result = Val{result, &b} | (1ll << (to_int_ty.getWidth() - 1));
502-
503528
// Insert sign bit.
504-
result = b.create<SelectOp>(from_sign_bit, neg_result, result);
529+
if (!from_ty.isFloat8E8M0FNU()) {
530+
Value neg_result = Val{result, &b} | (1ll << (to_int_ty.getWidth() - 1));
531+
result = b.create<SelectOp>(from_sign_bit, neg_result, result);
532+
}
505533
result = b.create<ma::BitcastOp>(to_ty, result);
506534
return result;
507535
}
@@ -598,6 +626,14 @@ struct RewriteAbsFPattern : public mlir::OpRewritePattern<mlir::math::AbsFOp> {
598626
return rewriter.notifyMatchFailure(op,
599627
"not an f8 (or less) or bf16 absf");
600628
}
629+
630+
// If type is unsigned (E8M0), the operation is no-op.
631+
if (!llvm::APFloat::semanticsHasSignedRepr(
632+
src.getType().getFloatSemantics())) {
633+
rewriter.replaceAllOpUsesWith(op, op.getOperand());
634+
return mlir::success();
635+
}
636+
601637
mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
602638
mlir::Type i_ty = rewriter.getIntegerType(src.getType().getWidth());
603639
Val value{b.create<ma::BitcastOp>(i_ty, src), &b};

xla/backends/gpu/codegen/transforms/tests/expand_float_ops.mlir

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,3 +152,16 @@ module {
152152
// CHECK-LABEL: @f4_abs
153153
// CHECK-NOT: math.absf
154154
// CHECK: arith.constant 7 : i4
155+
156+
// -----
157+
158+
module {
159+
func.func @e8m0_abs(%arg0: f8E8M0FNU) -> f8E8M0FNU {
160+
%ret = math.absf %arg0 : f8E8M0FNU
161+
return %ret : f8E8M0FNU
162+
}
163+
}
164+
165+
// CHECK-LABEL: @e8m0_abs
166+
// CHECK-NOT: math.absf
167+
// CHECK: return %arg0

xla/comparison_util.h

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -193,8 +193,13 @@ class Comparison {
193193
// -NaN < -Inf < -Finite < -0 < +0 < +Finite < +Inf < +NaN
194194
// Reference:
195195
// https://www.tensorflow.org/xla/operation_semantics#element-wise_comparison_operations
196-
using R = SignedIntegerTypeForSizeType<sizeof(T)>;
197-
return GetComparator<R>()(ToSignMagnitude(a), ToSignMagnitude(b));
196+
if constexpr (std::numeric_limits<T>::is_signed) {
197+
using R = SignedIntegerTypeForSizeType<sizeof(T)>;
198+
return GetComparator<R>()(ToSignMagnitude(a), ToSignMagnitude(b));
199+
} else {
200+
using R = UnsignedIntegerTypeForSizeType<sizeof(T)>;
201+
return GetComparator<R>()(ToSignMagnitude(a), ToSignMagnitude(b));
202+
}
198203
}
199204
}
200205
// Applies the comparison from this Comparison's direction and ordering.

xla/ffi/api/api.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,8 @@ inline std::ostream& operator<<(std::ostream& os,
147147
return os << "F8E5M2FNUZ";
148148
case XLA_FFI_DataType_F8E4M3FNUZ:
149149
return os << "F8E4M3FNUZ";
150+
case XLA_FFI_DataType_F8E8M0FNU:
151+
return os << "F8E8M0FNU";
150152
}
151153
}
152154

xla/ffi/api/c_api.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,7 @@ typedef enum {
202202
XLA_FFI_DataType_F8E5M2FNUZ = 24,
203203
XLA_FFI_DataType_F8E4M3FNUZ = 25,
204204
XLA_FFI_DataType_F4E2M1FN = 30,
205+
XLA_FFI_DataType_F8E8M0FNU = 31,
205206
} XLA_FFI_DataType;
206207
// LINT.ThenChange(ffi_test.cc)
207208

xla/ffi/api/ffi.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ enum class DataType : uint8_t {
8080
F8E4M3FNUZ = XLA_FFI_DataType_F8E4M3FNUZ,
8181
F8E3M4 = XLA_FFI_DataType_F8E3M4,
8282
F4E2M1FN = XLA_FFI_DataType_F4E2M1FN,
83+
F8E8M0FNU = XLA_FFI_DataType_F8E8M0FNU,
8384
};
8485

8586
// Create aliases in ::xla::ffi namespace for all DataTypes, for consistency
@@ -108,6 +109,7 @@ inline constexpr DataType F8E5M2FNUZ = DataType::F8E5M2FNUZ;
108109
inline constexpr DataType F8E4M3FNUZ = DataType::F8E4M3FNUZ;
109110
inline constexpr DataType F8E3M4 = DataType::F8E3M4;
110111
inline constexpr DataType F4E2M1FN = DataType::F4E2M1FN;
112+
inline constexpr DataType F8E8M0FNU = DataType::F8E8M0FNU;
111113

112114
inline std::ostream& operator<<(std::ostream& os, const DataType dtype) {
113115
return os << static_cast<XLA_FFI_DataType>(dtype);
@@ -130,6 +132,7 @@ constexpr size_t ByteWidth(DataType dtype) {
130132
case DataType::F8E4M3FNUZ:
131133
case DataType::F8E3M4:
132134
case DataType::F4E2M1FN:
135+
case DataType::F8E8M0FNU:
133136
return 1;
134137
case DataType::S16:
135138
case DataType::U16:

0 commit comments

Comments
 (0)