Skip to content

Commit 70ca820

Browse files
committed
Add F4E2M1FN type: literal support
1 parent 87d0056 commit 70ca820

File tree

7 files changed

+124
-61
lines changed

7 files changed

+124
-61
lines changed

xla/hlo/translate/hlo_to_mhlo/tests/import.hlo

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -421,6 +421,9 @@ add {
421421

422422
// CHECK: %[[VAL_13:.*]] = mhlo.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf8E3M4>
423423
%constant.13 = f8e3m4[4] constant({1, 2, 3, 4})
424+
425+
// CHECK: %[[VAL_14:.*]] = mhlo.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf4E2M1FN>
426+
%constant.14 = f4e2m1fn[4] constant({1, 2, 3, 4})
424427
}
425428

426429
// TODO(b/129422361) Potentially update when copy, reshape, and conv have actual
@@ -542,7 +545,13 @@ add {
542545
%convert.15 = f8e3m4[4] convert(f32[4] %convert.14)
543546

544547
// CHECK-NEXT: %13 = mhlo.convert %12 : (tensor<4xf8E3M4>) -> tensor<4xf32>
545-
ROOT %convert.16 = f32[4] convert(f8e3m4[4] %convert.15)
548+
%convert.16 = f32[4] convert(f8e3m4[4] %convert.15)
549+
550+
// CHECK-NEXT: %14 = mhlo.convert %13 : (tensor<4xf32>) -> tensor<4xf4E2M1FN>
551+
%convert.17 = f4e2m1fn[4] convert(f32[4] %convert.16)
552+
553+
// CHECK-NEXT: %15 = mhlo.convert %14 : (tensor<4xf4E2M1FN>) -> tensor<4xf32>
554+
ROOT %convert.18 = f32[4] convert(f4e2m1fn[4] %convert.17)
546555
}
547556

548557
// CHECK-LABEL: func private @test_stochastic_convert(%arg0: tensor<4x3xf32>, %arg1: tensor<4x3xui32>) -> tensor<4x3xi8>

