Skip to content

Commit a4074ba

Browse files
sergey-kozubGoogle-ML-Automation
authored andcommitted
PR #21380: Add F4E2M1FN and F8E8M0FNU types
Imported from GitHub PR #21380 Previous PR #19096 was rolled back, re-trying. This PR adds F4E2M1FN primitive type (4-bit float with 2 bits exponent and 1 bit mantissa), F8E8M0FNU primitive type (8-bit float with 8 bits exponent, no mantissa and no sign) and enables loads/stores in the same way S4/U4 type is implemented. This will enable using microscaling (MX) formats ([RFC](#18085)), such as MXFP4. ```c F4E2M1FN - Exponent bias: 1 - Maximum stored exponent value: 3 (binary 11) - Maximum unbiased exponent value: 3 - 1 = 2 - Minimum stored exponent value: 1 (binary 01) - Minimum unbiased exponent value: 1 − 1 = 0 - Has Positive and Negative zero - Doesn't have infinity - Doesn't have NaNs Additional details: - Zeros (+/-): S.00.0 - Max normal number: S.11.1 = ±2^(2) x (1 + 0.5) = ±6.0 - Min normal number: S.01.0 = ±2^(0) = ±1.0 - Min subnormal number: S.00.1 = ±2^(0) x 0.5 = ±0.5 F8E8M0FNU - Exponent bias: 127 - Maximum stored exponent value: 254 (binary 1111'1110) - Maximum unbiased exponent value: 254 - 127 = 127 - Minimum stored exponent value: 0 (binary 0000'0000) - Minimum unbiased exponent value: 0 − 127 = -127 - Doesn't have zero - Doesn't have infinity - NaN is encoded as binary 1111'1111 Additional details: - Zeros cannot be represented - Negative values cannot be represented - Mantissa is always 1 ``` Related PRs: - openxla/stablehlo#2582 - jax-ml/ml_dtypes#181 - llvm/llvm-project#95392 - llvm/llvm-project#108877 - jax-ml/ml_dtypes#166 - llvm/llvm-project#107127 - llvm/llvm-project#111028 Copybara import of the project: -- d7e00c4 by Sergey Kozub <[email protected]>: Add F4E2M1FN and F8E8M0FNU types Merging this change closes #21380 FUTURE_COPYBARA_INTEGRATE_REVIEW=#21380 from openxla:skozub/e2m1_e8m0 d7e00c4 PiperOrigin-RevId: 715070992
1 parent 7d912e5 commit a4074ba

File tree

79 files changed

+1853
-377
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

79 files changed

+1853
-377
lines changed

xla/array2d_test.cc

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,34 @@ TEST(Array2dTest, LinspaceF8E3M4) {
219219
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(2, 1)), 3.5);
220220
}
221221

