Skip to content

Commit fb7bf7a

Browse files
authored
[MLIR,Python] Support converting boolean numpy arrays to and from mlir attributes (#113064)
Currently it is unsupported to: 1. Convert a `MlirAttribute` with type `i1` to a numpy array 2. Convert a boolean numpy array to a `MlirAttribute` Currently the entire Python application violently crashes with a quite poor error message pybind/pybind11#3336 The complication handling these conversions, is that `MlirAttribute` represent booleans as a bit-packed `i1` type, whereas numpy represents booleans as a byte array with 8 bit used per boolean. This PR proposes the following approach: 1. When converting a `i1` typed `MlirAttribute` to a numpy array, we can not directly use the underlying raw data backing the `MlirAttribute` as a buffer to Python, as done for other types. Instead, a copy of the data is generated using numpy's unpackbits function, and the result is send back to Python. 2. When constructing a `MlirAttribute` from a numpy array, first the python data is read as a `uint8_t` to get it converted to the endianess used internally in mlir. Then the booleans are bitpacked using numpy's bitpack function, and the bitpacked array is saved as the `MlirAttribute` representation. Please note that I am not sure if this approach is the desired solution. I'd appreciate any feedback.
1 parent 78bfcc5 commit fb7bf7a

File tree

2 files changed

+253
-97
lines changed

2 files changed

+253
-97
lines changed

mlir/lib/Bindings/Python/IRAttributes.cpp

Lines changed: 181 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include "IRModule.h"
1414

1515
#include "PybindUtils.h"
16+
#include <pybind11/numpy.h>
1617

1718
#include "llvm/ADT/ScopeExit.h"
1819
#include "llvm/Support/raw_ostream.h"
@@ -757,103 +758,10 @@ class PyDenseElementsAttribute
757758
throw py::error_already_set();
758759
}
759760
auto freeBuffer = llvm::make_scope_exit([&]() { PyBuffer_Release(&view); });
760-
SmallVector<int64_t> shape;
761-
if (explicitShape) {
762-
shape.append(explicitShape->begin(), explicitShape->end());
763-
} else {
764-
shape.append(view.shape, view.shape + view.ndim);
765-
}
766761