xla/hlo/translate/mhlo_to_hlo/literal_exporter.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,12 @@ xla::Array<T> ArrayFromDenseElementsAttr(mlir::DenseElementsAttr dense_attr) {
4141
xla::Array<T> array(shape.dimensions());
4242
if constexpr (!xla::primitive_util::IsSubByteNonPredType(type)) {
4343
array.SetValues(dense_attr.getValues<T>());
44+
} else if constexpr (xla::primitive_util::IsMXType(type)) {
45+
// Bitcast MX floating point types from APFloat.
46+
auto values = dense_attr.getValues<llvm::APFloat>();
47+
for (int i = 0; i < values.size(); i++) {
48+
array.data()[i] = T::FromRep(values[i].bitcastToAPInt().getZExtValue());
49+
}
4450
} else {
4551
// The only way to get subbyte integers from getValues() is to get them as
4652
// APInts.

xla/hlo/translate/mhlo_to_hlo/tests/export.mlir

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -606,6 +606,9 @@ func.func @main() {
606606
// CHECK: f8e3m4[4] constant({1, 2, 3, 4})
607607
%cst_17 = arith.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf8E3M4>
608608

609+
// CHECK: f4e2m1fn[4] constant({1, 2, 3, 4})
610+
%cst_18 = arith.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf4E2M1FN>
611+
609612
func.return
610613
}
611614

@@ -739,7 +742,9 @@ func.func @main(%arg0: tensor<2xf32>) -> tensor<2xf32> {
739742
%9 = "mhlo.convert"(%8) : (tensor<2xf8E4M3>) -> tensor<2xf32>
740743
%10 = "mhlo.convert"(%9) : (tensor<2xf32>) -> tensor<2xf8E3M4>
741744
%11 = "mhlo.convert"(%10) : (tensor<2xf8E3M4>) -> tensor<2xf32>
742-
func.return %11 : tensor<2xf32>
745+
%12 = "mhlo.convert"(%11) : (tensor<2xf32>) -> tensor<2xf4E2M1FN>
746+
%13 = "mhlo.convert"(%12) : (tensor<2xf4E2M1FN>) -> tensor<2xf32>
747+
func.return %13 : tensor<2xf32>
743748
}
744749

745750
// CHECK: ENTRY
@@ -755,7 +760,9 @@ func.func @main(%arg0: tensor<2xf32>) -> tensor<2xf32> {
755760
// CHECK: %[[E4M3_VAL:.*]] = f8e4m3[2] convert(f32[2] %[[F32_VAL4]])
756761
// CHECK: %[[F32_VAL5:.*]] = f32[2] convert(f8e4m3[2] %[[E4M3_VAL]])
757762
// CHECK: %[[E3M4_VAL:.*]] = f8e3m4[2] convert(f32[2] %[[F32_VAL5]])
758-
// CHECK: ROOT %[[F32_VAL6:.*]] = f32[2] convert(f8e3m4[2] %[[E3M4_VAL]])
763+
// CHECK: %[[F32_VAL6:.*]] = f32[2] convert(f8e3m4[2] %[[E3M4_VAL]])
764+
// CHECK: %[[E2M1FN_VAL:.*]] = f4e2m1fn[2] convert(f32[2] %[[F32_VAL6]])
765+
// CHECK: ROOT %[[F32_VAL7:.*]] = f32[2] convert(f4e2m1fn[2] %[[E2M1FN_VAL]])
759766

760767
// -----
761768

xla/literal.cc

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -91,13 +91,13 @@ bool LiteralProtoHasValues(const LiteralProto& proto) {
9191
!proto.s16s().empty() || proto.s32s_size() || proto.s64s_size() ||
9292
!proto.u2s().empty() || !proto.u4s().empty() || !proto.u8s().empty() ||
9393
!proto.u16s().empty() || proto.u32s_size() || proto.u64s_size() ||
94-
!proto.f8e5m2s().empty() || !proto.f8e4m3s().empty() ||
95-
!proto.f8e4m3fns().empty() || !proto.f8e4m3b11fnuzs().empty() ||
96-
!proto.f8e5m2fnuzs().empty() || !proto.f8e4m3fnuzs().empty() ||
97-
!proto.f8e3m4s().empty() || !proto.f16s().empty() ||
98-
!proto.bf16s().empty() || proto.f32s_size() || proto.f64s_size() ||
99-
proto.c64s_size() || proto.c128s_size() || proto.preds_size() ||
100-
proto.tuple_literals_size();
94+
!proto.f4e2m1fns().empty() || !proto.f8e3m4s().empty() ||
95+
!proto.f8e4m3b11fnuzs().empty() || !proto.f8e4m3fns().empty() ||
96+
!proto.f8e4m3fnuzs().empty() || !proto.f8e4m3s().empty() ||
97+
!proto.f8e5m2fnuzs().empty() || !proto.f8e5m2s().empty() ||
98+
!proto.f16s().empty() || !proto.bf16s().empty() || proto.f32s_size() ||
99+
proto.f64s_size() || proto.c64s_size() || proto.c128s_size() ||
100+
proto.preds_size() || proto.tuple_literals_size();
101101
}
102102

103103
// Lazy getter for the interned scalar shape in static storage. We reuse this
@@ -1874,7 +1874,6 @@ bool LiteralBase::Piece::EqualElements(const LiteralBase::Piece& other) const {
18741874
<< __func__ << " is only supported for dense arrays: " << subshape();
18751875
CHECK_EQ(size_bytes_dense(), other.size_bytes_dense());
18761876
if (primitive_util::IsSubByteNonPredType(subshape().element_type())) {
1877-
CHECK(!primitive_util::IsFloatingPointType(subshape().element_type()));
18781877
auto one_array = buffer();
18791878
auto two_array = other.buffer();
18801879
const int bits_per_element =
@@ -2259,6 +2258,11 @@ void LiteralBase::Piece::WriteToProto(LiteralProto* proto) const {
22592258
case S64:
22602259
CopyToRepeatedField(proto->mutable_s64s(), data<int64_t>());
22612260
break;
2261+
case F4E2M1FN:
2262+
*proto->mutable_f4e2m1fns() = std::string(
2263+
reinterpret_cast<const char*>(data<tsl::float4_e2m1fn>().data()),
2264+
size_bytes_dense());
2265+
break;
22622266
case F8E5M2:
22632267
*proto->mutable_f8e5m2s() = std::string(
22642268
reinterpret_cast<const char*>(data<tsl::float8_e5m2>().data()),
@@ -2445,6 +2449,14 @@ absl::Status LiteralBase::Piece::CopyFromProto(const LiteralProto& proto) {
24452449
case U64:
24462450
TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<uint64_t>(), proto.u64s()));
24472451
break;
2452+
case F4E2M1FN: {
2453+
const std::string& s(proto.f4e2m1fns());
2454+
TF_RET_CHECK(data<tsl::float4_e2m1fn>().size() *
2455+
sizeof(tsl::float4_e2m1fn) ==
2456+
s.size());
2457+
memcpy(untyped_data(), s.data(), s.size());
2458+
break;
2459+
}
24482460
case F8E5M2: {
24492461
const std::string& s(proto.f8e5m2s());
24502462
TF_RET_CHECK(data<tsl::float8_e5m2>().size() * sizeof(tsl::float8_e5m2) ==

xla/literal.h

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -589,7 +589,6 @@ class LiteralBase {
589589
primitive_util::NativeToPrimitiveType<NativeT>();
590590
constexpr int bits_per_element = primitive_util::BitWidth(primitive_type);
591591
if constexpr (bits_per_element < 8) {
592-
static_assert(!primitive_util::IsFloatingPointType(primitive_type));
593592
static_assert(!primitive_util::IsComplexType(primitive_type));
594593
static_assert(8 % bits_per_element == 0);
595594
constexpr int elements_per_byte = 8 / bits_per_element;
@@ -598,9 +597,9 @@ class LiteralBase {
598597
for (int64_t i = 0; i < bytes; ++i) {
599598
uint8_t byte = 0;
600599
for (int b = 0; b < elements_per_byte; ++b) {
601-
uint8_t src =
602-
static_cast<uint8_t>(elements[i * elements_per_byte + b]) &
603-
LsbMask<uint8_t>(bits_per_element);
600+
uint8_t src = Eigen::numext::bit_cast<uint8_t>(
601+
elements[i * elements_per_byte + b]) &
602+
LsbMask<uint8_t>(bits_per_element);
604603
byte |= src << (b * bits_per_element);
605604
}
606605
WriteElement(byte);
@@ -609,9 +608,9 @@ class LiteralBase {
609608
if (rest != 0) {
610609
uint8_t byte = 0;
611610
for (int64_t b = 0; b < rest; ++b) {
612-
uint8_t src =
613-
static_cast<uint8_t>(elements[bytes * elements_per_byte + b]) &
614-
LsbMask<uint8_t>(bits_per_element);
611+
uint8_t src = Eigen::numext::bit_cast<uint8_t>(
612+
elements[bytes * elements_per_byte + b]) &
613+
LsbMask<uint8_t>(bits_per_element);
615614
byte |= src << (b * bits_per_element);
616615
}
617616
WriteElement(byte);
@@ -701,10 +700,16 @@ class LiteralBase {
701700
primitive_util::NativeToPrimitiveType<NativeT>();
702701
constexpr int bits_per_element = primitive_util::BitWidth(primitive_type);
703702
if constexpr (bits_per_element < 8) {
704-
static_assert(!primitive_util::IsFloatingPointType(primitive_type));
705703
static_assert(!primitive_util::IsComplexType(primitive_type));
706704
static_assert(8 % bits_per_element == 0);
705+
707706
constexpr int elements_per_byte = 8 / bits_per_element;
707+
constexpr auto cast = [](uint8_t x) -> NativeT {
708+
if constexpr (primitive_util::IsFloatingPointType(primitive_type)) {
709+
return Eigen::numext::bit_cast<NativeT>(x);
710+
}
711+
return static_cast<NativeT>(x);
712+
};
708713

709714
int64_t bytes = elements.size() / elements_per_byte;
710715
for (int64_t i = 0; i < bytes; ++i) {
@@ -714,7 +719,7 @@ class LiteralBase {
714719
}
715720
for (int b = 0; b < elements_per_byte; ++b) {
716721
elements[i * elements_per_byte + b] =
717-
static_cast<NativeT>(byte & LsbMask<uint8_t>(bits_per_element));
722+
cast(byte & LsbMask<uint8_t>(bits_per_element));
718723
byte >>= bits_per_element;
719724
}
720725
}
@@ -726,7 +731,7 @@ class LiteralBase {
726731
}
727732
for (int64_t b = 0; b < rest; ++b) {
728733
elements[bytes * elements_per_byte + b] =
729-
static_cast<NativeT>(byte & LsbMask<uint8_t>(bits_per_element));
734+
cast(byte & LsbMask<uint8_t>(bits_per_element));
730735
byte >>= bits_per_element;
731736
}
732737
}

xla/literal_comparison_test.cc

Lines changed: 27 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -29,27 +29,30 @@ namespace {
2929
template <typename T>
3030
class LiteralComparisonTest : public ::testing::Test {};
3131

32-
using TestedTypes =
33-
::testing::Types<tsl::float8_e3m4, tsl::float8_e4m3, tsl::float8_e4m3fn,
34-
tsl::float8_e4m3b11fnuz, tsl::float8_e5m2>;
32+
using TestedTypes = ::testing::Types<tsl::float4_e2m1fn, tsl::float8_e3m4,
33+
tsl::float8_e4m3, tsl::float8_e4m3b11fnuz,
34+
tsl::float8_e4m3fn, tsl::float8_e4m3fnuz,
35+
tsl::float8_e5m2, tsl::float8_e5m2fnuz>;
3536
TYPED_TEST_SUITE(LiteralComparisonTest, TestedTypes);
3637

3738
TYPED_TEST(LiteralComparisonTest, CompareNear_Equal) {
38-
auto actual = LiteralUtil::CreateR0<TypeParam>(TypeParam(8.0));
39-
auto expected = LiteralUtil::CreateR0<TypeParam>(TypeParam(8.0));
39+
auto actual = LiteralUtil::CreateR0<TypeParam>(TypeParam(1.0));
40+
auto expected = LiteralUtil::CreateR0<TypeParam>(TypeParam(1.0));
4041
TF_EXPECT_OK(literal_comparison::Near(expected, actual, ErrorSpec(0.0, 0.0),
4142
/*detailed_message=*/false,
4243
/*miscompare_callback=*/nullptr));
4344
}
4445

4546
TYPED_TEST(LiteralComparisonTest, CompareNear_NotEqual_1ulp) {
4647
PrimitiveType type = primitive_util::NativeToPrimitiveType<TypeParam>();
47-
auto actual = LiteralUtil::CreateR0<TypeParam>(TypeParam(8.0));
48-
float expV = 9.0; // F8E4M3*
49-
if (type == F8E5M2)
50-
expV = 10.0;
48+
auto actual = LiteralUtil::CreateR0<TypeParam>(TypeParam(1.0));
49+
float expV = 1.125; // F8E4M3*
50+
if (type == F8E5M2 || type == F8E5M2FNUZ)
51+
expV = 1.25;
5152
else if (type == F8E3M4)
52-
expV = 8.5;
53+
expV = 1.0625;
54+
else if (type == F4E2M1FN)
55+
expV = 1.5;
5356
auto expected = LiteralUtil::CreateR0<TypeParam>(TypeParam{expV});
5457
auto error_spec = ErrorSpec(0.0, 0.0);
5558
EXPECT_IS_NOT_OK(literal_comparison::Near(expected, actual, error_spec,
@@ -64,12 +67,14 @@ TYPED_TEST(LiteralComparisonTest, CompareNear_NotEqual_1ulp) {
6467

6568
TYPED_TEST(LiteralComparisonTest, CompareNear_NotEqual_4ulps) {
6669
PrimitiveType type = primitive_util::NativeToPrimitiveType<TypeParam>();
67-
auto actual = LiteralUtil::CreateR0<TypeParam>(TypeParam(8.0));
68-
float expV = 12.0; // F8E4M3*
69-
if (type == F8E5M2)
70-
expV = 14.0;
70+
auto actual = LiteralUtil::CreateR0<TypeParam>(TypeParam(1.0));
71+
float expV = 1.5; // F8E4M3*
72+
if (type == F8E5M2 || type == F8E5M2FNUZ)
73+
expV = 1.75;
7174
else if (type == F8E3M4)
72-
expV = 10.0;
75+
expV = 1.25;
76+
else if (type == F4E2M1FN)
77+
expV = 4.0;
7378
auto expected = LiteralUtil::CreateR0<TypeParam>(TypeParam{expV});
7479
auto error_spec = ErrorSpec(0.0, 0.0);
7580
error_spec.low_precision_fp_error_spec.type = type;
@@ -86,12 +91,14 @@ TYPED_TEST(LiteralComparisonTest, CompareNear_NotEqual_4ulps) {
8691

8792
TYPED_TEST(LiteralComparisonTest, FloatUsingCompareNear_NotEqual_4ulps) {
8893
PrimitiveType type = primitive_util::NativeToPrimitiveType<TypeParam>();
89-
auto actual = LiteralUtil::CreateR0<float>(8.0);
90-
float expV = 12.1; // F8E4M3*
91-
if (type == F8E5M2)
92-
expV = 13.0;
94+
auto actual = LiteralUtil::CreateR0<float>(1.0);
95+
float expV = 1.51; // F8E4M3*
96+
if (type == F8E5M2 || type == F8E5M2FNUZ)
97+
expV = 1.76;
9398
else if (type == F8E3M4)
94-
expV = 10.125;
99+
expV = 1.26;
100+
else if (type == F4E2M1FN)
101+
expV = 4.1;
95102
auto expected = LiteralUtil::CreateR0<float>(expV);
96103
auto error_spec = ErrorSpec(0.0, 0.0);
97104
error_spec.low_precision_fp_error_spec.type = type;

0 commit comments

Comments
 (0)