Skip to content

Commit daaa3af

Browse files
committed
Add F4E2M1FN type: python interface
1 parent c479f09 commit daaa3af

File tree

17 files changed

+79
-31
lines changed

17 files changed

+79
-31
lines changed

xla/pjrt/c/pjrt_c_api.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -649,6 +649,9 @@ typedef enum {
649649
// More truncated 8 bit floating-point formats.
650650
PJRT_Buffer_Type_F8E4M3,
651651
PJRT_Buffer_Type_F8E3M4,
652+
653+
// 4-bit MX floating-point format.
654+
PJRT_Buffer_Type_F4E2M1FN,
652655
} PJRT_Buffer_Type;
653656

654657
typedef enum {

xla/pjrt/c/pjrt_c_api_helpers.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,8 @@ PJRT_Buffer_Type ConvertToPjRtBufferType(xla::PrimitiveType type) {
294294
return PJRT_Buffer_Type::PJRT_Buffer_Type_BF16;
295295
case xla::PrimitiveType::F64:
296296
return PJRT_Buffer_Type::PJRT_Buffer_Type_F64;
297+
case xla::PrimitiveType::F4E2M1FN:
298+
return PJRT_Buffer_Type::PJRT_Buffer_Type_F4E2M1FN;
297299
case xla::PrimitiveType::F8E5M2:
298300
return PJRT_Buffer_Type::PJRT_Buffer_Type_F8E5M2;
299301
case xla::PrimitiveType::F8E4M3:
@@ -361,6 +363,8 @@ xla::PrimitiveType ConvertFromPjRtBufferType(PJRT_Buffer_Type type) {
361363
return xla::PrimitiveType::C64;
362364
case PJRT_Buffer_Type::PJRT_Buffer_Type_C128:
363365
return xla::PrimitiveType::C128;
366+
case PJRT_Buffer_Type::PJRT_Buffer_Type_F4E2M1FN:
367+
return xla::PrimitiveType::F4E2M1FN;
364368
case PJRT_Buffer_Type::PJRT_Buffer_Type_F8E5M2:
365369
return xla::PrimitiveType::F8E5M2;
366370
case PJRT_Buffer_Type::PJRT_Buffer_Type_F8E4M3:

xla/python/ifrt/dtype.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ std::optional<int> DType::byte_size() const {
3232
case kU2:
3333
case kS4:
3434
case kU4:
35+
case kF4E2M1FN:
3536
// Smaller than a byte.
3637
return std::nullopt;
3738
case kPred:
@@ -77,6 +78,7 @@ std::optional<int> DType::bit_size() const {
7778
return 2;
7879
case kS4:
7980
case kU4:
81+
case kF4E2M1FN:
8082
return 4;
8183
case kPred:
8284
case kS8:
@@ -142,6 +144,7 @@ absl::StatusOr<DType> DType::FromProto(const DTypeProto& dtype_proto) {
142144
CASE(C64);
143145
CASE(C128);
144146
// TODO: Uncomment once the minimum ml_dtypes in JAX is >= 0.5.0.
147+
// CASE(F4E2M1FN);
145148
// CASE(F8E3M4);
146149
// CASE(F8E4M3);
147150
CASE(F8E4M3FN);
@@ -190,6 +193,7 @@ DTypeProto DType::ToProto() const {
190193
CASE(C64);
191194
CASE(C128);
192195
// TODO: Uncomment once the minimum ml_dtypes in JAX is >= 0.5.0.
196+
// CASE(F4E2M1FN);
193197
// CASE(F8E3M4);
194198
// CASE(F8E4M3);
195199
CASE(F8E4M3FN);

xla/python/ifrt/dtype.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,10 @@ class DType {
8989
kF8E5M2 = 19,
9090
kF8E5M2FNUZ = 24,
9191

92-
// Next = 30
92+
// MX floating point types.
93+
kF4E2M1FN = 30,
94+
95+
// Next = 31
9396

9497
// Variable-length string represented as raw bytes, as in `bytes` in Python,
9598
// i.e., no encoding enforcement. String is not support in XLA. DType.Kind

xla/python/ifrt/dtype.proto

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,11 +71,16 @@ message DTypeProto {
7171
KIND_F8E5M2 = 19;
7272
KIND_F8E5M2FNUZ = 24;
7373

74+
// MX floating point types.
75+
KIND_F4E2M1FN = 30;
76+
7477
// Variable-length string represented as raw bytes, as in `bytes` in Python,
7578
// i.e., no encoding enforcement. String is not support in XLA. DType.Kind
7679
// needs to match xla.PrimitiveType enum, so choose a large enum to avoid
7780
// collision.
7881
KIND_STRING = 99;
82+
83+
// Next: 31
7984
}
8085
// LINT.ThenChange()
8186
Kind kind = 1;

xla/python/ifrt/dtype_test.cc

Lines changed: 16 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -42,35 +42,21 @@ TEST(DTypeTest, FromToFromProto) {
4242
TEST(DTypeTest, ByteSize) {
4343
for (const auto& [kind, byte_size] :
4444
std::vector<std::tuple<DType::Kind, int>>({
45-
{DType::kS2, -1},
46-
{DType::kU2, -1},
47-
{DType::kS4, -1},
48-
{DType::kU4, -1},
49-
{DType::kPred, 1},
50-
{DType::kS8, 1},
51-
{DType::kU8, 1},
52-
{DType::kF8E3M4, 1},
53-
{DType::kF8E4M3, 1},
54-
{DType::kF8E4M3FN, 1},
55-
{DType::kF8E4M3B11FNUZ, 1},
56-
{DType::kF8E4M3FNUZ, 1},
57-
{DType::kF8E5M2, 1},
58-
{DType::kF8E5M2FNUZ, 1},
59-
{DType::kS16, 2},
60-
{DType::kU16, 2},
61-
{DType::kF16, 2},
62-
{DType::kBF16, 2},
63-
{DType::kS32, 4},
64-
{DType::kU32, 4},
65-
{DType::kF32, 4},
66-
{DType::kS64, 8},
67-
{DType::kU64, 8},
68-
{DType::kF64, 8},
69-
{DType::kC64, 8},
70-
{DType::kC128, 16},
71-
{DType::kToken, -1},
72-
{DType::kInvalid, -1},
73-
{DType::kString, -1},
45+
{DType::kS2, -1}, {DType::kU2, -1},
46+
{DType::kS4, -1}, {DType::kU4, -1},
47+
{DType::kPred, 1}, {DType::kS8, 1},
48+
{DType::kU8, 1}, {DType::kF4E2M1FN, -1},
49+
{DType::kF8E3M4, 1}, {DType::kF8E4M3, 1},
50+
{DType::kF8E4M3FN, 1}, {DType::kF8E4M3B11FNUZ, 1},
51+
{DType::kF8E4M3FNUZ, 1}, {DType::kF8E5M2, 1},
52+
{DType::kF8E5M2FNUZ, 1}, {DType::kS16, 2},
53+
{DType::kU16, 2}, {DType::kF16, 2},
54+
{DType::kBF16, 2}, {DType::kS32, 4},
55+
{DType::kU32, 4}, {DType::kF32, 4},
56+
{DType::kS64, 8}, {DType::kU64, 8},
57+
{DType::kF64, 8}, {DType::kC64, 8},
58+
{DType::kC128, 16}, {DType::kToken, -1},
59+
{DType::kInvalid, -1}, {DType::kString, -1},
7460
})) {
7561
EXPECT_EQ(DType(kind).byte_size(),
7662
byte_size == -1 ? std::nullopt : std::make_optional(byte_size));
@@ -87,6 +73,7 @@ TEST(DTypeTest, BitSize) {
8773
{DType::kPred, 8},
8874
{DType::kS8, 8},
8975
{DType::kU8, 8},
76+
{DType::kF4E2M1FN, 4},
9077
{DType::kF8E3M4, 8},
9178
{DType::kF8E4M3, 8},
9279
{DType::kF8E4M3FN, 8},

xla/python/pjrt_ifrt/pjrt_dtype.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ absl::StatusOr<xla::PrimitiveType> ToPrimitiveType(DType dtype) {
4444
CASE(DType::kU16, xla::PrimitiveType::U16);
4545
CASE(DType::kU32, xla::PrimitiveType::U32);
4646
CASE(DType::kU64, xla::PrimitiveType::U64);
47+
CASE(DType::kF4E2M1FN, xla::PrimitiveType::F4E2M1FN);
4748
CASE(DType::kF8E3M4, xla::PrimitiveType::F8E3M4);
4849
CASE(DType::kF8E4M3, xla::PrimitiveType::F8E4M3);
4950
CASE(DType::kF8E4M3FN, xla::PrimitiveType::F8E4M3FN);
@@ -83,6 +84,7 @@ absl::StatusOr<DType> ToDType(xla::PrimitiveType primitive_type) {
8384
case xla::PrimitiveType::U16:
8485
case xla::PrimitiveType::U32:
8586
case xla::PrimitiveType::U64:
87+
case xla::PrimitiveType::F4E2M1FN:
8688
case xla::PrimitiveType::F8E3M4:
8789
case xla::PrimitiveType::F8E4M3:
8890
case xla::PrimitiveType::F8E4M3FN:

xla/python/py_values.cc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,9 @@ absl::StatusOr<DevicePutResultFn> HandleNumpyScalar(
184184
} else if (std::is_same<T, bfloat16>()) {
185185
PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<2>());
186186
type = BF16;
187+
} else if (std::is_same<T, tsl::float4_e2m1fn>()) {
188+
PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<2>());
189+
type = F4E2M1FN;
187190
} else if (std::is_same<T, tsl::float8_e3m4>()) {
188191
PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<2>());
189192
type = F8E3M4;
@@ -398,6 +401,10 @@ absl::StatusOr<DevicePutResultFn> DevicePut(nb::handle arg,
398401
(*p)[dtypes.np_uint16.ptr()] = HandleNumpyScalar<uint16_t>;
399402
(*p)[dtypes.np_uint32.ptr()] = HandleNumpyScalar<uint32_t>;
400403
(*p)[dtypes.np_uint64.ptr()] = HandleNumpyScalar<uint64_t, uint32_t>;
404+
if (dtypes.np_float4_e2m1fn.has_value()) {
405+
(*p)[dtypes.np_float4_e2m1fn->ptr()] =
406+
HandleNumpyScalar<tsl::float4_e2m1fn>;
407+
}
401408
if (dtypes.np_float8_e3m4.has_value()) {
402409
(*p)[dtypes.np_float8_e3m4->ptr()] =
403410
HandleNumpyScalar<tsl::float8_e3m4>;
@@ -595,6 +602,7 @@ absl::StatusOr<PyArgSignature> PyArgSignatureOfValue(nb::handle arg,
595602
(*p)[dtypes.np_uint32.ptr()] = numpy_array_handler;
596603
(*p)[dtypes.np_uint64.ptr()] = np_uint64_handler;
597604
// TODO: Uncomment once the minimum ml_dtypes in JAX is >= 0.5.0.
605+
// (*p)[dtypes.np_float4_e2m1fn.ptr()] = numpy_array_handler;
598606
// (*p)[dtypes.np_float8_e3m4.ptr()] = numpy_array_handler;
599607
// (*p)[dtypes.np_float8_e4m3.ptr()] = numpy_array_handler;
600608
(*p)[dtypes.np_float8_e4m3fn.ptr()] = numpy_array_handler;

xla/python/types.cc

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ namespace {
5858

5959
struct CustomDtypes {
6060
nb_dtype bfloat16;
61+
std::optional<nb_dtype> float4_e2m1fn;
6162
std::optional<nb_dtype> float8_e3m4;
6263
std::optional<nb_dtype> float8_e4m3;
6364
nb_dtype float8_e4m3fn;
@@ -76,6 +77,10 @@ const CustomDtypes& GetCustomDtypes() {
7677
nb::module_ ml_dtypes = nb::module_::import_("ml_dtypes");
7778
auto* dtypes = new CustomDtypes;
7879
dtypes->bfloat16 = nb_dtype::from_args(ml_dtypes.attr("bfloat16"));
80+
if (nb::hasattr(ml_dtypes, "float4_e2m1fn")) {
81+
dtypes->float4_e2m1fn =
82+
nb_dtype::from_args(ml_dtypes.attr("float4_e2m1fn"));
83+
}
7984
if (nb::hasattr(ml_dtypes, "float8_e3m4")) {
8085
dtypes->float8_e3m4 = nb_dtype::from_args(ml_dtypes.attr("float8_e3m4"));
8186
}
@@ -147,6 +152,9 @@ absl::StatusOr<PrimitiveType> DtypeToPrimitiveType(const nb_dtype& np_type) {
147152
auto* map =
148153
new absl::flat_hash_map<nb_dtype, PrimitiveType, DtypeHash, DtypeEq>();
149154
map->emplace(custom_dtypes.bfloat16, BF16);
155+
if (custom_dtypes.float4_e2m1fn.has_value()) {
156+
map->emplace(*custom_dtypes.float4_e2m1fn, F4E2M1FN);
157+
}
150158
if (custom_dtypes.float8_e3m4.has_value()) {
151159
map->emplace(*custom_dtypes.float8_e3m4, F8E3M4);
152160
}
@@ -217,6 +225,11 @@ absl::StatusOr<nb_dtype> PrimitiveTypeToNbDtype(PrimitiveType type) {
217225
return to_nb_dtype(NPY_UINT32);
218226
case U64:
219227
return to_nb_dtype(NPY_UINT64);
228+
case F4E2M1FN:
229+
if (custom_dtypes.float4_e2m1fn.has_value()) {
230+
return *custom_dtypes.float4_e2m1fn;
231+
}
232+
break;
220233
case F8E3M4:
221234
if (custom_dtypes.float8_e3m4.has_value()) {
222235
return *custom_dtypes.float8_e3m4;
@@ -307,6 +320,11 @@ absl::StatusOr<nb_dtype> IfrtDtypeToNbDtype(ifrt::DType dtype) {
307320
return to_nb_dtype(NPY_COMPLEX64);
308321
case ifrt::DType::kC128:
309322
return to_nb_dtype(NPY_COMPLEX128);
323+
case ifrt::DType::kF4E2M1FN:
324+
if (custom_dtypes.float4_e2m1fn.has_value()) {
325+
return *custom_dtypes.float4_e2m1fn;
326+
}
327+
break;
310328
case ifrt::DType::kF8E3M4:
311329
if (custom_dtypes.float8_e3m4.has_value()) {
312330
return *custom_dtypes.float8_e3m4;
@@ -380,6 +398,9 @@ const NumpyScalarTypes& GetNumpyScalarTypes() {
380398
dtypes->np_uint32 = nb::object(numpy.attr("uint32"));
381399
dtypes->np_uint64 = nb::object(numpy.attr("uint64"));
382400
dtypes->np_bfloat16 = nb::object(ml_dtypes.attr("bfloat16"));
401+
if (nb::hasattr(ml_dtypes, "float4_e2m1fn")) {
402+
dtypes->np_float4_e2m1fn = nb::object(ml_dtypes.attr("float4_e2m1fn"));
403+
}
383404
if (nb::hasattr(ml_dtypes, "float8_e3m4")) {
384405
dtypes->np_float8_e3m4 = nb::object(ml_dtypes.attr("float8_e3m4"));
385406
}

xla/python/types.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ struct NumpyScalarTypes {
8181
nanobind::object np_uint64;
8282
nanobind::object np_bfloat16;
8383
// Remove std::optional once the minimum ml_dtypes in JAX is >= 0.5.0.
84+
std::optional<nanobind::object> np_float4_e2m1fn;
8485
std::optional<nanobind::object> np_float8_e3m4;
8586
std::optional<nanobind::object> np_float8_e4m3;
8687
nanobind::object np_float8_e4m3fn;

xla/python/xla.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,7 @@ NB_MODULE(xla_extension, m) {
205205
.value("U64", U64)
206206
.value("F16", F16)
207207
// TODO: Uncomment once the minimum ml_dtypes in JAX is >= 0.5.0.
208+
// .value("F4E2M1FN", F4E2M1FN)
208209
// .value("F8E3M4", F8E3M4)
209210
// .value("F8E4M3", F8E4M3)
210211
.value("F8E4M3FN", F8E4M3FN)

xla/python/xla_client.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,7 @@ def CurrentSourceInfoMetadata(op_type=None, op_name=None, skip_frames=1):
280280

281281
bfloat16 = ml_dtypes.bfloat16
282282
# TODO(reedwm): Uncomment once the minimum ml_dtypes in JAX is >= 0.5.0.
283+
# float4_e2m1fn = ml_dtypes.float4_e2m1fn
283284
# float8_e3m4 = ml_dtypes.float8_e3m4
284285
# float8_e4m3 = ml_dtypes.float8_e4m3
285286
float8_e4m3fn = ml_dtypes.float8_e4m3fn
@@ -301,6 +302,7 @@ def CurrentSourceInfoMetadata(op_type=None, op_name=None, skip_frames=1):
301302
PrimitiveType.U32: np.dtype('uint32'),
302303
PrimitiveType.U64: np.dtype('uint64'),
303304
# TODO(reedwm): Uncomment once the minimum ml_dtypes in JAX is >= 0.5.0.
305+
# PrimitiveType.F4E2M1FN: np.dtype(float4_e2m1fn),
304306
# PrimitiveType.F8E3M4: np.dtype(float8_e3m4),
305307
# PrimitiveType.F8E4M3: np.dtype(float8_e4m3),
306308
PrimitiveType.F8E4M3FN: np.dtype(float8_e4m3fn),

xla/python/xla_client.pyi

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ mlir_api_version: int
6262

6363
bfloat16: type[numpy.generic]
6464
# TODO: Uncomment once the minimum ml_dtypes in JAX is >= 0.5.0.
65+
# float4_e2m1fn: type[numpy.generic]
6566
# float8_e3m4: type[numpy.generic]
6667
# float8_e4m3: type[numpy.generic]
6768
float8_e4m3fn: type[numpy.generic]

xla/python/xla_client_test.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555

5656
bfloat16 = xla_client.bfloat16
5757
# TODO(reedwm): Uncomment once the minimum ml_dtypes in JAX is >= 0.5.0.
58+
# float4_e2m1fn = xla_client.float4_e2m1fn
5859
# float8_e3m4 = xla_client.float8_e3m4
5960
# float8_e4m3 = xla_client.float8_e4m3
6061
float8_e4m3fn = xla_client.float8_e4m3fn
@@ -189,7 +190,7 @@ def TestFactory(xla_backend,
189190
fp8_dtypes = [float8_e4m3b11fnuz, float8_e4m3fn, float8_e5m2]
190191
standard_dtypes += fp8_dtypes
191192
# TODO(reedwm): Uncomment once the minimum ml_dtypes in JAX is >= 0.5.0.
192-
# standard_dtypes += [float8_e3m4, float8_e4m3]
193+
# standard_dtypes += [float4_e2m1fn, float8_e3m4, float8_e4m3]
193194
dlpack_dtypes = int_dtypes + float_dtypes + [np.bool_] + complex_dtypes
194195

195196
class ComputationTest(parameterized.TestCase):

xla/python/xla_extension/__init__.pyi

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ class PrimitiveType(enum.IntEnum):
7474
U16: PrimitiveType
7575
U32: PrimitiveType
7676
U64: PrimitiveType
77+
F4E2M1FN: PrimitiveType
7778
F8E3M4: PrimitiveType
7879
F8E4M3: PrimitiveType
7980
F8E4M3FN: PrimitiveType

xla/tsl/python/lib/core/ml_dtypes.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@ struct MlDtypesInitInfo {
6161

6262
numpy_dtypes.bfloat16 =
6363
py::dtype::from_args(ml_dtypes.attr("bfloat16")).num();
64+
numpy_dtypes.float4_e2m1fn =
65+
py::dtype::from_args(ml_dtypes.attr("float4_e2m1fn")).num();
6466
numpy_dtypes.float8_e3m4 =
6567
py::dtype::from_args(ml_dtypes.attr("float8_e3m4")).num();
6668
numpy_dtypes.float8_e4m3 =
@@ -85,6 +87,7 @@ struct MlDtypesInitInfo {
8587

8688
// Verify all types were successfully loaded.
8789
if (numpy_dtypes.bfloat16 == NPY_NOTYPE ||
90+
numpy_dtypes.float4_e2m1fn == NPY_NOTYPE ||
8891
numpy_dtypes.float8_e3m4 == NPY_NOTYPE ||
8992
numpy_dtypes.float8_e4m3 == NPY_NOTYPE ||
9093
numpy_dtypes.float8_e4m3fn == NPY_NOTYPE ||

xla/tsl/python/lib/core/ml_dtypes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ namespace ml_dtypes {
2424

2525
struct NumpyDtypes {
2626
int bfloat16;
27+
int float4_e2m1fn;
2728
int float8_e3m4;
2829
int float8_e4m3;
2930
int float8_e4m3fn;

0 commit comments

Comments
 (0)