Skip to content

Commit ad4d33e

Browse files
apivovarovGoogle-ML-Automation
authored andcommitted
PR #16585: Add support for float8_e4m3 and float8_e3m4 types
Imported from GitHub PR #16585 This PR adds f8E4M3 and f8E3M4 types support to XLA (mainly to cpu_compiler). ### `f8E4M3` type follows IEEE 754 convention. ```c f8E4M3 (IEEE 754) - Exponent bias: 7 - Maximum stored exponent value: 14 (binary 1110) - Maximum unbiased exponent value: 14 - 7 = 7 - Minimum stored exponent value: 1 (binary 0001) - Minimum unbiased exponent value: 1 − 7 = −6 - Precision specifies the total number of bits used for the significand (mantisa), including implicit leading integer bit = 3 + 1 = 4 - Follows IEEE 754 conventions for representation of special values - Has Positive and Negative zero - Has Positive and Negative infinity - Has NaNs Additional details: - Max exp (unbiased): 7 - Min exp (unbiased): -6 - Infinities (+/-): S.1111.000 - Zeros (+/-): S.0000.000 - NaNs: S.1111.{001, 010, 011, 100, 101, 110, 111} - Max normal number: S.1110.111 = +/-2^(7) x (1 + 0.875) = +/-240 - Min normal number: S.0001.000 = +/-2^(-6) - Max subnormal number: S.0000.111 = +/-2^(-6) x 0.875 = +/-2^(-9) x 7 - Min subnormal number: S.0000.001 = +/-2^(-6) x 0.125 = +/-2^(-9) ``` ### `f8E3M4` type follows IEEE 754 convention ```c f8E3M4 (IEEE 754) - Exponent bias: 3 - Maximum stored exponent value: 6 (binary 110) - Maximum unbiased exponent value: 6 - 3 = 3 - Minimum stored exponent value: 1 (binary 001) - Minimum unbiased exponent value: 1 − 3 = −2 - Precision specifies the total number of bits used for the significand (mantissa), including implicit leading integer bit = 4 + 1 = 5 - Follows IEEE 754 conventions for representation of special values - Has Positive and Negative zero - Has Positive and Negative infinity - Has NaNs Additional details: - Max exp (unbiased): 3 - Min exp (unbiased): -2 - Infinities (+/-): S.111.0000 - Zeros (+/-): S.000.0000 - NaNs: S.111.{0,1}⁴ except S.111.0000 - Max normal number: S.110.1111 = +/-2^(6-3) x (1 + 15/16) = +/-2^3 x 31 x 2^(-4) = +/-15.5 - Min normal number: S.001.0000 = +/-2^(1-3) x (1 + 0) = +/-2^(-2) - Max subnormal number: S.000.1111 = +/-2^(-2) x 15/16 = +/-2^(-2) x 15 x 2^(-4) = +/-15 x 2^(-6) - Min subnormal number: S.000.0001 = +/-2^(-2) x 1/16 = +/-2^(-2) x 2^(-4) = +/-2^(-6) ``` ### Testing: ``` bazel test \ //xla:array2d_test \ //xla:fp_util_test \ //xla:literal_comparison_test \ //xla:literal_test \ //xla/mlir/utils:type_util_test \ //xla:primitive_util_test \ //xla/python/ifrt:dtype_test \ //xla/python:xla_client_test \ //xla/service:elemental_ir_emitter_test \ //xla/service:float_normalization_test \ //xla/service/gpu/tests:float_conversions_test \ //xla/tests:array_elementwise_ops_test \ //xla/tests:constants_test \ //xla/tests:convert_test \ //xla/tests:float8_test \ //xla:util_test bazel test \ //xla/hlo/translate/hlo_to_mhlo/tests:import.hlo.test \ //xla/hlo/translate/mhlo_to_hlo/tests:export.mlir.test \ //xla/mlir_hlo/tests:Dialect/mhlo/hlo-legalize-to-stablehlo.mlir.test \ //xla/mlir_hlo/tests:Dialect/mhlo/ops.mlir.test \ //xla/mlir_hlo/tests:Dialect/mhlo/stablehlo-legalize-to-hlo.mlir.test ``` ### Related PRs: - LLVM [PR-97179](llvm/llvm-project#97179) [APFloat] Add support for f8E4M3 IEEE 754 type (Merged) - LLVM [PR-97118](llvm/llvm-project#97118) [MLIR] Add f8E4M3 IEEE 754 type (Merged) - LLVM [PR-99698](llvm/llvm-project#99698) [APFloat] Add support for f8E3M4 IEEE 754 type (Merged) - LLVM [PR-101230](llvm/llvm-project#101230) [MLIR] Add f8E3M4 IEEE 754 type (Merged) - StableHLO [PR-2486](openxla/stablehlo#2486) [RFC] Add f8E4M3 and f8E3M4 types support (Merged) - StableHLO [PR-2482](openxla/stablehlo#2482) Add f8E4M3 and f8E3M4 types support (Merged) - ml_dtypes [PR-161](jax-ml/ml_dtypes#161) Add float8_e4m3 (Merged) - ml_dtypes [PR-171](jax-ml/ml_dtypes#171) Add float8_e3m4 (Merged) - XLA [PR-17075](#17075) [TSL] Bump ml_dtypes. Add float8_e4m3, float8_e3m4 (Approved) - XLA [PR-3200](#3200) Add support for float8_e4m3fnuz and float8_e5m2fnuz (Template) - JAX [PR-23585](jax-ml/jax#23585) Add float8_e4m3 type support (in Review) Copybara import of the project: -- ec1c723 by Alexander Pivovarov <[email protected]>: Add support for float8_e4m3 and float8_e3m4 types Merging this change closes #16585 FUTURE_COPYBARA_INTEGRATE_REVIEW=#16585 from apivovarov:float8_e4m3 ec1c723 PiperOrigin-RevId: 680651037
1 parent 869808b commit ad4d33e

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

