Skip to content

Commit 1f0e19f

Browse files
committed
Add F4E2M1FN type: FFI
1 parent daaa3af commit 1f0e19f

File tree

5 files changed

+10
-0
lines changed

5 files changed

+10
-0
lines changed

xla/ffi/api/api.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,8 @@ inline std::ostream& operator<<(std::ostream& os,
131131
return os << "C128";
132132
case XLA_FFI_DataType_TOKEN:
133133
return os << "TOKEN";
134+
case XLA_FFI_DataType_F4E2M1FN:
135+
return os << "F4E2M1FN";
134136
case XLA_FFI_DataType_F8E5M2:
135137
return os << "F8E5M2";
136138
case XLA_FFI_DataType_F8E3M4:

xla/ffi/api/c_api.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,7 @@ typedef enum {
201201
XLA_FFI_DataType_F8E4M3B11FNUZ = 23,
202202
XLA_FFI_DataType_F8E5M2FNUZ = 24,
203203
XLA_FFI_DataType_F8E4M3FNUZ = 25,
204+
XLA_FFI_DataType_F4E2M1FN = 30,
204205
} XLA_FFI_DataType;
205206
// LINT.ThenChange(ffi_test.cc)
206207

xla/ffi/api/ffi.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ enum class DataType : uint8_t {
7979
F8E5M2FNUZ = XLA_FFI_DataType_F8E5M2FNUZ,
8080
F8E4M3FNUZ = XLA_FFI_DataType_F8E4M3FNUZ,
8181
F8E3M4 = XLA_FFI_DataType_F8E3M4,
82+
F4E2M1FN = XLA_FFI_DataType_F4E2M1FN,
8283
};
8384

8485
// Create aliases in ::xla::ffi namespace for all DataTypes, for consistency
@@ -106,6 +107,7 @@ inline constexpr DataType F8E4M3B11FNUZ = DataType::F8E4M3B11FNUZ;
106107
inline constexpr DataType F8E5M2FNUZ = DataType::F8E5M2FNUZ;
107108
inline constexpr DataType F8E4M3FNUZ = DataType::F8E4M3FNUZ;
108109
inline constexpr DataType F8E3M4 = DataType::F8E3M4;
110+
inline constexpr DataType F4E2M1FN = DataType::F4E2M1FN;
109111

110112
inline std::ostream& operator<<(std::ostream& os, const DataType dtype) {
111113
return os << static_cast<XLA_FFI_DataType>(dtype);
@@ -127,6 +129,7 @@ constexpr size_t ByteWidth(DataType dtype) {
127129
case DataType::F8E5M2FNUZ:
128130
case DataType::F8E4M3FNUZ:
129131
case DataType::F8E3M4:
132+
case DataType::F4E2M1FN:
130133
return 1;
131134
case DataType::S16:
132135
case DataType::U16:

xla/ffi/api/ffi_test.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@ TEST(FfiTest, DataTypeEnumValue) {
129129

130130
EXPECT_EQ(encoded(PrimitiveType::TOKEN), encoded(DataType::TOKEN));
131131

132+
EXPECT_EQ(encoded(PrimitiveType::F4E2M1FN), encoded(DataType::F4E2M1FN));
132133
EXPECT_EQ(encoded(PrimitiveType::F8E5M2), encoded(DataType::F8E5M2));
133134
EXPECT_EQ(encoded(PrimitiveType::F8E4M3), encoded(DataType::F8E4M3));
134135
EXPECT_EQ(encoded(PrimitiveType::F8E4M3FN), encoded(DataType::F8E4M3FN));
@@ -179,6 +180,8 @@ TEST(FfiTest, DataTypeByteWidth) {
179180
EXPECT_EQ(primitive_util::ByteWidth(PrimitiveType::C128),
180181
ByteWidth(DataType::C128));
181182

183+
EXPECT_EQ(primitive_util::ByteWidth(PrimitiveType::F4E2M1FN),
184+
ByteWidth(DataType::F4E2M1FN));
182185
EXPECT_EQ(primitive_util::ByteWidth(PrimitiveType::F8E5M2),
183186
ByteWidth(DataType::F8E5M2));
184187
EXPECT_EQ(primitive_util::ByteWidth(PrimitiveType::F8E4M3),

xla/ffi/call_frame.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,7 @@ static XLA_FFI_DataType ToDataType(PrimitiveType primitive_type) {
264264
case PrimitiveType::C64:
265265
case PrimitiveType::C128:
266266
case PrimitiveType::TOKEN:
267+
case PrimitiveType::F4E2M1FN:
267268
case PrimitiveType::F8E5M2:
268269
case PrimitiveType::F8E4M3:
269270
case PrimitiveType::F8E4M3FN:

0 commit comments

Comments
 (0)