767-
MlirAttribute encodingAttr = mlirAttributeGetNull();
768762
MlirContext context = contextWrapper->get();
769-
770-
// Detect format codes that are suitable for bulk loading. This includes
771-
// all byte aligned integer and floating point types up to 8 bytes.
772-
// Notably, this excludes, bool (which needs to be bit-packed) and
773-
// other exotics which do not have a direct representation in the buffer
774-
// protocol (i.e. complex, etc).
775-
std::optional<MlirType> bulkLoadElementType;
776-
if (explicitType) {
777-
bulkLoadElementType = *explicitType;
778-
} else {
779-
std::string_view format(view.format);
780-
if (format == "f") {
781-
// f32
782-
assert(view.itemsize == 4 && "mismatched array itemsize");
783-
bulkLoadElementType = mlirF32TypeGet(context);
784-
} else if (format == "d") {
785-
// f64
786-
assert(view.itemsize == 8 && "mismatched array itemsize");
787-
bulkLoadElementType = mlirF64TypeGet(context);
788-
} else if (format == "e") {
789-
// f16
790-
assert(view.itemsize == 2 && "mismatched array itemsize");
791-
bulkLoadElementType = mlirF16TypeGet(context);
792-
} else if (isSignedIntegerFormat(format)) {
793-
if (view.itemsize == 4) {
794-
// i32
795-
bulkLoadElementType = signless
796-
? mlirIntegerTypeGet(context, 32)
797-
: mlirIntegerTypeSignedGet(context, 32);
798-
} else if (view.itemsize == 8) {
799-
// i64
800-
bulkLoadElementType = signless
801-
? mlirIntegerTypeGet(context, 64)
802-
: mlirIntegerTypeSignedGet(context, 64);
803-
} else if (view.itemsize == 1) {
804-
// i8
805-
bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8)
806-
: mlirIntegerTypeSignedGet(context, 8);
807-
} else if (view.itemsize == 2) {
808-
// i16
809-
bulkLoadElementType = signless
810-
? mlirIntegerTypeGet(context, 16)
811-
: mlirIntegerTypeSignedGet(context, 16);
812-
}
813-
} else if (isUnsignedIntegerFormat(format)) {
814-
if (view.itemsize == 4) {
815-
// unsigned i32
816-
bulkLoadElementType = signless
817-
? mlirIntegerTypeGet(context, 32)
818-
: mlirIntegerTypeUnsignedGet(context, 32);
819-
} else if (view.itemsize == 8) {
820-
// unsigned i64
821-
bulkLoadElementType = signless
822-
? mlirIntegerTypeGet(context, 64)
823-
: mlirIntegerTypeUnsignedGet(context, 64);
824-
} else if (view.itemsize == 1) {
825-
// i8
826-
bulkLoadElementType = signless
827-
? mlirIntegerTypeGet(context, 8)
828-
: mlirIntegerTypeUnsignedGet(context, 8);
829-
} else if (view.itemsize == 2) {
830-
// i16
831-
bulkLoadElementType = signless
832-
? mlirIntegerTypeGet(context, 16)
833-
: mlirIntegerTypeUnsignedGet(context, 16);
834-
}
835-
}
836-
if (!bulkLoadElementType) {
837-
throw std::invalid_argument(
838-
std::string("unimplemented array format conversion from format: ") +
839-
std::string(format));
840-
}
841-
}
842-
843-
MlirType shapedType;
844-
if (mlirTypeIsAShaped(*bulkLoadElementType)) {
845-
if (explicitShape) {
846-
throw std::invalid_argument("Shape can only be specified explicitly "
847-
"when the type is not a shaped type.");
848-
}
849-
shapedType = *bulkLoadElementType;
850-
} else {
851-
shapedType = mlirRankedTensorTypeGet(shape.size(), shape.data(),
852-
*bulkLoadElementType, encodingAttr);
853-
}
854-
size_t rawBufferSize = view.len;
855-
MlirAttribute attr =
856-
mlirDenseElementsAttrRawBufferGet(shapedType, rawBufferSize, view.buf);
763+
MlirAttribute attr = getAttributeFromBuffer(view, signless, explicitType,
764+
explicitShape, context);
857765
if (mlirAttributeIsNull(attr)) {
858766
throw std::invalid_argument(
859767
"DenseElementsAttr could not be constructed from the given buffer. "
@@ -963,6 +871,13 @@ class PyDenseElementsAttribute
963871
// unsigned i16
964872
return bufferInfo<uint16_t>(shapedType);
965873
}
874+
} else if (mlirTypeIsAInteger(elementType) &&
875+
mlirIntegerTypeGetWidth(elementType) == 1) {
876+
// i1 / bool
877+
// We can not send the buffer directly back to Python, because the i1
878+
// values are bitpacked within MLIR. We call numpy's unpackbits function
879+
// to convert the bytes.
880+
return getBooleanBufferFromBitpackedAttribute();
966881
}
967882

968883
// TODO: Currently crashes the program.
@@ -1016,14 +931,183 @@ class PyDenseElementsAttribute
1016931
code == 'q';
1017932
}
1018933

