Skip to content

Commit d7d5af7

Browse files
committed
Add F4E2M1FN type: add tests
1 parent 999bf96 commit d7d5af7

File tree

25 files changed

+186
-70
lines changed

25 files changed

+186
-70
lines changed

xla/array2d_test.cc

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,20 @@ TEST(Array2dTest, LinspaceF8E3M4) {
219219
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(2, 1)), 3.5);
220220
}
221221

222+
TEST(Array2dTest, LinspaceF4E2M1FN) {
223+
auto arr = MakeLinspaceArray2D<tsl::float4_e2m1fn>(1.0, 3.5, 3, 2);
224+
225+
EXPECT_EQ(arr->n1(), 3);
226+
EXPECT_EQ(arr->n2(), 2);
227+
228+
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(0, 0)), 1.0);
229+
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(0, 1)), 1.5);
230+
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(1, 0)), 2.0);
231+
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(1, 1)), 2.0); // 2.5 rounded down
232+
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(2, 0)), 3.0);
233+
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(2, 1)), 4.0); // 3.5 rounded up
234+
}
235+
222236
TEST(Array2dTest, Stringification) {
223237
auto arr = MakeLinspaceArray2D(1.0, 3.5, 3, 2);
224238
const std::string expected = R"([[1, 1.5],

xla/backends/gpu/codegen/transforms/lower_tensors.cc

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,7 @@ std::tuple<Value, Value> GetI4IndexAndNibble(Value linear_index,
297297
mlir::LLVM::GEPOp CreateGep(TypedValue<mlir::RankedTensorType> tensor,
298298
Value linear_index, mlir::ImplicitLocOpBuilder& b) {
299299
Type element_type = tensor.getType().getElementType();
300-
if (element_type == b.getI4Type()) {
300+
if (element_type.getIntOrFloatBitWidth() == 4) {
301301
element_type = b.getI8Type();
302302
}
303303
auto ptr = mlir::LLVM::LLVMPointerType::get(b.getContext());
@@ -325,8 +325,9 @@ struct RewriteTensorExtract : OpRewritePattern<mlir::tensor::ExtractOp> {
325325
mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
326326
auto linear_index = GetLinearIndex(op.getIndices(), b);
327327
Type element_type = op.getTensor().getType().getElementType();
328+
328329
Value is_low_nibble = nullptr;
329-
if (element_type == rewriter.getI4Type()) {
330+
if (element_type.getIntOrFloatBitWidth() == 4) {
330331
std::tie(linear_index, is_low_nibble) =
331332
GetI4IndexAndNibble(linear_index, b);
332333
}
@@ -341,7 +342,7 @@ struct RewriteTensorExtract : OpRewritePattern<mlir::tensor::ExtractOp> {
341342
auto high_value = b.create<mlir::arith::ShRUIOp>(
342343
load, b.create<mlir::arith::ConstantIntOp>(4, load.getType()));
343344
load = b.create<mlir::arith::TruncIOp>(
344-
op.getType(),
345+
rewriter.getI4Type(),
345346
b.create<mlir::arith::SelectOp>(is_low_nibble, load, high_value));
346347
}
347348

@@ -377,6 +378,7 @@ struct RewriteTransferRead : OpRewritePattern<mlir::vector::TransferReadOp> {
377378

378379
auto source = mlir::dyn_cast<mlir::TypedValue<mlir::RankedTensorType>>(
379380
op.getSource());
381+
mlir::Type source_element_type = source.getType().getElementType();
380382

381383
mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
382384
auto linear_index = GetLinearIndex(op.getIndices(), b);
@@ -385,7 +387,8 @@ struct RewriteTransferRead : OpRewritePattern<mlir::vector::TransferReadOp> {
385387
if (vector_type.getElementType().isInteger(1)) {
386388
vector_type = vector_type.cloneWith(std::nullopt, b.getI8Type());
387389
}
388-
if (op.getVectorType().getElementType().isInteger(4)) {
390+
mlir::Type gep_element_type = vector_type.getElementType();
391+
if (gep_element_type.getIntOrFloatBitWidth() == 4) {
389392
linear_index = b.create<arith::ShRUIOp>(
390393
linear_index,
391394
b.create<arith::ConstantIntOp>(1, linear_index.getType()));
@@ -397,11 +400,11 @@ struct RewriteTransferRead : OpRewritePattern<mlir::vector::TransferReadOp> {
397400
auto loaded =
398401
b.create<mlir::LLVM::LoadOp>(llvm_vector_type, gep).getResult();
399402

400-
if (source.getType().getElementType().isInteger(1)) {
403+
if (source_element_type.isInteger(1)) {
401404
Value zero = b.create<mlir::arith::ConstantOp>(
402405
mlir::DenseElementsAttr::get(vector_type, b.getI8IntegerAttr(0)));
403406
loaded = b.create<arith::CmpIOp>(arith::CmpIPredicate::ne, loaded, zero);
404-
} else if (source.getType().getElementType().isInteger(4)) {
407+
} else if (source_element_type.getIntOrFloatBitWidth() == 4) {
405408
// LLVM and XLA pack i4s in opposite order, so we have to reshuffle the
406409
// elements.
407410
loaded = PermutePairsInVector(loaded, b);
@@ -430,7 +433,7 @@ struct RewriteTensorInsert : OpRewritePattern<mlir::tensor::InsertOp> {
430433
auto scalar_value = op.getScalar();
431434

432435
// For i4 we store 2 values into one byte. This needs special handling here.
433-
if (tensor_dest.getType().getElementType() == rewriter.getI4Type()) {
436+
if (tensor_dest.getType().getElementType().getIntOrFloatBitWidth() == 4) {
434437
// We need to use directly op.getDest() as input, otherwise the following
435438
// rewrite might remove the only user of it.
436439
tensor_dest = op.getDest();
@@ -448,6 +451,10 @@ struct RewriteTensorInsert : OpRewritePattern<mlir::tensor::InsertOp> {
448451
auto tensor_dest_i8 =
449452
b.create<UnrealizedConversionCastOp>(tensor_ty, tensor_dest)
450453
.getResult(0);
454+
if (scalar_value.getType() != rewriter.getI4Type()) {
455+
scalar_value =
456+
b.create<arith::BitcastOp>(rewriter.getI4Type(), scalar_value);
457+
}
451458
scalar_value = b.create<mlir::arith::ExtUIOp>(ty, scalar_value);
452459

453460
// We need AtomicRMWOp because it can happen that different threads try to
@@ -507,12 +514,13 @@ struct RewriteTransferWrite : OpRewritePattern<mlir::vector::TransferWriteOp> {
507514
auto linear_index = GetLinearIndex(op.getIndices(), b);
508515

509516
mlir::Value vector_value = op.getVector();
510-
if (op.getVectorType().getElementType().isInteger(1)) {
517+
mlir::Type vector_element_type = op.getVectorType().getElementType();
518+
if (vector_element_type.isInteger(1)) {
511519
vector_value = b.create<arith::ExtUIOp>(
512520
op.getVectorType().cloneWith(std::nullopt, b.getI8Type()),
513521
vector_value);
514522
}
515-
if (op.getVectorType().getElementType().isInteger(4)) {
523+
if (vector_element_type.getIntOrFloatBitWidth() == 4) {
516524
linear_index = b.create<arith::ShRUIOp>(
517525
linear_index,
518526
b.create<arith::ConstantIntOp>(1, linear_index.getType()));

xla/fp_util_test.cc

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,59 @@ class FP8E4M3DistanceTest : public ::testing::Test {};
119119
using F8E4M3Types = ::testing::Types<tsl::float8_e4m3, tsl::float8_e4m3fn>;
120120
TYPED_TEST_SUITE(FP8E4M3DistanceTest, F8E4M3Types);
121121

122+
TEST(FPDistanceTest, F4E2M1FNDistance) {
123+
// a & b are equal
124+
EXPECT_EQ(CalculateDistanceInFloats<tsl::float4_e2m1fn>(
125+
tsl::float4_e2m1fn(4.0), tsl::float4_e2m1fn(4.0)),
126+
0);
127+
128+
// a & b have the same exponents
129+
EXPECT_EQ(CalculateDistanceInFloats<tsl::float4_e2m1fn>(
130+
tsl::float4_e2m1fn(4.0), tsl::float4_e2m1fn(6.0)),
131+
1);
132+
133+
// a & b have different exponents
134+
EXPECT_EQ(CalculateDistanceInFloats<tsl::float4_e2m1fn>(
135+
tsl::float4_e2m1fn(2.0), tsl::float4_e2m1fn(4.0)),
136+
2);
137+
138+
// 1 from 0 in the positive direction
139+
EXPECT_EQ(CalculateDistanceInFloats<tsl::float4_e2m1fn>(
140+
std::numeric_limits<tsl::float4_e2m1fn>::denorm_min(),
141+
tsl::float4_e2m1fn(0)),
142+
1);
143+
144+
// 1 from 0 in the negative direction
145+
EXPECT_EQ(CalculateDistanceInFloats<tsl::float4_e2m1fn>(
146+
-std::numeric_limits<tsl::float4_e2m1fn>::denorm_min(),
147+
tsl::float4_e2m1fn(0)),
148+
1);
149+
150+
// a & b have different signs
151+
EXPECT_EQ(CalculateDistanceInFloats<tsl::float4_e2m1fn>(
152+
-std::numeric_limits<tsl::float4_e2m1fn>::denorm_min(),
153+
std::numeric_limits<tsl::float4_e2m1fn>::denorm_min()),
154+
2);
155+
156+
// 1 non denorm from 0 in the positive direction
157+
EXPECT_EQ(CalculateDistanceInFloats<tsl::float4_e2m1fn>(
158+
std::numeric_limits<tsl::float4_e2m1fn>::min(),
159+
tsl::float4_e2m1fn(0)),
160+
2);
161+
162+
// 1 non denorm from 0 in the negative direction
163+
EXPECT_EQ(CalculateDistanceInFloats<tsl::float4_e2m1fn>(
164+
-std::numeric_limits<tsl::float4_e2m1fn>::min(),
165+
tsl::float4_e2m1fn(0)),
166+
2);
167+
168+
// a & b have different signs
169+
EXPECT_EQ(CalculateDistanceInFloats<tsl::float4_e2m1fn>(
170+
-std::numeric_limits<tsl::float4_e2m1fn>::min(),
171+
std::numeric_limits<tsl::float4_e2m1fn>::min()),
172+
4);
173+
}
174+
122175
TEST(FPDistanceTest, F8E3M4Distance) {
123176
// a & b are equal
124177
EXPECT_EQ(CalculateDistanceInFloats<tsl::float8_e3m4>(tsl::float8_e3m4(8.0),

xla/hlo/builder/lib/math.cc

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,7 @@ XlaOp IsNegZero(XlaOp operand) {
184184
case F32:
185185
return Eq(BitcastConvertType(operand, U32),
186186
ConstantR0WithType(&b, U32, uint32_t{1} << 31));
187+
case F4E2M1FN:
187188
case F8E3M4:
188189
case F8E4M3:
189190
case F8E5M2:
@@ -971,8 +972,9 @@ XlaOp Igamma(XlaOp a, XlaOp x) {
971972
TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("Igamma", a));
972973
PrimitiveType a_x_type = a_shape.element_type();
973974
bool needs_upcast = false;
974-
for (PrimitiveType type : {BF16, F16, F8E3M4, F8E4M3, F8E5M2, F8E4M3FN,
975-
F8E4M3B11FNUZ, F8E5M2FNUZ, F8E4M3FNUZ}) {
975+
for (PrimitiveType type :
976+
{BF16, F16, F4E2M1FN, F8E3M4, F8E4M3, F8E4M3B11FNUZ, F8E4M3FN,
977+
F8E4M3FNUZ, F8E5M2, F8E5M2FNUZ}) {
976978
if (a_shape.element_type() == type) {
977979
needs_upcast = true;
978980
break;
@@ -1024,8 +1026,9 @@ XlaOp IgammaGradA(XlaOp a, XlaOp x) {
10241026
}
10251027
TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("IgammaGradA", a));
10261028
bool needs_upcast = false;
1027-
for (PrimitiveType type : {BF16, F16, F8E3M4, F8E4M3, F8E5M2, F8E4M3FN,
1028-
F8E4M3B11FNUZ, F8E5M2FNUZ, F8E4M3FNUZ}) {
1029+
for (PrimitiveType type :
1030+
{BF16, F16, F4E2M1FN, F8E3M4, F8E4M3, F8E4M3B11FNUZ, F8E4M3FN,
1031+
F8E4M3FNUZ, F8E5M2, F8E5M2FNUZ}) {
10291032
if (a_shape.element_type() == type) {
10301033
needs_upcast = true;
10311034
break;

xla/hlo/builder/lib/math_test.cc

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -95,17 +95,22 @@ class MathTypedTest : public MathTest {
9595
Tuple(&b, {IsFinite(x), IsInf(x), IsPosInf(x), IsNegInf(x), IsNan(x)});
9696

9797
bool has_inf = std::numeric_limits<T>::has_infinity;
98+
bool has_nan = std::numeric_limits<T>::has_quiet_NaN;
99+
bool is_finite = !has_inf && !has_nan;
100+
bool is_nan_only = !has_inf && has_nan;
101+
98102
auto expected = LiteralUtil::MakeTupleOwned(
99-
LiteralUtil::CreateR1<bool>(
100-
{true, true, true, true, true, false, false, false, false}),
103+
LiteralUtil::CreateR1<bool>({true, true, true, true, true, is_finite,
104+
is_finite, is_finite, is_finite}),
101105
LiteralUtil::CreateR1<bool>({false, false, false, false, false, has_inf,
102106
has_inf, false, false}),
103107
LiteralUtil::CreateR1<bool>(
104108
{false, false, false, false, false, has_inf, false, false, false}),
105109
LiteralUtil::CreateR1<bool>(
106110
{false, false, false, false, false, false, has_inf, false, false}),
107111
LiteralUtil::CreateR1<bool>({false, false, false, false, false,
108-
!has_inf, !has_inf, true, true}));
112+
is_nan_only, is_nan_only, has_nan,
113+
has_nan}));
109114
ComputeAndCompareLiteral(&b, expected, {});
110115
}
111116

@@ -118,10 +123,11 @@ class MathTypedTest : public MathTest {
118123
LiteralUtil::CreateR1<T>({T{-0.0}, T{0}, T{1}, T{-1}, inf, -inf, nan}),
119124
&b));
120125

126+
bool is_mx = std::is_same_v<T, tsl::float4_e2m1fn>;
121127
ComputeAndCompareLiteral(
122128
&b,
123129
LiteralUtil::CreateR1<bool>(
124-
{has_negative_zero_v<T>, false, false, false, false, false, false}),
130+
{has_negative_zero_v<T>, false, false, false, false, false, is_mx}),
125131
{}, error_spec_);
126132
}
127133

@@ -136,6 +142,9 @@ class MathTypedTest : public MathTest {
136142
// For good measure, we also check pow with an exponent other than 0.5.
137143
void TestSqrtPowInequivalence() {
138144
SetFastMathDisabled(true);
145+
if (std::is_same_v<T, tsl::float4_e2m1fn>) {
146+
GTEST_SKIP() << "Skipping due to low precision";
147+
}
139148

140149
// Tests disable constant folding by default, but this test needs it
141150
// enabled, otherwise we don't tickle the bug we're trying to catch.
@@ -181,18 +190,23 @@ class MathTypedTest : public MathTest {
181190
&b);
182191
Erf(x);
183192

184-
bool has_inf = std::numeric_limits<T>::has_infinity;
185-
std::vector<T> expected = {
186-
has_inf ? T(-1) : nan, has_inf ? T(1) : nan, T(-0), T(0), T(-1), T(1)};
193+
bool inf_as_nan = !std::numeric_limits<T>::has_infinity &&
194+
std::numeric_limits<T>::has_quiet_NaN;
195+
std::vector<T> expected = {inf_as_nan ? nan : T(-1),
196+
inf_as_nan ? nan : T(1),
197+
T(-0),
198+
T(0),
199+
T(-1),
200+
T(1)};
187201

188202
ComputeAndCompareR1<T>(&b, expected, {}, error_spec_);
189203
}
190204
};
191205

192206
using TestTypes =
193-
::testing::Types<tsl::float8_e3m4, tsl::float8_e4m3, tsl::float8_e4m3fnuz,
194-
tsl::float8_e4m3b11fnuz, tsl::float8_e5m2,
195-
tsl::float8_e5m2fnuz,
207+
::testing::Types<tsl::float4_e2m1fn, tsl::float8_e3m4, tsl::float8_e4m3,
208+
tsl::float8_e4m3fnuz, tsl::float8_e4m3b11fnuz,
209+
tsl::float8_e5m2, tsl::float8_e5m2fnuz,
196210
#ifndef XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16
197211
Eigen::half,
198212
#endif

xla/hlo/transforms/simplifiers/float_normalization.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,9 @@ absl::Status FloatNormalizationVisitor::ChangeOutputTypeThenInsertConvertBack(
217217
hlo->mutable_shape(), [&](Shape* subshape, const xla::ShapeIndex& index) {
218218
if (subshape->element_type() == from) {
219219
subshape->set_element_type(to);
220+
if (subshape->has_layout() && from == F4E2M1FN) {
221+
subshape->mutable_layout()->set_element_size_in_bits(0);
222+
}
220223
}
221224
});
222225
float_normalization_->UpdateLayout(hlo->mutable_shape());

xla/hlo/transforms/simplifiers/float_normalization_test.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,9 @@ class FloatNormalizationF8Test
150150
public ::testing::WithParamInterface<PrimitiveType> {};
151151

152152
INSTANTIATE_TEST_SUITE_P(FloatNormalizationF8Suite, FloatNormalizationF8Test,
153-
::testing::Values(F8E3M4, F8E4M3, F8E5M2));
153+
::testing::Values(F4E2M1FN, F8E3M4, F8E4M3,
154+
F8E4M3B11FNUZ, F8E4M3FN, F8E4M3FNUZ,
155+
F8E5M2, F8E5M2FNUZ));
154156

155157
TEST_F(FloatNormalizationTest, NoopIfSupported) {
156158
auto builder = HloComputation::Builder(TestName());

xla/mlir/utils/type_util.cc

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ absl::StatusOr<mlir::Type> ConvertPrimitiveTypeToMlirType(
3232
switch (type) {
3333
case xla::PrimitiveType::PRED:
3434
return b.getI1Type();
35+
case xla::PrimitiveType::F4E2M1FN:
36+
return b.getFloat4E2M1FNType();
3537
case xla::PrimitiveType::F8E5M2:
3638
return b.getFloat8E5M2Type();
3739
case xla::PrimitiveType::F8E4M3:
@@ -78,7 +80,9 @@ absl::StatusOr<mlir::Type> ConvertPrimitiveTypeToMlirType(
7880
}
7981

8082
xla::PrimitiveType ConvertMlirTypeToPrimitiveType(mlir::Type type) {
81-
if (type.isFloat8E5M2()) {
83+
if (type.isFloat4E2M1FN()) {
84+
return xla::PrimitiveType::F4E2M1FN;
85+
} else if (type.isFloat8E5M2()) {
8286
return xla::PrimitiveType::F8E5M2;
8387
} else if (type.isFloat8E4M3()) {
8488
return xla::PrimitiveType::F8E4M3;

xla/mlir/utils/type_util_test.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ INSTANTIATE_TEST_SUITE_P(
101101
Execute, TypeUtilTest,
102102
::testing::ValuesIn(std::vector<TypeUtilTestParam>(
103103
{{PRED, [](mlir::Builder b) { return b.getI1Type(); }},
104+
{F4E2M1FN, [](mlir::Builder b) { return b.getFloat4E2M1FNType(); }},
104105
{F8E5M2, [](mlir::Builder b) { return b.getFloat8E5M2Type(); }},
105106
{F8E4M3, [](mlir::Builder b) { return b.getFloat8E4M3Type(); }},
106107
{F8E4M3FN, [](mlir::Builder b) { return b.getFloat8E4M3FNType(); }},

xla/mlir_hlo/tests/Dialect/mhlo/ops.mlir

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6844,6 +6844,13 @@ func.func @invalid_dimension_attr(%arg0: tensor<?x?xf32, #mhlo.type_extensions<b
68446844

68456845
// -----
68466846

6847+
func.func @f4e2m1fn(%arg0: tensor<f16>) -> tensor<f4E2M1FN> {
6848+
%0 = "mhlo.convert"(%arg0) : (tensor<f16>) -> tensor<f4E2M1FN>
6849+
func.return %0 : tensor<f4E2M1FN>
6850+
}
6851+
6852+
// -----
6853+
68476854
func.func @f8e3m4(%arg0: tensor<f16>) -> tensor<f8E3M4> {
68486855
%0 = "mhlo.convert"(%arg0) : (tensor<f16>) -> tensor<f8E3M4>
68496856
func.return %0 : tensor<f8E3M4>

0 commit comments

Comments
 (0)