71 files changed

+1509
-155
lines changed

third_party/tsl/tsl/platform/ml_dtypes.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ limitations under the License.
2020
#include "ml_dtypes/include/intn.h" // from @ml_dtypes
2121

2222
namespace tsl {
23+
using float8_e3m4 = ::ml_dtypes::float8_e3m4;
24+
using float8_e4m3 = ::ml_dtypes::float8_e4m3;
2325
using float8_e4m3fn = ::ml_dtypes::float8_e4m3fn;
2426
using float8_e4m3fnuz = ::ml_dtypes::float8_e4m3fnuz;
2527
using float8_e4m3b11fnuz = ::ml_dtypes::float8_e4m3b11fnuz;

xla/BUILD

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -316,6 +316,7 @@ xla_cc_test(
316316
":util",
317317
"@com_google_absl//absl/base",
318318
"@com_google_absl//absl/numeric:bits",
319+
"@com_google_googletest//:gtest_main",
319320
"@tsl//tsl/platform:ml_dtypes",
320321
"@tsl//tsl/platform:test_main",
321322
],
@@ -373,6 +374,7 @@ xla_cc_test(
373374
":test",
374375
":types",
375376
":util",
377+
"@ml_dtypes//:float8",
376378
"@tsl//tsl/platform:logging",
377379
"@tsl//tsl/platform:ml_dtypes",
378380
"@tsl//tsl/platform:test_main",

xla/array2d_test.cc

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,20 @@ TEST(Array2dTest, LinspaceF8E5M2) {
162162
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(2, 1)), 3.5);
163163
}
164164