934+
static MlirType
935+
getShapedType(std::optional<MlirType> bulkLoadElementType,
936+
std::optional<std::vector<int64_t>> explicitShape,
937+
Py_buffer &view) {
938+
SmallVector<int64_t> shape;
939+
if (explicitShape) {
940+
shape.append(explicitShape->begin(), explicitShape->end());
941+
} else {
942+
shape.append(view.shape, view.shape + view.ndim);
943+
}
944+
945+
if (mlirTypeIsAShaped(*bulkLoadElementType)) {
946+
if (explicitShape) {
947+
throw std::invalid_argument("Shape can only be specified explicitly "
948+
"when the type is not a shaped type.");
949+
}
950+
return *bulkLoadElementType;
951+
} else {
952+
MlirAttribute encodingAttr = mlirAttributeGetNull();
953+
return mlirRankedTensorTypeGet(shape.size(), shape.data(),
954+
*bulkLoadElementType, encodingAttr);
955+
}
956+
}
957+
958+
static MlirAttribute getAttributeFromBuffer(
959+
Py_buffer &view, bool signless, std::optional<PyType> explicitType,
960+
std::optional<std::vector<int64_t>> explicitShape, MlirContext &context) {
961+
// Detect format codes that are suitable for bulk loading. This includes
962+
// all byte aligned integer and floating point types up to 8 bytes.
963+
// Notably, this excludes exotics types which do not have a direct
964+
// representation in the buffer protocol (i.e. complex, etc).
965+
std::optional<MlirType> bulkLoadElementType;
966+
if (explicitType) {
967+
bulkLoadElementType = *explicitType;
968+
} else {
969+
std::string_view format(view.format);
970+
if (format == "f") {
971+
// f32
972+
assert(view.itemsize == 4 && "mismatched array itemsize");
973+
bulkLoadElementType = mlirF32TypeGet(context);
974+
} else if (format == "d") {
975+
// f64
976+
assert(view.itemsize == 8 && "mismatched array itemsize");
977+
bulkLoadElementType = mlirF64TypeGet(context);
978+
} else if (format == "e") {
979+
// f16
980+
assert(view.itemsize == 2 && "mismatched array itemsize");
981+
bulkLoadElementType = mlirF16TypeGet(context);
982+
} else if (format == "?") {
983+
// i1
984+
// The i1 type needs to be bit-packed, so we will handle it seperately
985+
return getBitpackedAttributeFromBooleanBuffer(view, explicitShape,
986+
context);
987+
} else if (isSignedIntegerFormat(format)) {
988+
if (view.itemsize == 4) {
989+
// i32
990+
bulkLoadElementType = signless
991+
? mlirIntegerTypeGet(context, 32)
992+
: mlirIntegerTypeSignedGet(context, 32);
993+
} else if (view.itemsize == 8) {
994+
// i64
995+
bulkLoadElementType = signless
996+
? mlirIntegerTypeGet(context, 64)
997+
: mlirIntegerTypeSignedGet(context, 64);
998+
} else if (view.itemsize == 1) {
999+
// i8
1000+
bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8)
1001+
: mlirIntegerTypeSignedGet(context, 8);
1002+
} else if (view.itemsize == 2) {
1003+
// i16
1004+
bulkLoadElementType = signless
1005+
? mlirIntegerTypeGet(context, 16)
1006+
: mlirIntegerTypeSignedGet(context, 16);
1007+
}
1008+
} else if (isUnsignedIntegerFormat(format)) {
1009+
if (view.itemsize == 4) {
1010+
// unsigned i32
1011+
bulkLoadElementType = signless
1012+
? mlirIntegerTypeGet(context, 32)
1013+
: mlirIntegerTypeUnsignedGet(context, 32);
1014+
} else if (view.itemsize == 8) {
1015+
// unsigned i64
1016+
bulkLoadElementType = signless
1017+
? mlirIntegerTypeGet(context, 64)
1018+
: mlirIntegerTypeUnsignedGet(context, 64);
1019+
} else if (view.itemsize == 1) {
1020+
// i8
1021+
bulkLoadElementType = signless
1022+
? mlirIntegerTypeGet(context, 8)
1023+
: mlirIntegerTypeUnsignedGet(context, 8);
1024+
} else if (view.itemsize == 2) {
1025+
// i16
1026+
bulkLoadElementType = signless
1027+
? mlirIntegerTypeGet(context, 16)
1028+
: mlirIntegerTypeUnsignedGet(context, 16);
1029+
}
1030+
}
1031+
if (!bulkLoadElementType) {
1032+
throw std::invalid_argument(
1033+
std::string("unimplemented array format conversion from format: ") +
1034+
std::string(format));
1035+
}
1036+
}
1037+
1038+
MlirType type = getShapedType(bulkLoadElementType, explicitShape, view);
1039+
return mlirDenseElementsAttrRawBufferGet(type, view.len, view.buf);
1040+
}
1041+
1042+
// There is a complication for boolean numpy arrays, as numpy represents them
1043+
// as 8 bits (1 byte) per boolean, whereas MLIR bitpacks them into 8 booleans
1044+
// per byte.
1045+
static MlirAttribute getBitpackedAttributeFromBooleanBuffer(
1046+
Py_buffer &view, std::optional<std::vector<int64_t>> explicitShape,
1047+
MlirContext &context) {
1048+
if (llvm::endianness::native != llvm::endianness::little) {
1049+
// Given we have no good way of testing the behavior on big-endian systems
1050+
// we will throw
1051+
throw py::type_error("Constructing a bit-packed MLIR attribute is "
1052+
"unsupported on big-endian systems");
1053+
}
1054+
1055+
py::array_t<uint8_t> unpackedArray(view.len,
1056+
static_cast<uint8_t *>(view.buf));
1057+
1058+
py::module numpy = py::module::import("numpy");
1059+
py::object packbits_func = numpy.attr("packbits");
1060+
py::object packed_booleans =
1061+
packbits_func(unpackedArray, "bitorder"_a = "little");
1062+
py::buffer_info pythonBuffer = packed_booleans.cast<py::buffer>().request();
1063+
1064+
MlirType bitpackedType =
1065+
getShapedType(mlirIntegerTypeGet(context, 1), explicitShape, view);
1066+
return mlirDenseElementsAttrRawBufferGet(bitpackedType, pythonBuffer.size,
1067+
pythonBuffer.ptr);
1068+
}
1069+
1070+
// This does the opposite transformation of
1071+
// `getBitpackedAttributeFromBooleanBuffer`
1072+
py::buffer_info getBooleanBufferFromBitpackedAttribute() {
1073+
if (llvm::endianness::native != llvm::endianness::little) {
1074+
// Given we have no good way of testing the behavior on big-endian systems
1075+
// we will throw
1076+
throw py::type_error("Constructing a numpy array from a MLIR attribute "
1077+
"is unsupported on big-endian systems");
1078+
}
1079+
1080+
int64_t numBooleans = mlirElementsAttrGetNumElements(*this);
1081+
int64_t numBitpackedBytes = llvm::divideCeil(numBooleans, 8);
1082+
uint8_t *bitpackedData = static_cast<uint8_t *>(
1083+
const_cast<void *>(mlirDenseElementsAttrGetRawData(*this)));
1084+
py::array_t<uint8_t> packedArray(numBitpackedBytes, bitpackedData);
1085+
1086+
py::module numpy = py::module::import("numpy");
1087+
py::object unpackbits_func = numpy.attr("unpackbits");
1088+
py::object unpacked_booleans =
1089+
unpackbits_func(packedArray, "bitorder"_a = "little");
1090+
py::buffer_info pythonBuffer =
1091+
unpacked_booleans.cast<py::buffer>().request();
1092+
1093+
MlirType shapedType = mlirAttributeGetType(*this);
1094+
return bufferInfo<bool>(shapedType, (bool *)pythonBuffer.ptr, "?");
1095+
}
1096+
10191097
template <typename Type>
10201098
py::buffer_info bufferInfo(MlirType shapedType,
10211099
const char *explicitFormat = nullptr) {
1022-
intptr_t rank = mlirShapedTypeGetRank(shapedType);
10231100
// Prepare the data for the buffer_info.
1024-
// Buffer is configured for read-only access below.
1101+
// Buffer is configured for read-only access inside the `bufferInfo` call.
10251102
Type *data = static_cast<Type *>(
10261103
const_cast<void *>(mlirDenseElementsAttrGetRawData(*this)));
1104+
return bufferInfo<Type>(shapedType, data, explicitFormat);
1105+
}
1106+
1107+
template <typename Type>
1108+
py::buffer_info bufferInfo(MlirType shapedType, Type *data,
1109+
const char *explicitFormat = nullptr) {
1110+
intptr_t rank = mlirShapedTypeGetRank(shapedType);
10271111
// Prepare the shape for the buffer_info.
10281112
SmallVector<intptr_t, 4> shape;
10291113
for (intptr_t i = 0; i < rank; ++i)

mlir/test/python/ir/array_attributes.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -326,6 +326,78 @@ def testGetDenseElementsF64():
326326
print(np.array(attr))
327327

328328

329+
### 1 bit/boolean integer arrays
330+
# CHECK-LABEL: TEST: testGetDenseElementsI1Signless
331+
@run
332+
def testGetDenseElementsI1Signless():
333+
with Context():
334+
array = np.array([True], dtype=np.bool_)
335+
attr = DenseElementsAttr.get(array)
336+
# CHECK: dense<true> : tensor<1xi1>
337+
print(attr)
338+
# CHECK{LITERAL}: [ True]
339+
print(np.array(attr))
340+
341+
array = np.array([[True, False, True], [True, True, False]], dtype=np.bool_)
342+
attr = DenseElementsAttr.get(array)
343+
# CHECK{LITERAL}: dense<[[true, false, true], [true, true, false]]> : tensor<2x3xi1>
344+
print(attr)
345+
# CHECK{LITERAL}: [[ True False True]
346+
# CHECK{LITERAL}: [ True True False]]
347+
print(np.array(attr))
348+
349+
array = np.array(
350+
[[True, True, False, False], [True, False, True, False]], dtype=np.bool_
351+
)
352+
attr = DenseElementsAttr.get(array)
353+
# CHECK{LITERAL}: dense<[[true, true, false, false], [true, false, true, false]]> : tensor<2x4xi1>
354+
print(attr)
355+
# CHECK{LITERAL}: [[ True True False False]
356+
# CHECK{LITERAL}: [ True False True False]]
357+
print(np.array(attr))
358+
359+
array = np.array(
360+
[
361+
[True, True, False, False],
362+
[True, False, True, False],
363+
[False, False, False, False],
364+
[True, True, True, True],
365+
[True, False, False, True],
366+
],
367+
dtype=np.bool_,
368+
)
369+
attr = DenseElementsAttr.get(array)
370+
# CHECK{LITERAL}: dense<[[true, true, false, false], [true, false, true, false], [false, false, false, false], [true, true, true, true], [true, false, false, true]]> : tensor<5x4xi1>
371+
print(attr)
372+
# CHECK{LITERAL}: [[ True True False False]
373+
# CHECK{LITERAL}: [ True False True False]
374+
# CHECK{LITERAL}: [False False False False]
375+
# CHECK{LITERAL}: [ True True True True]
376+
# CHECK{LITERAL}: [ True False False True]]
377+
print(np.array(attr))
378+
379+
array = np.array(
380+
[
381+
[True, True, False, False, True, True, False, False, False],
382+
[False, False, False, True, False, True, True, False, True],
383+
],
384+
dtype=np.bool_,
385+
)
386+
attr = DenseElementsAttr.get(array)
387+
# CHECK{LITERAL}: dense<[[true, true, false, false, true, true, false, false, false], [false, false, false, true, false, true, true, false, true]]> : tensor<2x9xi1>
388+
print(attr)
389+
# CHECK{LITERAL}: [[ True True False False True True False False False]
390+
# CHECK{LITERAL}: [False False False True False True True False True]]
391+
print(np.array(attr))
392+
393+
array = np.array([], dtype=np.bool_)
394+
attr = DenseElementsAttr.get(array)
395+
# CHECK: dense<> : tensor<0xi1>
396+
print(attr)
397+
# CHECK{LITERAL}: []
398+
print(np.array(attr))
399+
400+
329401
### 16 bit integer arrays
330402
# CHECK-LABEL: TEST: testGetDenseElementsI16Signless
331403
@run

0 commit comments

Comments
 (0)