Skip to content

Commit b912750

Browse files
sergey-kozubGoogle-ML-Automation
authored andcommitted
PR #21380: Add F4E2M1FN and F8E8M0FNU types
Imported from GitHub PR #21380 Previous PR #19096 was rolled back, re-trying. This PR adds F4E2M1FN primitive type (4-bit float with 2 bits exponent and 1 bit mantissa), F8E8M0FNU primitive type (8-bit float with 8 bits exponent, no mantissa and no sign) and enables loads/stores in the same way S4/U4 type is implemented. This will enable using microscaling (MX) formats ([RFC](#18085)), such as MXFP4. ```c F4E2M1FN - Exponent bias: 1 - Maximum stored exponent value: 3 (binary 11) - Maximum unbiased exponent value: 3 - 1 = 2 - Minimum stored exponent value: 1 (binary 01) - Minimum unbiased exponent value: 1 − 1 = 0 - Has Positive and Negative zero - Doesn't have infinity - Doesn't have NaNs Additional details: - Zeros (+/-): S.00.0 - Max normal number: S.11.1 = ±2^(2) x (1 + 0.5) = ±6.0 - Min normal number: S.01.0 = ±2^(0) = ±1.0 - Min subnormal number: S.00.1 = ±2^(0) x 0.5 = ±0.5 F8E8M0FNU - Exponent bias: 127 - Maximum stored exponent value: 254 (binary 1111'1110) - Maximum unbiased exponent value: 254 - 127 = 127 - Minimum stored exponent value: 0 (binary 0000'0000) - Minimum unbiased exponent value: 0 − 127 = -127 - Doesn't have zero - Doesn't have infinity - NaN is encoded as binary 1111'1111 Additional details: - Zeros cannot be represented - Negative values cannot be represented - Mantissa is always 1 ``` Related PRs: - openxla/stablehlo#2582 - jax-ml/ml_dtypes#181 - llvm/llvm-project#95392 - llvm/llvm-project#108877 - jax-ml/ml_dtypes#166 - llvm/llvm-project#107127 - llvm/llvm-project#111028 Copybara import of the project: -- d7e00c4 by Sergey Kozub <[email protected]>: Add F4E2M1FN and F8E8M0FNU types Merging this change closes #21380 COPYBARA_INTEGRATE_REVIEW=#21380 from openxla:skozub/e2m1_e8m0 d7e00c4 PiperOrigin-RevId: 715434229
1 parent 88a2497 commit b912750

File tree

79 files changed

+1851
-376
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

79 files changed

+1851
-376
lines changed

xla/array2d_test.cc

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,34 @@ TEST(Array2dTest, LinspaceF8E3M4) {
219219
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(2, 1)), 3.5);
220220
}
221221

