Skip to content

Commit fd28e84

Browse files
sergey-kozubTensorFlow MLIR Team
authored and
TensorFlow MLIR Team
committed
PR #21380: Add F4E2M1FN and F8E8M0FNU types
Imported from GitHub PR openxla/xla#21380 Previous PR openxla/xla#19096 was rolled back, re-trying. This PR adds F4E2M1FN primitive type (4-bit float with 2 bits exponent and 1 bit mantissa), F8E8M0FNU primitive type (8-bit float with 8 bits exponent, no mantissa and no sign) and enables loads/stores in the same way S4/U4 type is implemented. This will enable using microscaling (MX) formats ([RFC](openxla/xla#18085)), such as MXFP4. ```c F4E2M1FN - Exponent bias: 1 - Maximum stored exponent value: 3 (binary 11) - Maximum unbiased exponent value: 3 - 1 = 2 - Minimum stored exponent value: 1 (binary 01) - Minimum unbiased exponent value: 1 − 1 = 0 - Has Positive and Negative zero - Doesn't have infinity - Doesn't have NaNs Additional details: - Zeros (+/-): S.00.0 - Max normal number: S.11.1 = ±2^(2) x (1 + 0.5) = ±6.0 - Min normal number: S.01.0 = ±2^(0) = ±1.0 - Min subnormal number: S.00.1 = ±2^(0) x 0.5 = ±0.5 F8E8M0FNU - Exponent bias: 127 - Maximum stored exponent value: 254 (binary 1111'1110) - Maximum unbiased exponent value: 254 - 127 = 127 - Minimum stored exponent value: 0 (binary 0000'0000) - Minimum unbiased exponent value: 0 − 127 = -127 - Doesn't have zero - Doesn't have infinity - NaN is encoded as binary 1111'1111 Additional details: - Zeros cannot be represented - Negative values cannot be represented - Mantissa is always 1 ``` Related PRs: - openxla/stablehlo#2582 - jax-ml/ml_dtypes#181 - llvm/llvm-project#95392 - llvm/llvm-project#108877 - jax-ml/ml_dtypes#166 - llvm/llvm-project#107127 - llvm/llvm-project#111028 Copybara import of the project: -- d7e00c49a4b4f26c06266d6bb941275e67464c01 by Sergey Kozub <[email protected]>: Add F4E2M1FN and F8E8M0FNU types Merging this change closes #21380 PiperOrigin-RevId: 715434229
1 parent 4484fd7 commit fd28e84

File tree

1 file changed

+14
-0
lines changed

1 file changed

+14
-0
lines changed

tests/Dialect/mhlo/ops.mlir

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6844,6 +6844,13 @@ func.func @invalid_dimension_attr(%arg0: tensor<?x?xf32, #mhlo.type_extensions<b
68446844

68456845
// -----
68466846

6847+
func.func @f4e2m1fn(%arg0: tensor<f16>) -> tensor<f4E2M1FN> {
6848+
%0 = "mhlo.convert"(%arg0) : (tensor<f16>) -> tensor<f4E2M1FN>
6849+
func.return %0 : tensor<f4E2M1FN>
6850+
}
6851+
6852+
// -----
6853+
68476854
func.func @f8e3m4(%arg0: tensor<f16>) -> tensor<f8E3M4> {
68486855
%0 = "mhlo.convert"(%arg0) : (tensor<f16>) -> tensor<f8E3M4>
68496856
func.return %0 : tensor<f8E3M4>
@@ -6872,6 +6879,13 @@ func.func @f8e5m2(%arg0: tensor<f16>) -> tensor<f8E5M2> {
68726879

68736880
// -----
68746881

6882+
func.func @f8e8m0fnu(%arg0: tensor<f16>) -> tensor<f8E8M0FNU> {
6883+
%0 = "mhlo.convert"(%arg0) : (tensor<f16>) -> tensor<f8E8M0FNU>
6884+
func.return %0 : tensor<f8E8M0FNU>
6885+
}
6886+
6887+
// -----
6888+
68756889
func.func @top_k_1d(%arg0 : tensor<16xf32>) {
68766890
%0:2 = mhlo.topk(%arg0, k=8, largest=true) : tensor<16xf32> -> (tensor<8xf32>, tensor<8xi32>)
68776891
return

0 commit comments

Comments
 (0)