Skip to content

Commit ca1ea7c

Browse files
apivovarovTensorFlow MLIR Team
authored and
TensorFlow MLIR Team
committed
PR #16585: Add support for float8_e4m3 and float8_e3m4 types
Imported from GitHub PR openxla/xla#16585 This PR adds f8E4M3 and f8E3M4 types support to XLA (mainly to cpu_compiler). ### `f8E4M3` type follows IEEE 754 convention. ```c f8E4M3 (IEEE 754) - Exponent bias: 7 - Maximum stored exponent value: 14 (binary 1110) - Maximum unbiased exponent value: 14 - 7 = 7 - Minimum stored exponent value: 1 (binary 0001) - Minimum unbiased exponent value: 1 − 7 = −6 - Precision specifies the total number of bits used for the significand (mantisa), including implicit leading integer bit = 3 + 1 = 4 - Follows IEEE 754 conventions for representation of special values - Has Positive and Negative zero - Has Positive and Negative infinity - Has NaNs Additional details: - Max exp (unbiased): 7 - Min exp (unbiased): -6 - Infinities (+/-): S.1111.000 - Zeros (+/-): S.0000.000 - NaNs: S.1111.{001, 010, 011, 100, 101, 110, 111} - Max normal number: S.1110.111 = +/-2^(7) x (1 + 0.875) = +/-240 - Min normal number: S.0001.000 = +/-2^(-6) - Max subnormal number: S.0000.111 = +/-2^(-6) x 0.875 = +/-2^(-9) x 7 - Min subnormal number: S.0000.001 = +/-2^(-6) x 0.125 = +/-2^(-9) ``` ### `f8E3M4` type follows IEEE 754 convention ```c f8E3M4 (IEEE 754) - Exponent bias: 3 - Maximum stored exponent value: 6 (binary 110) - Maximum unbiased exponent value: 6 - 3 = 3 - Minimum stored exponent value: 1 (binary 001) - Minimum unbiased exponent value: 1 − 3 = −2 - Precision specifies the total number of bits used for the significand (mantissa), including implicit leading integer bit = 4 + 1 = 5 - Follows IEEE 754 conventions for representation of special values - Has Positive and Negative zero - Has Positive and Negative infinity - Has NaNs Additional details: - Max exp (unbiased): 3 - Min exp (unbiased): -2 - Infinities (+/-): S.111.0000 - Zeros (+/-): S.000.0000 - NaNs: S.111.{0,1}⁴ except S.111.0000 - Max normal number: S.110.1111 = +/-2^(6-3) x (1 + 15/16) = +/-2^3 x 31 x 2^(-4) = +/-15.5 - Min normal number: S.001.0000 = +/-2^(1-3) x (1 + 0) = +/-2^(-2) - Max subnormal number: S.000.1111 = +/-2^(-2) x 15/16 = +/-2^(-2) x 15 x 2^(-4) = +/-15 x 2^(-6) - Min subnormal number: S.000.0001 = +/-2^(-2) x 1/16 = +/-2^(-2) x 2^(-4) = +/-2^(-6) ``` ### Testing: ``` bazel test \ //xla:array2d_test \ //xla:fp_util_test \ //xla:literal_comparison_test \ //xla:literal_test \ //xla/mlir/utils:type_util_test \ //xla:primitive_util_test \ //xla/python/ifrt:dtype_test \ //xla/python:xla_client_test \ //xla/service:elemental_ir_emitter_test \ //xla/service:float_normalization_test \ //xla/service/gpu/tests:float_conversions_test \ //xla/tests:array_elementwise_ops_test \ //xla/tests:constants_test \ //xla/tests:convert_test \ //xla/tests:float8_test \ //xla:util_test bazel test \ //xla/hlo/translate/hlo_to_mhlo/tests:import.hlo.test \ //xla/hlo/translate/mhlo_to_hlo/tests:export.mlir.test \ //xla/mlir_hlo/tests:Dialect/mhlo/hlo-legalize-to-stablehlo.mlir.test \ //xla/mlir_hlo/tests:Dialect/mhlo/ops.mlir.test \ //xla/mlir_hlo/tests:Dialect/mhlo/stablehlo-legalize-to-hlo.mlir.test ``` ### Related PRs: - LLVM [PR-97179](llvm/llvm-project#97179) [APFloat] Add support for f8E4M3 IEEE 754 type (Merged) - LLVM [PR-97118](llvm/llvm-project#97118) [MLIR] Add f8E4M3 IEEE 754 type (Merged) - LLVM [PR-99698](llvm/llvm-project#99698) [APFloat] Add support for f8E3M4 IEEE 754 type (Merged) - LLVM [PR-101230](llvm/llvm-project#101230) [MLIR] Add f8E3M4 IEEE 754 type (Merged) - StableHLO [PR-2486](openxla/stablehlo#2486) [RFC] Add f8E4M3 and f8E3M4 types support (Merged) - StableHLO [PR-2482](openxla/stablehlo#2482) Add f8E4M3 and f8E3M4 types support (Merged) - ml_dtypes [PR-161](jax-ml/ml_dtypes#161) Add float8_e4m3 (Merged) - ml_dtypes [PR-171](jax-ml/ml_dtypes#171) Add float8_e3m4 (Merged) - XLA [PR-17075](openxla/xla#17075) [TSL] Bump ml_dtypes. Add float8_e4m3, float8_e3m4 (Approved) - XLA [PR-3200](openxla/xla#3200) Add support for float8_e4m3fnuz and float8_e5m2fnuz (Template) - JAX [PR-23585](jax-ml/jax#23585) Add float8_e4m3 type support (in Review) Copybara import of the project: -- ec1c723027012a816d7e17f268c5f034863696e6 by Alexander Pivovarov <[email protected]>: Add support for float8_e4m3 and float8_e3m4 types Merging this change closes #16585 PiperOrigin-RevId: 681551979
1 parent 9f778be commit ca1ea7c