222+
TEST(Array2dTest, LinspaceF4E2M1FN) {
223+
auto arr = MakeLinspaceArray2D<tsl::float4_e2m1fn>(1.0, 3.5, 3, 2);
224+
225+
EXPECT_EQ(arr->n1(), 3);
226+
EXPECT_EQ(arr->n2(), 2);
227+
228+
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(0, 0)), 1.0);
229+
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(0, 1)), 1.5);
230+
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(1, 0)), 2.0);
231+
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(1, 1)), 2.0); // 2.5 rounded down
232+
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(2, 0)), 3.0);
233+
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(2, 1)), 4.0); // 3.5 rounded up
234+
}
235+
236+
TEST(Array2dTest, LinspaceF8E8M0FNU) {
237+
auto arr = MakeLinspaceArray2D<tsl::float8_e8m0fnu>(1.0, 3.5, 3, 2);
238+
239+
EXPECT_EQ(arr->n1(), 3);
240+
EXPECT_EQ(arr->n2(), 2);
241+
242+
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(0, 0)), 1.0);
243+
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(0, 1)), 2.0); // 1.5 rounded up
244+
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(1, 0)), 2.0);
245+
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(1, 1)), 2.0); // 2.5 rounded down
246+
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(2, 0)), 4.0); // 3.0 rounded up
247+
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(2, 1)), 4.0); // 3.5 rounded up
248+
}
249+
222250
TEST(Array2dTest, Stringification) {
223251
auto arr = MakeLinspaceArray2D(1.0, 3.5, 3, 2);
224252
const std::string expected = R"([[1, 1.5],

xla/backends/gpu/codegen/transforms/expand_float_ops.cc

Lines changed: 119 additions & 72 deletions
Large diffs are not rendered by default.

xla/backends/gpu/codegen/transforms/lower_tensors.cc

Lines changed: 35 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,8 @@ std::tuple<Value, Value> GetI4IndexAndNibble(Value linear_index,
297297
mlir::LLVM::GEPOp CreateGep(TypedValue<mlir::RankedTensorType> tensor,
298298
Value linear_index, mlir::ImplicitLocOpBuilder& b) {
299299
Type element_type = tensor.getType().getElementType();
300-
if (element_type == b.getI4Type()) {
300+
if (element_type.isIntOrFloat() &&
301+
element_type.getIntOrFloatBitWidth() == 4) {
301302
element_type = b.getI8Type();
302303
}
303304
auto ptr = mlir::LLVM::LLVMPointerType::get(b.getContext());
@@ -326,7 +327,8 @@ struct RewriteTensorExtract : OpRewritePattern<mlir::tensor::ExtractOp> {
326327
auto linear_index = GetLinearIndex(op.getIndices(), b);
327328
Type element_type = op.getTensor().getType().getElementType();
328329
Value is_low_nibble = nullptr;
329-
if (element_type == rewriter.getI4Type()) {
330+
if (element_type.isIntOrFloat() &&
331+
element_type.getIntOrFloatBitWidth() == 4) {
330332
std::tie(linear_index, is_low_nibble) =
331333
GetI4IndexAndNibble(linear_index, b);
332334
}
@@ -341,7 +343,7 @@ struct RewriteTensorExtract : OpRewritePattern<mlir::tensor::ExtractOp> {
341343
auto high_value = b.create<mlir::arith::ShRUIOp>(
342344
load, b.create<mlir::arith::ConstantIntOp>(4, load.getType()));
343345
load = b.create<mlir::arith::TruncIOp>(
344-
op.getType(),
346+
rewriter.getI4Type(),
345347
b.create<mlir::arith::SelectOp>(is_low_nibble, load, high_value));
346348
}
347349

@@ -377,6 +379,7 @@ struct RewriteTransferRead : OpRewritePattern<mlir::vector::TransferReadOp> {
377379

378380
auto source = mlir::dyn_cast<mlir::TypedValue<mlir::RankedTensorType>>(
379381
op.getSource());
382+
mlir::Type source_element_type = source.getType().getElementType();
380383

381384
mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
382385
auto linear_index = GetLinearIndex(op.getIndices(), b);
@@ -385,7 +388,9 @@ struct RewriteTransferRead : OpRewritePattern<mlir::vector::TransferReadOp> {
385388
if (vector_type.getElementType().isInteger(1)) {
386389
vector_type = vector_type.cloneWith(std::nullopt, b.getI8Type());
387390
}
388-
if (op.getVectorType().getElementType().isInteger(4)) {
391+
mlir::Type gep_element_type = vector_type.getElementType();
392+
if (gep_element_type.isIntOrFloat() &&
393+
gep_element_type.getIntOrFloatBitWidth() == 4) {
389394
linear_index = b.create<arith::ShRUIOp>(
390395
linear_index,
391396
b.create<arith::ConstantIntOp>(1, linear_index.getType()));
@@ -397,11 +402,12 @@ struct RewriteTransferRead : OpRewritePattern<mlir::vector::TransferReadOp> {
397402
auto loaded =
398403
b.create<mlir::LLVM::LoadOp>(llvm_vector_type, gep).getResult();
399404

400-
if (source.getType().getElementType().isInteger(1)) {
405+
if (source_element_type.isInteger(1)) {
401406
Value zero = b.create<mlir::arith::ConstantOp>(
402407
mlir::DenseElementsAttr::get(vector_type, b.getI8IntegerAttr(0)));
403408
loaded = b.create<arith::CmpIOp>(arith::CmpIPredicate::ne, loaded, zero);
404-
} else if (source.getType().getElementType().isInteger(4)) {
409+
} else if (source_element_type.isIntOrFloat() &&
410+
source_element_type.getIntOrFloatBitWidth() == 4) {
405411
// LLVM and XLA pack i4s in opposite order, so we have to reshuffle the
406412
// elements.
407413
loaded = PermutePairsInVector(loaded, b);
@@ -430,7 +436,8 @@ struct RewriteTensorInsert : OpRewritePattern<mlir::tensor::InsertOp> {
430436
auto scalar_value = op.getScalar();
431437

432438
// For i4 we store 2 values into one byte. This needs special handling here.
433-
if (tensor_dest.getType().getElementType() == rewriter.getI4Type()) {
439+
if (tensor_dest.getType().getElementType().isIntOrFloat() &&
440+
tensor_dest.getType().getElementType().getIntOrFloatBitWidth() == 4) {
434441
// We need to use directly op.getDest() as input, otherwise the following
435442
// rewrite might remove the only user of it.
436443
tensor_dest = op.getDest();
@@ -448,6 +455,10 @@ struct RewriteTensorInsert : OpRewritePattern<mlir::tensor::InsertOp> {
448455
auto tensor_dest_i8 =
449456
b.create<UnrealizedConversionCastOp>(tensor_ty, tensor_dest)
450457
.getResult(0);
458+
if (scalar_value.getType() != rewriter.getI4Type()) {
459+
scalar_value =
460+
b.create<arith::BitcastOp>(rewriter.getI4Type(), scalar_value);
461+
}
451462
scalar_value = b.create<mlir::arith::ExtUIOp>(ty, scalar_value);
452463

453464
// We need AtomicRMWOp because it can happen that different threads try to
@@ -507,12 +518,14 @@ struct RewriteTransferWrite : OpRewritePattern<mlir::vector::TransferWriteOp> {
507518
auto linear_index = GetLinearIndex(op.getIndices(), b);
508519

509520
mlir::Value vector_value = op.getVector();
510-
if (op.getVectorType().getElementType().isInteger(1)) {
521+
mlir::Type vector_element_type = op.getVectorType().getElementType();
522+
if (vector_element_type.isInteger(1)) {
511523
vector_value = b.create<arith::ExtUIOp>(
512524
op.getVectorType().cloneWith(std::nullopt, b.getI8Type()),
513525
vector_value);
514526
}
515-
if (op.getVectorType().getElementType().isInteger(4)) {
527+
if (vector_element_type.isIntOrFloat() &&
528+
vector_element_type.getIntOrFloatBitWidth() == 4) {
516529
linear_index = b.create<arith::ShRUIOp>(
517530
linear_index,
518531
b.create<arith::ConstantIntOp>(1, linear_index.getType()));
@@ -577,21 +590,19 @@ mlir::LLVM::GlobalOp CreateGlobalOp(mlir::Attribute value,
577590
// Needed to support complex element type.
578591
mlir::LLVMTypeConverter converter(b.getContext());
579592
auto llvm_element_type = converter.convertType(element_type);
580-
if (mlir::isa<mlir::IntegerType>(element_type)) {
581-
int bit_width = mlir::cast<mlir::IntegerType>(element_type).getWidth();
582-
if (bit_width == 4) {
583-
num_elements = CeilOfRatio<int64_t>(num_elements, 2);
584-
llvm_element_type = b.getI8Type();
585-
auto unpacked_data =
586-
mlir::cast<mlir::DenseElementsAttr>(value).getRawData();
587-
std::vector<char> packed_data(num_elements);
588-
absl::Span<char> packed_data_span =
589-
absl::MakeSpan(packed_data.data(), packed_data.size());
590-
PackIntN(4, unpacked_data, packed_data_span);
591-
value = mlir::DenseElementsAttr::getFromRawBuffer(
592-
mlir::RankedTensorType::get({num_elements}, llvm_element_type),
593-
packed_data);
594-
}
593+
if (element_type.isIntOrFloat() &&
594+
element_type.getIntOrFloatBitWidth() == 4) {
595+
num_elements = CeilOfRatio<int64_t>(num_elements, 2);
596+
llvm_element_type = b.getI8Type();
597+
auto unpacked_data =
598+
mlir::cast<mlir::DenseElementsAttr>(value).getRawData();
599+
std::vector<char> packed_data(num_elements);
600+
absl::Span<char> packed_data_span =
601+
absl::MakeSpan(packed_data.data(), packed_data.size());
602+
PackIntN(4, unpacked_data, packed_data_span);
603+
value = mlir::DenseElementsAttr::getFromRawBuffer(
604+
mlir::RankedTensorType::get({num_elements}, llvm_element_type),
605+
packed_data);
595606
}
596607
auto array_ty =
597608
mlir::LLVM::LLVMArrayType::get(llvm_element_type, num_elements);

xla/backends/gpu/codegen/transforms/tests/expand_float_ops.mlir

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,3 +115,53 @@ module {
115115
// CHECK: %[[EXT:.*]] = arith.extf {{.*}} : bf16 to f32
116116
// CHECK: arith.truncf %[[EXT]] : f32 to f16
117117
// CHECK-NOT: arith.truncf
118+
119+
// -----
120+
121+
module {
122+
func.func @f4_to_f16(%arg0: f4E2M1FN) -> f16 {
123+
%ret = arith.extf %arg0 : f4E2M1FN to f16
124+
return %ret : f16
125+
}
126+
}
127+
128+
// CHECK-LABEL: @f4_to_f16
129+
// CHECK-NOT: arith.extf
130+
131+
// -----
132+
133+
module {
134+
func.func @f16_to_f4(%arg0: f16) -> f4E2M1FN {
135+
%ret = arith.truncf %arg0 : f16 to f4E2M1FN
136+
return %ret : f4E2M1FN
137+
}
138+
}
139+
140+
// CHECK-LABEL: @f16_to_f4
141+
// CHECK-NOT: arith.truncf
142+
143+
// -----
144+
145+
module {
146+
func.func @f4_abs(%arg0: f4E2M1FN) -> f4E2M1FN {
147+
%ret = math.absf %arg0 : f4E2M1FN
148+
return %ret : f4E2M1FN
149+
}
150+
}
151+
152+
// CHECK-LABEL: @f4_abs
153+
// CHECK-NOT: math.absf
154+
// CHECK: arith.constant 7 : i4
155+
156+
// -----
157+
158+
module {
159+
func.func @e8m0_abs(%arg0: f8E8M0FNU) -> f8E8M0FNU {
160+
%ret = math.absf %arg0 : f8E8M0FNU
161+
return %ret : f8E8M0FNU
162+
}
163+
}
164+
165+
// CHECK-LABEL: @e8m0_abs
166+
// CHECK-NOT: math.absf
167+
// CHECK: return %arg0

xla/backends/gpu/codegen/transforms/tests/lower_tensors.mlir

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -763,4 +763,44 @@ func.func @for_op(%arg0: tensor<500xf32>) -> f32 {
763763

764764
// CHECK-LABEL: @for_op
765765
// CHECK: scf.for {{.*}} -> (vector<4xf32>) {
766-
// CHECK-NEXT: scf.for {{.*}} -> (vector<4xf32>) {
766+
// CHECK-NEXT: scf.for {{.*}} -> (vector<4xf32>) {
767+
768+
// -----
769+
770+
func.func @f4_constant(%arg0: tensor<3xf4E2M1FN>, %arg1: index) -> f4E2M1FN {
771+
%cst = arith.constant dense<[0.5, -0.5, 2.5]> : tensor<3xf4E2M1FN>
772+
%extracted = tensor.extract %arg0[%arg1] : tensor<3xf4E2M1FN>
773+
%extracted_0 = tensor.extract %cst[%arg1] : tensor<3xf4E2M1FN>
774+
%0 = arith.addf %extracted, %extracted_0 : f4E2M1FN
775+
return %0 : f4E2M1FN
776+
}
777+
// CHECK: llvm.mlir.global private constant
778+
// CHECK-SAME: dense<[25, 64]>
779+
// CHECK-LABEL: @f4_constant
780+
781+
// -----
782+
783+
func.func @transfer_read_f4(%arg0: tensor<43xf4E2M1FN> {xla.slice_index = 1}) -> vector<2xf4E2M1FN> {
784+
%c16 = arith.constant 16 : index
785+
%c0 = arith.constant 0.0 : f4E2M1FN
786+
%out = vector.transfer_read %arg0[%c16], %c0 : tensor<43xf4E2M1FN>, vector<2xf4E2M1FN>
787+
func.return %out : vector<2xf4E2M1FN>
788+
}
789+
// CHECK-LABEL: @transfer_read_f4
790+
// CHECK: %[[PTR:.*]] = llvm.getelementptr inbounds %{{.*}}[8]
791+
// CHECK: llvm.load %[[PTR]] : !llvm.ptr -> vector<2xi4>
792+
// CHECK: %[[OUT:.*]] = builtin.unrealized_conversion_cast %{{.*}} : vector<2xi4> to vector<2xf4E2M1FN>
793+
// CHECK: return %[[OUT]] : vector<2xf4E2M1FN>
794+
795+
// -----
796+
797+
func.func @transfer_write_f4(%arg0: tensor<43xf4E2M1FN> {xla.slice_index = 1},
798+
%arg1: vector<2xf4E2M1FN>) -> tensor<43xf4E2M1FN> {
799+
%c10 = arith.constant 10 : index
800+
%out = vector.transfer_write %arg1, %arg0[%c10] : vector<2xf4E2M1FN>, tensor<43xf4E2M1FN>
801+
func.return %out : tensor<43xf4E2M1FN>
802+
}
803+
// CHECK-LABEL: @transfer_write_f4
804+
// CHECK: %[[PTR:.*]] = llvm.getelementptr inbounds %arg0[5] : (!llvm.ptr) -> !llvm.ptr, i8
805+
// CHECK: %[[OUT:.*]] = builtin.unrealized_conversion_cast %{{.*}} : vector<2xf4E2M1FN> to vector<2xi4>
806+
// CHECK: llvm.store %[[OUT]], %[[PTR]] : vector<2xi4>, !llvm.ptr

xla/comparison_util.h

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -193,8 +193,13 @@ class Comparison {
193193
// -NaN < -Inf < -Finite < -0 < +0 < +Finite < +Inf < +NaN
194194
// Reference:
195195
// https://www.tensorflow.org/xla/operation_semantics#element-wise_comparison_operations
196-
using R = SignedIntegerTypeForSizeType<sizeof(T)>;
197-
return GetComparator<R>()(ToSignMagnitude(a), ToSignMagnitude(b));
196+
if constexpr (std::numeric_limits<T>::is_signed) {
197+
using R = SignedIntegerTypeForSizeType<sizeof(T)>;
198+
return GetComparator<R>()(ToSignMagnitude(a), ToSignMagnitude(b));
199+
} else {
200+
using R = UnsignedIntegerTypeForSizeType<sizeof(T)>;
201+
return GetComparator<R>()(ToSignMagnitude(a), ToSignMagnitude(b));
202+
}
198203
}
199204
}
200205
// Applies the comparison from this Comparison's direction and ordering.

xla/ffi/api/api.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,8 @@ inline std::ostream& operator<<(std::ostream& os,
131131
return os << "C128";
132132
case XLA_FFI_DataType_TOKEN:
133133
return os << "TOKEN";
134+
case XLA_FFI_DataType_F4E2M1FN:
135+
return os << "F4E2M1FN";
134136
case XLA_FFI_DataType_F8E5M2:
135137
return os << "F8E5M2";
136138
case XLA_FFI_DataType_F8E3M4:
@@ -145,6 +147,8 @@ inline std::ostream& operator<<(std::ostream& os,
145147
return os << "F8E5M2FNUZ";
146148
case XLA_FFI_DataType_F8E4M3FNUZ:
147149
return os << "F8E4M3FNUZ";
150+
case XLA_FFI_DataType_F8E8M0FNU:
151+
return os << "F8E8M0FNU";
148152
}
149153
}
150154

xla/ffi/api/c_api.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,8 @@ typedef enum {
201201
XLA_FFI_DataType_F8E4M3B11FNUZ = 23,
202202
XLA_FFI_DataType_F8E5M2FNUZ = 24,
203203
XLA_FFI_DataType_F8E4M3FNUZ = 25,
204+
XLA_FFI_DataType_F4E2M1FN = 32,
205+
XLA_FFI_DataType_F8E8M0FNU = 33,
204206
} XLA_FFI_DataType;
205207
// LINT.ThenChange(ffi_test.cc)
206208

xla/ffi/api/ffi.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,8 @@ enum class DataType : uint8_t {
7979
F8E5M2FNUZ = XLA_FFI_DataType_F8E5M2FNUZ,
8080
F8E4M3FNUZ = XLA_FFI_DataType_F8E4M3FNUZ,
8181
F8E3M4 = XLA_FFI_DataType_F8E3M4,
82+
F4E2M1FN = XLA_FFI_DataType_F4E2M1FN,
83+
F8E8M0FNU = XLA_FFI_DataType_F8E8M0FNU,
8284
};
8385

8486
// Create aliases in ::xla::ffi namespace for all DataTypes, for consistency
@@ -106,6 +108,8 @@ inline constexpr DataType F8E4M3B11FNUZ = DataType::F8E4M3B11FNUZ;
106108
inline constexpr DataType F8E5M2FNUZ = DataType::F8E5M2FNUZ;
107109
inline constexpr DataType F8E4M3FNUZ = DataType::F8E4M3FNUZ;
108110
inline constexpr DataType F8E3M4 = DataType::F8E3M4;
111+
inline constexpr DataType F4E2M1FN = DataType::F4E2M1FN;
112+
inline constexpr DataType F8E8M0FNU = DataType::F8E8M0FNU;
109113

110114
inline std::ostream& operator<<(std::ostream& os, const DataType dtype) {
111115
return os << static_cast<XLA_FFI_DataType>(dtype);
@@ -127,6 +131,8 @@ constexpr size_t ByteWidth(DataType dtype) {
127131
case DataType::F8E5M2FNUZ:
128132
case DataType::F8E4M3FNUZ:
129133
case DataType::F8E3M4:
134+
case DataType::F4E2M1FN:
135+
case DataType::F8E8M0FNU:
130136
return 1;
131137
case DataType::S16:
132138
case DataType::U16:

xla/ffi/api/ffi_test.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@ TEST(FfiTest, DataTypeEnumValue) {
129129

130130
EXPECT_EQ(encoded(PrimitiveType::TOKEN), encoded(DataType::TOKEN));
131131

132+
EXPECT_EQ(encoded(PrimitiveType::F4E2M1FN), encoded(DataType::F4E2M1FN));
132133
EXPECT_EQ(encoded(PrimitiveType::F8E5M2), encoded(DataType::F8E5M2));
133134
EXPECT_EQ(encoded(PrimitiveType::F8E4M3), encoded(DataType::F8E4M3));
134135
EXPECT_EQ(encoded(PrimitiveType::F8E4M3FN), encoded(DataType::F8E4M3FN));
@@ -137,6 +138,7 @@ TEST(FfiTest, DataTypeEnumValue) {
137138
EXPECT_EQ(encoded(PrimitiveType::F8E5M2FNUZ), encoded(DataType::F8E5M2FNUZ));
138139
EXPECT_EQ(encoded(PrimitiveType::F8E4M3FNUZ), encoded(DataType::F8E4M3FNUZ));
139140
EXPECT_EQ(encoded(PrimitiveType::F8E3M4), encoded(DataType::F8E3M4));
141+
EXPECT_EQ(encoded(PrimitiveType::F8E8M0FNU), encoded(DataType::F8E8M0FNU));
140142
}
141143

142144
TEST(FfiTest, DataTypeByteWidth) {
@@ -179,6 +181,8 @@ TEST(FfiTest, DataTypeByteWidth) {
179181
EXPECT_EQ(primitive_util::ByteWidth(PrimitiveType::C128),
180182
ByteWidth(DataType::C128));
181183

184+
EXPECT_EQ(primitive_util::ByteWidth(PrimitiveType::F4E2M1FN),
185+
ByteWidth(DataType::F4E2M1FN));
182186
EXPECT_EQ(primitive_util::ByteWidth(PrimitiveType::F8E5M2),
183187
ByteWidth(DataType::F8E5M2));
184188
EXPECT_EQ(primitive_util::ByteWidth(PrimitiveType::F8E4M3),
@@ -193,6 +197,8 @@ TEST(FfiTest, DataTypeByteWidth) {
193197
ByteWidth(DataType::F8E4M3FNUZ));
194198
EXPECT_EQ(primitive_util::ByteWidth(PrimitiveType::F8E3M4),
195199
ByteWidth(DataType::F8E3M4));
200+
EXPECT_EQ(primitive_util::ByteWidth(PrimitiveType::F8E8M0FNU),
201+
ByteWidth(DataType::F8E8M0FNU));
196202
}
197203

198204
TEST(FfiTest, ErrorEnumValue) {

xla/ffi/call_frame.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,13 +264,15 @@ static XLA_FFI_DataType ToDataType(PrimitiveType primitive_type) {
264264
case PrimitiveType::C64:
265265
case PrimitiveType::C128:
266266
case PrimitiveType::TOKEN:
267+
case PrimitiveType::F4E2M1FN:
267268
case PrimitiveType::F8E5M2:
268269
case PrimitiveType::F8E4M3:
269270
case PrimitiveType::F8E4M3FN:
270271
case PrimitiveType::F8E4M3B11FNUZ:
271272
case PrimitiveType::F8E5M2FNUZ:
272273
case PrimitiveType::F8E4M3FNUZ:
273274
case PrimitiveType::F8E3M4:
275+
case PrimitiveType::F8E8M0FNU:
274276
return static_cast<XLA_FFI_DataType>(primitive_type);
275277
default:
276278
DCHECK(false) << "Unsupported primitive type "

0 commit comments

Comments
 (0)