File tree 5 files changed +10
-0
lines changed 5 files changed +10
-0
lines changed Original file line number Diff line number Diff line change @@ -131,6 +131,8 @@ inline std::ostream& operator<<(std::ostream& os,
131
131
return os << " C128" ;
132
132
case XLA_FFI_DataType_TOKEN:
133
133
return os << " TOKEN" ;
134
+ case XLA_FFI_DataType_F4E2M1FN:
135
+ return os << " F4E2M1FN" ;
134
136
case XLA_FFI_DataType_F8E5M2:
135
137
return os << " F8E5M2" ;
136
138
case XLA_FFI_DataType_F8E3M4:
Original file line number Diff line number Diff line change @@ -201,6 +201,7 @@ typedef enum {
201
201
XLA_FFI_DataType_F8E4M3B11FNUZ = 23 ,
202
202
XLA_FFI_DataType_F8E5M2FNUZ = 24 ,
203
203
XLA_FFI_DataType_F8E4M3FNUZ = 25 ,
204
+ XLA_FFI_DataType_F4E2M1FN = 30 ,
204
205
} XLA_FFI_DataType;
205
206
// LINT.ThenChange(ffi_test.cc)
206
207
Original file line number Diff line number Diff line change @@ -79,6 +79,7 @@ enum class DataType : uint8_t {
79
79
F8E5M2FNUZ = XLA_FFI_DataType_F8E5M2FNUZ,
80
80
F8E4M3FNUZ = XLA_FFI_DataType_F8E4M3FNUZ,
81
81
F8E3M4 = XLA_FFI_DataType_F8E3M4,
82
+ F4E2M1FN = XLA_FFI_DataType_F4E2M1FN,
82
83
};
83
84
84
85
// Create aliases in ::xla::ffi namespace for all DataTypes, for consistency
@@ -106,6 +107,7 @@ inline constexpr DataType F8E4M3B11FNUZ = DataType::F8E4M3B11FNUZ;
106
107
inline constexpr DataType F8E5M2FNUZ = DataType::F8E5M2FNUZ;
107
108
inline constexpr DataType F8E4M3FNUZ = DataType::F8E4M3FNUZ;
108
109
inline constexpr DataType F8E3M4 = DataType::F8E3M4;
110
+ inline constexpr DataType F4E2M1FN = DataType::F4E2M1FN;
109
111
110
112
inline std::ostream& operator <<(std::ostream& os, const DataType dtype) {
111
113
return os << static_cast <XLA_FFI_DataType>(dtype);
@@ -127,6 +129,7 @@ constexpr size_t ByteWidth(DataType dtype) {
127
129
case DataType::F8E5M2FNUZ:
128
130
case DataType::F8E4M3FNUZ:
129
131
case DataType::F8E3M4:
132
+ case DataType::F4E2M1FN:
130
133
return 1 ;
131
134
case DataType::S16:
132
135
case DataType::U16:
Original file line number Diff line number Diff line change @@ -129,6 +129,7 @@ TEST(FfiTest, DataTypeEnumValue) {
129
129
130
130
EXPECT_EQ (encoded (PrimitiveType::TOKEN), encoded (DataType::TOKEN));
131
131
132
+ EXPECT_EQ (encoded (PrimitiveType::F4E2M1FN), encoded (DataType::F4E2M1FN));
132
133
EXPECT_EQ (encoded (PrimitiveType::F8E5M2), encoded (DataType::F8E5M2));
133
134
EXPECT_EQ (encoded (PrimitiveType::F8E4M3), encoded (DataType::F8E4M3));
134
135
EXPECT_EQ (encoded (PrimitiveType::F8E4M3FN), encoded (DataType::F8E4M3FN));
@@ -179,6 +180,8 @@ TEST(FfiTest, DataTypeByteWidth) {
179
180
EXPECT_EQ (primitive_util::ByteWidth (PrimitiveType::C128),
180
181
ByteWidth (DataType::C128));
181
182
183
+ EXPECT_EQ (primitive_util::ByteWidth (PrimitiveType::F4E2M1FN),
184
+ ByteWidth (DataType::F4E2M1FN));
182
185
EXPECT_EQ (primitive_util::ByteWidth (PrimitiveType::F8E5M2),
183
186
ByteWidth (DataType::F8E5M2));
184
187
EXPECT_EQ (primitive_util::ByteWidth (PrimitiveType::F8E4M3),
Original file line number Diff line number Diff line change @@ -264,6 +264,7 @@ static XLA_FFI_DataType ToDataType(PrimitiveType primitive_type) {
264
264
case PrimitiveType::C64:
265
265
case PrimitiveType::C128:
266
266
case PrimitiveType::TOKEN:
267
+ case PrimitiveType::F4E2M1FN:
267
268
case PrimitiveType::F8E5M2:
268
269
case PrimitiveType::F8E4M3:
269
270
case PrimitiveType::F8E4M3FN:
You can’t perform that action at this time.
0 commit comments