File tree

3 files changed

+42
-0
lines changed

3 files changed

+42
-0
lines changed

tests/Dialect/mhlo/hlo-legalize-to-stablehlo.mlir

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1805,6 +1805,20 @@ func.func @type_ui64(%arg0: tensor<ui64>, %arg1: tensor<ui64>) -> tensor<ui64> {
18051805
func.return %0 : tensor<ui64>
18061806
}
18071807

1808+
// CHECK-LABEL: "type_f8E3M4"
1809+
func.func @type_f8E3M4(%arg0: tensor<f8E3M4>, %arg1: tensor<f8E3M4>) -> tensor<f8E3M4> {
1810+
// CHECK: "stablehlo.add"([[ARG0:%arg[0-9]+]], [[ARG1:%arg[0-9]+]]) : (tensor<f8E3M4>, tensor<f8E3M4>) -> tensor<f8E3M4>
1811+
%0 = "mhlo.add"(%arg0, %arg1) : (tensor<f8E3M4>, tensor<f8E3M4>) -> tensor<f8E3M4>
1812+
func.return %0 : tensor<f8E3M4>
1813+
}
1814+
1815+
// CHECK-LABEL: "type_f8E4M3"
1816+
func.func @type_f8E4M3(%arg0: tensor<f8E4M3>, %arg1: tensor<f8E4M3>) -> tensor<f8E4M3> {
1817+
// CHECK: "stablehlo.add"([[ARG0:%arg[0-9]+]], [[ARG1:%arg[0-9]+]]) : (tensor<f8E4M3>, tensor<f8E4M3>) -> tensor<f8E4M3>
1818+
%0 = "mhlo.add"(%arg0, %arg1) : (tensor<f8E4M3>, tensor<f8E4M3>) -> tensor<f8E4M3>
1819+
func.return %0 : tensor<f8E4M3>
1820+
}
1821+
18081822
// CHECK-LABEL: "type_f8E4M3FN"
18091823
func.func @type_f8E4M3FN(%arg0: tensor<f8E4M3FN>, %arg1: tensor<f8E4M3FN>) -> tensor<f8E4M3FN> {
18101824
// CHECK: "stablehlo.add"([[ARG0:%arg[0-9]+]], [[ARG1:%arg[0-9]+]]) : (tensor<f8E4M3FN>, tensor<f8E4M3FN>) -> tensor<f8E4M3FN>

tests/Dialect/mhlo/ops.mlir

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6832,6 +6832,20 @@ func.func @invalid_dimension_attr(%arg0: tensor<?x?xf32, #mhlo.type_extensions<b
68326832

68336833
// -----
68346834

6835+
func.func @f8e3m4(%arg0: tensor<f16>) -> tensor<f8E3M4> {
6836+
%0 = "mhlo.convert"(%arg0) : (tensor<f16>) -> tensor<f8E3M4>
6837+
func.return %0 : tensor<f8E3M4>
6838+
}
6839+
6840+
// -----
6841+
6842+
func.func @f8e4m3(%arg0: tensor<f16>) -> tensor<f8E4M3> {
6843+
%0 = "mhlo.convert"(%arg0) : (tensor<f16>) -> tensor<f8E4M3>
6844+
func.return %0 : tensor<f8E4M3>
6845+
}
6846+
6847+
// -----
6848+
68356849
func.func @f8e4m3fn(%arg0: tensor<f16>) -> tensor<f8E4M3FN> {
68366850
%0 = "mhlo.convert"(%arg0) : (tensor<f16>) -> tensor<f8E4M3FN>
68376851
func.return %0 : tensor<f8E4M3FN>

tests/Dialect/mhlo/stablehlo-legalize-to-hlo.mlir

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1787,6 +1787,20 @@ func.func @type_ui64(%arg0: tensor<ui64>, %arg1: tensor<ui64>) -> tensor<ui64> {
17871787
func.return %0 : tensor<ui64>
17881788
}
17891789

1790+
// CHECK-LABEL: "type_f8E3M4"
1791+
func.func @type_f8E3M4(%arg0: tensor<f8E3M4>, %arg1: tensor<f8E3M4>) -> tensor<f8E3M4> {
1792+
// CHECK: "mhlo.add"([[ARG0:%arg[0-9]+]], [[ARG1:%arg[0-9]+]]) : (tensor<f8E3M4>, tensor<f8E3M4>) -> tensor<f8E3M4>
1793+
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<f8E3M4>, tensor<f8E3M4>) -> tensor<f8E3M4>
1794+
func.return %0 : tensor<f8E3M4>
1795+
}
1796+
1797+
// CHECK-LABEL: "type_f8E4M3"
1798+
func.func @type_f8E4M3(%arg0: tensor<f8E4M3>, %arg1: tensor<f8E4M3>) -> tensor<f8E4M3> {
1799+
// CHECK: "mhlo.add"([[ARG0:%arg[0-9]+]], [[ARG1:%arg[0-9]+]]) : (tensor<f8E4M3>, tensor<f8E4M3>) -> tensor<f8E4M3>
1800+
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<f8E4M3>, tensor<f8E4M3>) -> tensor<f8E4M3>
1801+
func.return %0 : tensor<f8E4M3>
1802+
}
1803+
17901804
// CHECK-LABEL: "type_f8E4M3FN"
17911805
func.func @type_f8E4M3FN(%arg0: tensor<f8E4M3FN>, %arg1: tensor<f8E4M3FN>) -> tensor<f8E4M3FN> {
17921806
// CHECK: "mhlo.add"([[ARG0:%arg[0-9]+]], [[ARG1:%arg[0-9]+]]) : (tensor<f8E4M3FN>, tensor<f8E4M3FN>) -> tensor<f8E4M3FN>

0 commit comments

Comments
 (0)