165+
TEST(Array2dTest, LinspaceF8E4M3) {
166+
auto arr = MakeLinspaceArray2D<tsl::float8_e4m3>(1.0, 3.5, 3, 2);
167+
168+
EXPECT_EQ(arr->n1(), 3);
169+
EXPECT_EQ(arr->n2(), 2);
170+
171+
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(0, 0)), 1.0);
172+
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(0, 1)), 1.5);
173+
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(1, 0)), 2.0);
174+
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(1, 1)), 2.5);
175+
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(2, 0)), 3.0);
176+
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(2, 1)), 3.5);
177+
}
178+
165179
TEST(Array2dTest, LinspaceF8E4M3Fn) {
166180
auto arr = MakeLinspaceArray2D<tsl::float8_e4m3fn>(1.0, 3.5, 3, 2);
167181

@@ -190,6 +204,20 @@ TEST(Array2dTest, LinspaceF8E4M3FnNoNan) {
190204
}
191205
}
192206

207+
TEST(Array2dTest, LinspaceF8E3M4) {
208+
auto arr = MakeLinspaceArray2D<tsl::float8_e3m4>(1.0, 3.5, 3, 2);
209+
210+
EXPECT_EQ(arr->n1(), 3);
211+
EXPECT_EQ(arr->n2(), 2);
212+
213+
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(0, 0)), 1.0);
214+
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(0, 1)), 1.5);
215+
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(1, 0)), 2.0);
216+
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(1, 1)), 2.5);
217+
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(2, 0)), 3.0);
218+
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(2, 1)), 3.5);
219+
}
220+
193221
TEST(Array2dTest, Stringification) {
194222
auto arr = MakeLinspaceArray2D(1.0, 3.5, 3, 2);
195223
const std::string expected = R"([[1, 1.5],

xla/ffi/api/api.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,10 @@ inline std::ostream& operator<<(std::ostream& os,
133133
return os << "TOKEN";
134134
case XLA_FFI_DataType_F8E5M2:
135135
return os << "F8E5M2";
136+
case XLA_FFI_DataType_F8E3M4:
137+
return os << "F8E3M4";
138+
case XLA_FFI_DataType_F8E4M3:
139+
return os << "F8E4M3";
136140
case XLA_FFI_DataType_F8E4M3FN:
137141
return os << "F8E4M3FN";
138142
case XLA_FFI_DataType_F8E4M3B11FNUZ:

xla/ffi/api/c_api.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,8 @@ typedef enum {
195195
XLA_FFI_DataType_C128 = 18,
196196
XLA_FFI_DataType_TOKEN = 17,
197197
XLA_FFI_DataType_F8E5M2 = 19,
198+
XLA_FFI_DataType_F8E3M4 = 29,
199+
XLA_FFI_DataType_F8E4M3 = 28,
198200
XLA_FFI_DataType_F8E4M3FN = 20,
199201
XLA_FFI_DataType_F8E4M3B11FNUZ = 23,
200202
XLA_FFI_DataType_F8E5M2FNUZ = 24,

xla/ffi/api/ffi.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,10 +73,12 @@ enum class DataType : uint8_t {
7373
C128 = XLA_FFI_DataType_C128,
7474
TOKEN = XLA_FFI_DataType_TOKEN,
7575
F8E5M2 = XLA_FFI_DataType_F8E5M2,
76+
F8E4M3 = XLA_FFI_DataType_F8E4M3,
7677
F8E4M3FN = XLA_FFI_DataType_F8E4M3FN,
7778
F8E4M3B11FNUZ = XLA_FFI_DataType_F8E4M3B11FNUZ,
7879
F8E5M2FNUZ = XLA_FFI_DataType_F8E5M2FNUZ,
7980
F8E4M3FNUZ = XLA_FFI_DataType_F8E4M3FNUZ,
81+
F8E3M4 = XLA_FFI_DataType_F8E3M4,
8082
};
8183

8284
// Create aliases in ::xla::ffi namespace for all DataTypes, for consistency
@@ -98,10 +100,12 @@ inline constexpr DataType C64 = DataType::C64;
98100
inline constexpr DataType C128 = DataType::C128;
99101
inline constexpr DataType TOKEN = DataType::TOKEN;
100102
inline constexpr DataType F8E5M2 = DataType::F8E5M2;
103+
inline constexpr DataType F8E4M3 = DataType::F8E4M3;
101104
inline constexpr DataType F8E4M3FN = DataType::F8E4M3FN;
102105
inline constexpr DataType F8E4M3B11FNUZ = DataType::F8E4M3B11FNUZ;
103106
inline constexpr DataType F8E5M2FNUZ = DataType::F8E5M2FNUZ;
104107
inline constexpr DataType F8E4M3FNUZ = DataType::F8E4M3FNUZ;
108+
inline constexpr DataType F8E3M4 = DataType::F8E3M4;
105109

106110
inline std::ostream& operator<<(std::ostream& os, const DataType dtype) {
107111
return os << static_cast<XLA_FFI_DataType>(dtype);
@@ -117,10 +121,12 @@ constexpr size_t ByteWidth(DataType dtype) {
117121
case DataType::S8:
118122
case DataType::U8:
119123
case DataType::F8E5M2:
124+
case DataType::F8E4M3:
120125
case DataType::F8E4M3FN:
121126
case DataType::F8E4M3B11FNUZ:
122127
case DataType::F8E5M2FNUZ:
123128
case DataType::F8E4M3FNUZ:
129+
case DataType::F8E3M4:
124130
return 1;
125131
case DataType::S16:
126132
case DataType::U16:

xla/ffi/api/ffi_test.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,11 +130,13 @@ TEST(FfiTest, DataTypeEnumValue) {
130130
EXPECT_EQ(encoded(PrimitiveType::TOKEN), encoded(DataType::TOKEN));
131131

132132
EXPECT_EQ(encoded(PrimitiveType::F8E5M2), encoded(DataType::F8E5M2));
133+
EXPECT_EQ(encoded(PrimitiveType::F8E4M3), encoded(DataType::F8E4M3));
133134
EXPECT_EQ(encoded(PrimitiveType::F8E4M3FN), encoded(DataType::F8E4M3FN));
134135
EXPECT_EQ(encoded(PrimitiveType::F8E4M3B11FNUZ),
135136
encoded(DataType::F8E4M3B11FNUZ));
136137
EXPECT_EQ(encoded(PrimitiveType::F8E5M2FNUZ), encoded(DataType::F8E5M2FNUZ));
137138
EXPECT_EQ(encoded(PrimitiveType::F8E4M3FNUZ), encoded(DataType::F8E4M3FNUZ));
139+
EXPECT_EQ(encoded(PrimitiveType::F8E3M4), encoded(DataType::F8E3M4));
138140
}
139141

140142
TEST(FfiTest, DataTypeByteWidth) {
@@ -179,6 +181,8 @@ TEST(FfiTest, DataTypeByteWidth) {
179181

180182
EXPECT_EQ(primitive_util::ByteWidth(PrimitiveType::F8E5M2),
181183
ByteWidth(DataType::F8E5M2));
184+
EXPECT_EQ(primitive_util::ByteWidth(PrimitiveType::F8E4M3),
185+
ByteWidth(DataType::F8E4M3));
182186
EXPECT_EQ(primitive_util::ByteWidth(PrimitiveType::F8E4M3FN),
183187
ByteWidth(DataType::F8E4M3FN));
184188
EXPECT_EQ(primitive_util::ByteWidth(PrimitiveType::F8E4M3B11FNUZ),
@@ -187,6 +191,8 @@ TEST(FfiTest, DataTypeByteWidth) {
187191
ByteWidth(DataType::F8E5M2FNUZ));
188192
EXPECT_EQ(primitive_util::ByteWidth(PrimitiveType::F8E4M3FNUZ),
189193
ByteWidth(DataType::F8E4M3FNUZ));
194+
EXPECT_EQ(primitive_util::ByteWidth(PrimitiveType::F8E3M4),
195+
ByteWidth(DataType::F8E3M4));
190196
}
191197

192198
TEST(FfiTest, ErrorEnumValue) {

xla/ffi/call_frame.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,10 +265,12 @@ static XLA_FFI_DataType ToDataType(PrimitiveType primitive_type) {
265265
case PrimitiveType::C128:
266266
case PrimitiveType::TOKEN:
267267
case PrimitiveType::F8E5M2:
268+
case PrimitiveType::F8E4M3:
268269
case PrimitiveType::F8E4M3FN:
269270
case PrimitiveType::F8E4M3B11FNUZ:
270271
case PrimitiveType::F8E5M2FNUZ:
271272
case PrimitiveType::F8E4M3FNUZ:
273+
case PrimitiveType::F8E3M4:
272274
return static_cast<XLA_FFI_DataType>(primitive_type);
273275
default:
274276
DCHECK(false) << "Unsupported primitive type "

xla/fp_util_test.cc

Lines changed: 58 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ limitations under the License.
1919
#include <cstdint>
2020
#include <limits>
2121

22+
#include <gtest/gtest.h>
2223
#include "absl/base/casts.h"
2324
#include "absl/numeric/bits.h"
2425
#include "xla/bit_cast.h"
@@ -111,21 +112,74 @@ INSTANTIATE_TEST_SUITE_P(DoublePrecisionInputs, FixedValueTest,
111112
0x1.fffffffffffffp-127,
112113
0x1.aaaaaaaaaaaaap-127));
113114

114-
// Test F8E4M3 floating-point types (F8E4M3FN)
115+
// Test F8E4M3 floating-point types (F8E4M3, F8E4M3FN)
115116
template <typename T>
116117
class FP8E4M3DistanceTest : public ::testing::Test {};
117118

118-
using F8E4M3Types = ::testing::Types<tsl::float8_e4m3fn>;
119+
using F8E4M3Types = ::testing::Types<tsl::float8_e4m3, tsl::float8_e4m3fn>;
119120
TYPED_TEST_SUITE(FP8E4M3DistanceTest, F8E4M3Types);
120121

122+
TEST(FPDistanceTest, F8E3M4Distance) {
123+
// a & b are equal
124+
EXPECT_EQ(CalculateDistanceInFloats<tsl::float8_e3m4>(tsl::float8_e3m4(8.0),
125+
tsl::float8_e3m4(8.0)),
126+
0);
127+
128+
// a & b have the same exponents
129+
EXPECT_EQ(CalculateDistanceInFloats<tsl::float8_e3m4>(tsl::float8_e3m4(8.0),
130+
tsl::float8_e3m4(15.5)),
131+
15);
132+
133+
// a & b have different exponents
134+
EXPECT_EQ(CalculateDistanceInFloats<tsl::float8_e3m4>(tsl::float8_e3m4(8.0),
135+
tsl::float8_e3m4(6)),
136+
8);
137+
138+
// 1 from 0 in the positive direction
139+
EXPECT_EQ(CalculateDistanceInFloats<tsl::float8_e3m4>(
140+
std::numeric_limits<tsl::float8_e3m4>::denorm_min(),
141+
tsl::float8_e3m4(0)),
142+
1);
143+
144+
// 1 from 0 in the negative direction
145+
EXPECT_EQ(CalculateDistanceInFloats<tsl::float8_e3m4>(
146+
-std::numeric_limits<tsl::float8_e3m4>::denorm_min(),
147+
tsl::float8_e3m4(0)),
148+
1);
149+
150+
// a & b have different signs
151+
EXPECT_EQ(CalculateDistanceInFloats<tsl::float8_e3m4>(
152+
-std::numeric_limits<tsl::float8_e3m4>::denorm_min(),
153+
std::numeric_limits<tsl::float8_e3m4>::denorm_min()),
154+
2);
155+
156+
// 1 non denorm from 0 in the positive direction
157+
EXPECT_EQ(
158+
CalculateDistanceInFloats<tsl::float8_e3m4>(
159+
std::numeric_limits<tsl::float8_e3m4>::min(), tsl::float8_e3m4(0)),
160+
16);
161+
162+
// 1 non denorm from 0 in the negative direction
163+
EXPECT_EQ(
164+
CalculateDistanceInFloats<tsl::float8_e3m4>(
165+
-std::numeric_limits<tsl::float8_e3m4>::min(), tsl::float8_e3m4(0)),
166+
16);
167+
168+
// a & b have different signs
169+
EXPECT_EQ(CalculateDistanceInFloats<tsl::float8_e3m4>(
170+
-std::numeric_limits<tsl::float8_e3m4>::min(),
171+
std::numeric_limits<tsl::float8_e3m4>::min()),
172+
32);
173+
}
174+
121175
TYPED_TEST(FP8E4M3DistanceTest, F8E4M3Distance) {
122176
// a & b are equal, distance should be 0
123177
EXPECT_EQ(
124178
CalculateDistanceInFloats<TypeParam>(TypeParam(8.0), TypeParam(8.0)), 0);
125179

126180
// a & b have the same exponents
127-
EXPECT_EQ(CalculateDistanceInFloats<TypeParam>(TypeParam(8.0), TypeParam(13)),
128-
5);
181+
EXPECT_EQ(
182+
CalculateDistanceInFloats<TypeParam>(TypeParam(8.0), TypeParam(15.0)), 7);
129183

130184
// a & b have different exponents
131185
EXPECT_EQ(

xla/hlo/builder/lib/math.cc

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,8 @@ XlaOp IsNegZero(XlaOp operand) {
175175
case F32:
176176
return Eq(BitcastConvertType(operand, U32),
177177
ConstantR0WithType(&b, U32, uint32_t{1} << 31));
178+
case F8E3M4:
179+
case F8E4M3:
178180
case F8E5M2:
179181
case F8E4M3FN:
180182
case F8E4M3B11FNUZ:
@@ -973,8 +975,8 @@ XlaOp Igamma(XlaOp a, XlaOp x) {
973975
TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("Igamma", a));
974976
PrimitiveType a_x_type = a_shape.element_type();
975977
bool needs_upcast = false;
976-
for (PrimitiveType type :
977-
{BF16, F16, F8E5M2, F8E4M3FN, F8E4M3B11FNUZ, F8E5M2FNUZ, F8E4M3FNUZ}) {
978+
for (PrimitiveType type : {BF16, F16, F8E3M4, F8E4M3, F8E5M2, F8E4M3FN,
979+
F8E4M3B11FNUZ, F8E5M2FNUZ, F8E4M3FNUZ}) {
978980
if (a_shape.element_type() == type) {
979981
needs_upcast = true;
980982
break;
@@ -1026,8 +1028,8 @@ XlaOp IgammaGradA(XlaOp a, XlaOp x) {
10261028
}
10271029
TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("IgammaGradA", a));
10281030
bool needs_upcast = false;
1029-
for (PrimitiveType type :
1030-
{BF16, F16, F8E5M2, F8E4M3FN, F8E4M3B11FNUZ, F8E5M2FNUZ, F8E4M3FNUZ}) {
1031+
for (PrimitiveType type : {BF16, F16, F8E3M4, F8E4M3, F8E5M2, F8E4M3FN,
1032+
F8E4M3B11FNUZ, F8E5M2FNUZ, F8E4M3FNUZ}) {
10311033
if (a_shape.element_type() == type) {
10321034
needs_upcast = true;
10331035
break;

xla/hlo/evaluator/hlo_evaluator_typed_visitor.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1743,10 +1743,12 @@ extern template class HloEvaluatorTypedVisitor<complex64>;
17431743
extern template class HloEvaluatorTypedVisitor<complex128>;
17441744
extern template class HloEvaluatorTypedVisitor<bfloat16, float>;
17451745
extern template class HloEvaluatorTypedVisitor<tsl::float8_e5m2, float>;
1746+
extern template class HloEvaluatorTypedVisitor<tsl::float8_e4m3, float>;
17461747
extern template class HloEvaluatorTypedVisitor<tsl::float8_e4m3fn, float>;
17471748
extern template class HloEvaluatorTypedVisitor<tsl::float8_e4m3b11fnuz, float>;
17481749
extern template class HloEvaluatorTypedVisitor<tsl::float8_e5m2fnuz, float>;
17491750
extern template class HloEvaluatorTypedVisitor<tsl::float8_e4m3fnuz, float>;
1751+
extern template class HloEvaluatorTypedVisitor<tsl::float8_e3m4, float>;
17501752

17511753
} // namespace xla
17521754

xla/hlo/evaluator/hlo_evaluator_typed_visitor_float8.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,10 @@ limitations under the License.
1919

2020
namespace xla {
2121
template class HloEvaluatorTypedVisitor<tsl::float8_e5m2, float>;
22+
template class HloEvaluatorTypedVisitor<tsl::float8_e4m3, float>;
2223
template class HloEvaluatorTypedVisitor<tsl::float8_e4m3fn, float>;
2324
template class HloEvaluatorTypedVisitor<tsl::float8_e4m3b11fnuz, float>;
2425
template class HloEvaluatorTypedVisitor<tsl::float8_e5m2fnuz, float>;
2526
template class HloEvaluatorTypedVisitor<tsl::float8_e4m3fnuz, float>;
27+
template class HloEvaluatorTypedVisitor<tsl::float8_e3m4, float>;
2628
} // namespace xla

xla/hlo/translate/hlo_to_mhlo/tests/import.hlo

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -410,11 +410,17 @@ add {
410410
// CHECK: %[[VAL_9:.*]] = mhlo.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf8E4M3B11FNUZ>
411411
%constant.9 = f8e4m3b11fnuz[4] constant({1, 2, 3, 4})
412412

413-
// CHECK: %[[VAL_9:.*]] = mhlo.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf8E4M3FNUZ>
413+
// CHECK: %[[VAL_10:.*]] = mhlo.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf8E4M3FNUZ>
414414
%constant.10 = f8e4m3fnuz[4] constant({1, 2, 3, 4})
415415

416-
// CHECK: %[[VAL_9:.*]] = mhlo.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf8E5M2FNUZ>
416+
// CHECK: %[[VAL_11:.*]] = mhlo.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf8E5M2FNUZ>
417417
%constant.11 = f8e5m2fnuz[4] constant({1, 2, 3, 4})
418+
419+
// CHECK: %[[VAL_12:.*]] = mhlo.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf8E4M3>
420+
%constant.12 = f8e4m3[4] constant({1, 2, 3, 4})
421+
422+
// CHECK: %[[VAL_13:.*]] = mhlo.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf8E3M4>
423+
%constant.13 = f8e3m4[4] constant({1, 2, 3, 4})
418424
}
419425

420426
// TODO(b/129422361) Potentially update when copy, reshape, and conv have actual
@@ -524,7 +530,19 @@ add {
524530
%convert.11 = f8e5m2fnuz[4] convert(f32[4] %convert.10)
525531

526532
// CHECK-NEXT: %9 = mhlo.convert %8 : (tensor<4xf8E5M2FNUZ>) -> tensor<4xf32>
527-
ROOT %convert.12 = f32[4] convert(f8e5m2fnuz[4] %convert.11)
533+
%convert.12 = f32[4] convert(f8e5m2fnuz[4] %convert.11)
534+
535+
// CHECK-NEXT: %10 = mhlo.convert %9 : (tensor<4xf32>) -> tensor<4xf8E4M3>
536+
%convert.13 = f8e4m3[4] convert(f32[4] %convert.12)
537+
538+
// CHECK-NEXT: %11 = mhlo.convert %10 : (tensor<4xf8E4M3>) -> tensor<4xf32>
539+
%convert.14 = f32[4] convert(f8e4m3[4] %convert.13)
540+
541+
// CHECK-NEXT: %12 = mhlo.convert %11 : (tensor<4xf32>) -> tensor<4xf8E3M4>
542+
%convert.15 = f8e3m4[4] convert(f32[4] %convert.14)
543+
544+
// CHECK-NEXT: %13 = mhlo.convert %12 : (tensor<4xf8E3M4>) -> tensor<4xf32>
545+
ROOT %convert.16 = f32[4] convert(f8e3m4[4] %convert.15)
528546
}
529547

530548
// CHECK-LABEL: func private @test_stochastic_convert(%arg0: tensor<4x3xf32>, %arg1: tensor<4x3xui32>) -> tensor<4x3xi8>

xla/hlo/translate/mhlo_to_hlo/tests/export.mlir

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -600,6 +600,12 @@ func.func @main() {
600600
// CHECK: f8e5m2fnuz[4] constant({1, 2, 3, 4})
601601
%cst_15 = arith.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf8E5M2FNUZ>
602602

603+
// CHECK: f8e4m3[4] constant({1, 2, 3, 4})
604+
%cst_16 = arith.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf8E4M3>
605+
606+
// CHECK: f8e3m4[4] constant({1, 2, 3, 4})
607+
%cst_17 = arith.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf8E3M4>
608+
603609
func.return
604610
}
605611

@@ -729,7 +735,11 @@ func.func @main(%arg0: tensor<2xf32>) -> tensor<2xf32> {
729735
%5 = "mhlo.convert"(%4) : (tensor<2xf8E4M3FNUZ>) -> tensor<2xf32>
730736
%6 = "mhlo.convert"(%5) : (tensor<2xf32>) -> tensor<2xf8E5M2FNUZ>
731737
%7 = "mhlo.convert"(%6) : (tensor<2xf8E5M2FNUZ>) -> tensor<2xf32>
732-
func.return %7 : tensor<2xf32>
738+
%8 = "mhlo.convert"(%7) : (tensor<2xf32>) -> tensor<2xf8E4M3>
739+
%9 = "mhlo.convert"(%8) : (tensor<2xf8E4M3>) -> tensor<2xf32>
740+
%10 = "mhlo.convert"(%9) : (tensor<2xf32>) -> tensor<2xf8E3M4>
741+
%11 = "mhlo.convert"(%10) : (tensor<2xf8E3M4>) -> tensor<2xf32>
742+
func.return %11 : tensor<2xf32>
733743
}
734744

735745
// CHECK: ENTRY
@@ -741,7 +751,11 @@ func.func @main(%arg0: tensor<2xf32>) -> tensor<2xf32> {
741751
// CHECK: %[[E4M3FNUZ_VAL:.*]] = f8e4m3fnuz[2] convert(f32[2] %[[F32_VAL2]])
742752
// CHECK: %[[F32_VAL3:.*]] = f32[2] convert(f8e4m3fnuz[2] %[[E4M3FNUZ_VAL]])
743753
// CHECK: %[[E5M2FNUZ_VAL:.*]] = f8e5m2fnuz[2] convert(f32[2] %[[F32_VAL3]])
744-
// CHECK: ROOT %[[RESULT:.*]] = f32[2] convert(f8e5m2fnuz[2] %[[E5M2FNUZ_VAL]])
754+
// CHECK: %[[F32_VAL4:.*]] = f32[2] convert(f8e5m2fnuz[2] %[[E5M2FNUZ_VAL]])
755+
// CHECK: %[[E4M3_VAL:.*]] = f8e4m3[2] convert(f32[2] %[[F32_VAL4]])
756+
// CHECK: %[[F32_VAL5:.*]] = f32[2] convert(f8e4m3[2] %[[E4M3_VAL]])
757+
// CHECK: %[[E3M4_VAL:.*]] = f8e3m4[2] convert(f32[2] %[[F32_VAL5]])
758+
// CHECK: ROOT %[[F32_VAL6:.*]] = f32[2] convert(f8e3m4[2] %[[E3M4_VAL]])
745759

746760
// -----
747761

0 commit comments

Comments
 (0)