222+
TEST(Array2dTest, LinspaceF4E2M1FN) {
223+
auto arr = MakeLinspaceArray2D<tsl::float4_e2m1fn>(1.0, 3.5, 3, 2);
224+
225+
EXPECT_EQ(arr->n1(), 3);
226+
EXPECT_EQ(arr->n2(), 2);
227+
228+
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(0, 0)), 1.0);
229+
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(0, 1)), 1.5);
230+
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(1, 0)), 2.0);
231+
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(1, 1)), 2.0); // 2.5 rounded down
232+
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(2, 0)), 3.0);
233+
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(2, 1)), 4.0); // 3.5 rounded up
234+
}
235+
236+
TEST(Array2dTest, LinspaceF8E8M0FNU) {
237+
auto arr = MakeLinspaceArray2D<tsl::float8_e8m0fnu>(1.0, 3.5, 3, 2);
238+
239+
EXPECT_EQ(arr->n1(), 3);
240+
EXPECT_EQ(arr->n2(), 2);
241+
242+
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(0, 0)), 1.0);
243+
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(0, 1)), 2.0); // 1.5 rounded up
244+
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(1, 0)), 2.0);
245+
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(1, 1)), 2.0); // 2.5 rounded down
246+
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(2, 0)), 4.0); // 3.0 rounded up
247+
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(2, 1)), 4.0); // 3.5 rounded up
248+
}
249+
222250
TEST(Array2dTest, Stringification) {
223251
auto arr = MakeLinspaceArray2D(1.0, 3.5, 3, 2);
224252
const std::string expected = R"([[1, 1.5],

xla/backends/gpu/codegen/transforms/expand_float_ops.cc

Lines changed: 119 additions & 72 deletions
Large diffs are not rendered by default.

xla/backends/gpu/codegen/transforms/lower_tensors.cc

Lines changed: 35 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -299,7 +299,8 @@ std::tuple<Value, Value> GetI4IndexAndNibble(Value linear_index,
299299
ml::GEPOp CreateGep(TypedValue<mlir::RankedTensorType> tensor,
300300
Value linear_index, mlir::ImplicitLocOpBuilder& b) {
301301
Type element_type = tensor.getType().getElementType();
302-
if (element_type == b.getI4Type()) {
302+
if (element_type.isIntOrFloat() &&
303+
element_type.getIntOrFloatBitWidth() == 4) {
303304
element_type = b.getI8Type();
304305
}
305306
auto ptr = ml::LLVMPointerType::get(b.getContext());
@@ -328,7 +329,8 @@ struct RewriteTensorExtract : OpRewritePattern<mlir::tensor::ExtractOp> {
328329
auto linear_index = GetLinearIndex(op.getIndices(), b);
329330
Type element_type = op.getTensor().getType().getElementType();
330331
Value is_low_nibble = nullptr;
331-
if (element_type == rewriter.getI4Type()) {
332+
if (element_type.isIntOrFloat() &&
333+
element_type.getIntOrFloatBitWidth() == 4) {
332334
std::tie(linear_index, is_low_nibble) =
333335
GetI4IndexAndNibble(linear_index, b);
334336
}
@@ -342,7 +344,7 @@ struct RewriteTensorExtract : OpRewritePattern<mlir::tensor::ExtractOp> {
342344
auto high_value = b.create<mlir::arith::ShRUIOp>(
343345
load, b.create<mlir::arith::ConstantIntOp>(4, load.getType()));
344346
load = b.create<mlir::arith::TruncIOp>(
345-
op.getType(),
347+
rewriter.getI4Type(),
346348
b.create<mlir::arith::SelectOp>(is_low_nibble, load, high_value));
347349
}
348350

@@ -378,6 +380,7 @@ struct RewriteTransferRead : OpRewritePattern<vector::TransferReadOp> {
378380

379381
auto source = mlir::dyn_cast<mlir::TypedValue<mlir::RankedTensorType>>(
380382
op.getSource());
383+
mlir::Type source_element_type = source.getType().getElementType();
381384

382385
mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
383386
auto linear_index = GetLinearIndex(op.getIndices(), b);
@@ -386,7 +389,9 @@ struct RewriteTransferRead : OpRewritePattern<vector::TransferReadOp> {
386389
if (vector_type.getElementType().isInteger(1)) {
387390
vector_type = vector_type.cloneWith(std::nullopt, b.getI8Type());
388391
}
389-
if (op.getVectorType().getElementType().isInteger(4)) {
392+
mlir::Type gep_element_type = vector_type.getElementType();
393+
if (gep_element_type.isIntOrFloat() &&
394+
gep_element_type.getIntOrFloatBitWidth() == 4) {
390395
linear_index = b.create<arith::ShRUIOp>(
391396
linear_index,
392397
b.create<arith::ConstantIntOp>(1, linear_index.getType()));
@@ -397,11 +402,12 @@ struct RewriteTransferRead : OpRewritePattern<vector::TransferReadOp> {
397402
auto llvm_vector_type = converter.convertType(vector_type);
398403
auto loaded = b.create<ml::LoadOp>(llvm_vector_type, gep).getResult();
399404

400-
if (source.getType().getElementType().isInteger(1)) {
405+
if (source_element_type.isInteger(1)) {
401406
Value zero = b.create<mlir::arith::ConstantOp>(
402407
mlir::DenseElementsAttr::get(vector_type, b.getI8IntegerAttr(0)));
403408
loaded = b.create<arith::CmpIOp>(arith::CmpIPredicate::ne, loaded, zero);
404-
} else if (source.getType().getElementType().isInteger(4)) {
409+
} else if (source_element_type.isIntOrFloat() &&
410+
source_element_type.getIntOrFloatBitWidth() == 4) {
405411
// LLVM and XLA pack i4s in opposite order, so we have to reshuffle the
406412
// elements.
407413
loaded = PermutePairsInVector(loaded, b);
@@ -430,7 +436,8 @@ struct RewriteTensorInsert : OpRewritePattern<mlir::tensor::InsertOp> {
430436
auto scalar_value = op.getScalar();
431437

432438
// For i4 we store 2 values into one byte. This needs special handling here.
433-
if (tensor_dest.getType().getElementType() == rewriter.getI4Type()) {
439+
if (tensor_dest.getType().getElementType().isIntOrFloat() &&
440+
tensor_dest.getType().getElementType().getIntOrFloatBitWidth() == 4) {
434441
// We need to use directly op.getDest() as input, otherwise the following
435442
// rewrite might remove the only user of it.
436443
tensor_dest = op.getDest();
@@ -448,6 +455,10 @@ struct RewriteTensorInsert : OpRewritePattern<mlir::tensor::InsertOp> {
448455
auto tensor_dest_i8 =
449456
b.create<UnrealizedConversionCastOp>(tensor_ty, tensor_dest)
450457
.getResult(0);
458+
if (scalar_value.getType() != rewriter.getI4Type()) {
459+
scalar_value =
460+
b.create<arith::BitcastOp>(rewriter.getI4Type(), scalar_value);
461+
}
451462
scalar_value = b.create<mlir::arith::ExtUIOp>(ty, scalar_value);
452463

453464
// We need AtomicRMWOp because it can happen that different threads try to
@@ -507,12 +518,14 @@ struct RewriteTransferWrite : OpRewritePattern<vector::TransferWriteOp> {
507518
auto linear_index = GetLinearIndex(op.getIndices(), b);
508519

509520
mlir::Value vector_value = op.getVector();
510-
if (op.getVectorType().getElementType().isInteger(1)) {
521+
mlir::Type vector_element_type = op.getVectorType().getElementType();
522+
if (vector_element_type.isInteger(1)) {
511523
vector_value = b.create<arith::ExtUIOp>(
512524
op.getVectorType().cloneWith(std::nullopt, b.getI8Type()),
513525
vector_value);
514526
}
515-
if (op.getVectorType().getElementType().isInteger(4)) {
527+
if (vector_element_type.isIntOrFloat() &&
528+
vector_element_type.getIntOrFloatBitWidth() == 4) {
516529
linear_index = b.create<arith::ShRUIOp>(
517530
linear_index,
518531
b.create<arith::ConstantIntOp>(1, linear_index.getType()));
@@ -575,21 +588,19 @@ ml::GlobalOp CreateGlobalOp(mlir::Attribute value,
575588
// Needed to support complex element type.
576589
mlir::LLVMTypeConverter converter(b.getContext());
577590
auto llvm_element_type = converter.convertType(element_type);
578-
if (mlir::isa<mlir::IntegerType>(element_type)) {
579-
int bit_width = mlir::cast<mlir::IntegerType>(element_type).getWidth();
580-
if (bit_width == 4) {
581-
num_elements = CeilOfRatio<int64_t>(num_elements, 2);
582-
llvm_element_type = b.getI8Type();
583-
auto unpacked_data =
584-
mlir::cast<mlir::DenseElementsAttr>(value).getRawData();
585-
std::vector<char> packed_data(num_elements);
586-
absl::Span<char> packed_data_span =
587-
absl::MakeSpan(packed_data.data(), packed_data.size());
588-
PackIntN(4, unpacked_data, packed_data_span);
589-
value = mlir::DenseElementsAttr::getFromRawBuffer(
590-
mlir::RankedTensorType::get({num_elements}, llvm_element_type),
591-
packed_data);
592-
}
591+
if (element_type.isIntOrFloat() &&
592+
element_type.getIntOrFloatBitWidth() == 4) {
593+
num_elements = CeilOfRatio<int64_t>(num_elements, 2);
594+
llvm_element_type = b.getI8Type();
595+
auto unpacked_data =
596+
mlir::cast<mlir::DenseElementsAttr>(value).getRawData();
597+
std::vector<char> packed_data(num_elements);
598+
absl::Span<char> packed_data_span =
599+
absl::MakeSpan(packed_data.data(), packed_data.size());
600+
PackIntN(4, unpacked_data, packed_data_span);
601+
value = mlir::DenseElementsAttr::getFromRawBuffer(
602+
mlir::RankedTensorType::get({num_elements}, llvm_element_type),
603+
packed_data);
593604
}
594605
auto array_ty = ml::LLVMArrayType::get(llvm_element_type, num_elements);
595606
std::string name;

xla/backends/gpu/codegen/transforms/tests/expand_float_ops.mlir

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,3 +115,53 @@ module {
115115
// CHECK: %[[EXT:.*]] = arith.extf {{.*}} : bf16 to f32
116116
// CHECK: arith.truncf %[[EXT]] : f32 to f16
117117
// CHECK-NOT: arith.truncf
118+
119+
// -----
120+
121+
module {
122+
func.func @f4_to_f16(%arg0: f4E2M1FN) -> f16 {
123+
%ret = arith.extf %arg0 : f4E2M1FN to f16
124+
return %ret : f16
125+
}
126+
}
127+
128+
// CHECK-LABEL: @f4_to_f16
129+
// CHECK-NOT: arith.extf
130+
131+
// -----
132+
133+
module {
134+
func.func @f16_to_f4(%arg0: f16) -> f4E2M1FN {
135+
%ret = arith.truncf %arg0 : f16 to f4E2M1FN
136+
return %ret : f4E2M1FN
137+
}
138+
}
139+
140+
// CHECK-LABEL: @f16_to_f4
141+
// CHECK-NOT: arith.truncf
142+
143+
// -----
144+
145+
module {
146+
func.func @f4_abs(%arg0: f4E2M1FN) -> f4E2M1FN {
147+
%ret = math.absf %arg0 : f4E2M1FN
148+
return %ret : f4E2M1FN
149+
}
150+
}
151+
152+
// CHECK-LABEL: @f4_abs
153+
// CHECK-NOT: math.absf
154+
// CHECK: arith.constant 7 : i4
155+
156+
// -----
157+
158+
module {
159+
func.func @e8m0_abs(%arg0: f8E8M0FNU) -> f8E8M0FNU {
160+
%ret = math.absf %arg0 : f8E8M0FNU
161+
return %ret : f8E8M0FNU
162+
}
163+
}
164+
165+
// CHECK-LABEL: @e8m0_abs
166+
// CHECK-NOT: math.absf
167+
// CHECK: return %arg0

xla/backends/gpu/codegen/transforms/tests/lower_tensors.mlir

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -785,3 +785,42 @@ func.func @vector_atomic_rmw(%arg0: tensor<4xf32>) -> tensor<4xf32> {
785785
// CHECK-HOPPER: llvm.atomicrmw fadd {{.*}} !llvm.ptr, f32
786786
// CHECK-HOPPER: llvm.atomicrmw fadd {{.*}} !llvm.ptr, f32
787787
// CHECK-HOPPER: llvm.atomicrmw fadd {{.*}} !llvm.ptr, f32
788+
789+
// -----
790+
791+
func.func @f4_constant(%arg0: tensor<3xf4E2M1FN>, %arg1: index) -> f4E2M1FN {
792+
%cst = arith.constant dense<[0.5, -0.5, 2.5]> : tensor<3xf4E2M1FN>
793+
%extracted = tensor.extract %arg0[%arg1] : tensor<3xf4E2M1FN>
794+
%extracted_0 = tensor.extract %cst[%arg1] : tensor<3xf4E2M1FN>
795+
%0 = arith.addf %extracted, %extracted_0 : f4E2M1FN
796+
return %0 : f4E2M1FN
797+
}
798+
// CHECK: llvm.mlir.global private constant
799+
// CHECK-SAME: dense<[25, 64]>
800+
// CHECK-LABEL: @f4_constant
801+
802+
// -----
803+
804+
func.func @transfer_read_f4(%arg0: tensor<43xf4E2M1FN> {xla.slice_index = 1}) -> vector<2xf4E2M1FN> {
805+
%c16 = arith.constant 16 : index
806+
%c0 = arith.constant 0.0 : f4E2M1FN
807+
%out = vector.transfer_read %arg0[%c16], %c0 : tensor<43xf4E2M1FN>, vector<2xf4E2M1FN>
808+
func.return %out : vector<2xf4E2M1FN>
809+
}
810+
// CHECK-LABEL: @transfer_read_f4
811+
// CHECK: %[[PTR:.*]] = llvm.getelementptr inbounds %{{.*}}[8]
812+
// CHECK: llvm.load %[[PTR]] : !llvm.ptr -> vector<2xi4>
813+
// CHECK: %[[OUT:.*]] = builtin.unrealized_conversion_cast %{{.*}} : vector<2xi4> to vector<2xf4E2M1FN>
814+
// CHECK: return %[[OUT]] : vector<2xf4E2M1FN>
815+
816+
// -----
817+
818+
func.func @transfer_write_f4(%arg0: tensor<43xf4E2M1FN> {xla.slice_index = 1},
819+
%arg1: vector<2xf4E2M1FN>) -> tensor<43xf4E2M1FN> {
820+
%c10 = arith.constant 10 : index
821+
%out = vector.transfer_write %arg1, %arg0[%c10] : vector<2xf4E2M1FN>, tensor<43xf4E2M1FN>
822+
func.return %out : tensor<43xf4E2M1FN>
823+
}
824+
// CHECK-LABEL: @transfer_write_f4
825+
// CHECK: %[[PTR:.*]] = llvm.getelementptr inbounds %arg0[5] : (!llvm.ptr) -> !llvm.ptr, i8
826+
// CHECK: %[[OUT:.*]] = builtin.unrealized_conversion_cast %{{.*}} : vector<2xf4E2M1FN> to vector<2xi4>

0 commit comments

Comments
 (0)