|
13 | 13 | #include "IRModule.h"
|
14 | 14 |
|
15 | 15 | #include "PybindUtils.h"
|
| 16 | +#include <pybind11/numpy.h> |
16 | 17 |
|
17 | 18 | #include "llvm/ADT/ScopeExit.h"
|
18 | 19 | #include "llvm/Support/raw_ostream.h"
|
@@ -757,103 +758,10 @@ class PyDenseElementsAttribute
|
757 | 758 | throw py::error_already_set();
|
758 | 759 | }
|
759 | 760 | auto freeBuffer = llvm::make_scope_exit([&]() { PyBuffer_Release(&view); });
|
760 |
| - SmallVector<int64_t> shape; |
761 |
| - if (explicitShape) { |
762 |
| - shape.append(explicitShape->begin(), explicitShape->end()); |
763 |
| - } else { |
764 |
| - shape.append(view.shape, view.shape + view.ndim); |
765 |
| - } |
766 | 761 |
|
767 |
| - MlirAttribute encodingAttr = mlirAttributeGetNull(); |
768 | 762 | MlirContext context = contextWrapper->get();
|
769 |
| - |
770 |
| - // Detect format codes that are suitable for bulk loading. This includes |
771 |
| - // all byte aligned integer and floating point types up to 8 bytes. |
772 |
| - // Notably, this excludes, bool (which needs to be bit-packed) and |
773 |
| - // other exotics which do not have a direct representation in the buffer |
774 |
| - // protocol (i.e. complex, etc). |
775 |
| - std::optional<MlirType> bulkLoadElementType; |
776 |
| - if (explicitType) { |
777 |
| - bulkLoadElementType = *explicitType; |
778 |
| - } else { |
779 |
| - std::string_view format(view.format); |
780 |
| - if (format == "f") { |
781 |
| - // f32 |
782 |
| - assert(view.itemsize == 4 && "mismatched array itemsize"); |
783 |
| - bulkLoadElementType = mlirF32TypeGet(context); |
784 |
| - } else if (format == "d") { |
785 |
| - // f64 |
786 |
| - assert(view.itemsize == 8 && "mismatched array itemsize"); |
787 |
| - bulkLoadElementType = mlirF64TypeGet(context); |
788 |
| - } else if (format == "e") { |
789 |
| - // f16 |
790 |
| - assert(view.itemsize == 2 && "mismatched array itemsize"); |
791 |
| - bulkLoadElementType = mlirF16TypeGet(context); |
792 |
| - } else if (isSignedIntegerFormat(format)) { |
793 |
| - if (view.itemsize == 4) { |
794 |
| - // i32 |
795 |
| - bulkLoadElementType = signless |
796 |
| - ? mlirIntegerTypeGet(context, 32) |
797 |
| - : mlirIntegerTypeSignedGet(context, 32); |
798 |
| - } else if (view.itemsize == 8) { |
799 |
| - // i64 |
800 |
| - bulkLoadElementType = signless |
801 |
| - ? mlirIntegerTypeGet(context, 64) |
802 |
| - : mlirIntegerTypeSignedGet(context, 64); |
803 |
| - } else if (view.itemsize == 1) { |
804 |
| - // i8 |
805 |
| - bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8) |
806 |
| - : mlirIntegerTypeSignedGet(context, 8); |
807 |
| - } else if (view.itemsize == 2) { |
808 |
| - // i16 |
809 |
| - bulkLoadElementType = signless |
810 |
| - ? mlirIntegerTypeGet(context, 16) |
811 |
| - : mlirIntegerTypeSignedGet(context, 16); |
812 |
| - } |
813 |
| - } else if (isUnsignedIntegerFormat(format)) { |
814 |
| - if (view.itemsize == 4) { |
815 |
| - // unsigned i32 |
816 |
| - bulkLoadElementType = signless |
817 |
| - ? mlirIntegerTypeGet(context, 32) |
818 |
| - : mlirIntegerTypeUnsignedGet(context, 32); |
819 |
| - } else if (view.itemsize == 8) { |
820 |
| - // unsigned i64 |
821 |
| - bulkLoadElementType = signless |
822 |
| - ? mlirIntegerTypeGet(context, 64) |
823 |
| - : mlirIntegerTypeUnsignedGet(context, 64); |
824 |
| - } else if (view.itemsize == 1) { |
825 |
| - // i8 |
826 |
| - bulkLoadElementType = signless |
827 |
| - ? mlirIntegerTypeGet(context, 8) |
828 |
| - : mlirIntegerTypeUnsignedGet(context, 8); |
829 |
| - } else if (view.itemsize == 2) { |
830 |
| - // i16 |
831 |
| - bulkLoadElementType = signless |
832 |
| - ? mlirIntegerTypeGet(context, 16) |
833 |
| - : mlirIntegerTypeUnsignedGet(context, 16); |
834 |
| - } |
835 |
| - } |
836 |
| - if (!bulkLoadElementType) { |
837 |
| - throw std::invalid_argument( |
838 |
| - std::string("unimplemented array format conversion from format: ") + |
839 |
| - std::string(format)); |
840 |
| - } |
841 |
| - } |
842 |
| - |
843 |
| - MlirType shapedType; |
844 |
| - if (mlirTypeIsAShaped(*bulkLoadElementType)) { |
845 |
| - if (explicitShape) { |
846 |
| - throw std::invalid_argument("Shape can only be specified explicitly " |
847 |
| - "when the type is not a shaped type."); |
848 |
| - } |
849 |
| - shapedType = *bulkLoadElementType; |
850 |
| - } else { |
851 |
| - shapedType = mlirRankedTensorTypeGet(shape.size(), shape.data(), |
852 |
| - *bulkLoadElementType, encodingAttr); |
853 |
| - } |
854 |
| - size_t rawBufferSize = view.len; |
855 |
| - MlirAttribute attr = |
856 |
| - mlirDenseElementsAttrRawBufferGet(shapedType, rawBufferSize, view.buf); |
| 763 | + MlirAttribute attr = getAttributeFromBuffer(view, signless, explicitType, |
| 764 | + explicitShape, context); |
857 | 765 | if (mlirAttributeIsNull(attr)) {
|
858 | 766 | throw std::invalid_argument(
|
859 | 767 | "DenseElementsAttr could not be constructed from the given buffer. "
|
@@ -963,6 +871,13 @@ class PyDenseElementsAttribute
|
963 | 871 | // unsigned i16
|
964 | 872 | return bufferInfo<uint16_t>(shapedType);
|
965 | 873 | }
|
| 874 | + } else if (mlirTypeIsAInteger(elementType) && |
| 875 | + mlirIntegerTypeGetWidth(elementType) == 1) { |
| 876 | + // i1 / bool |
| 877 | + // We can not send the buffer directly back to Python, because the i1 |
| 878 | + // values are bitpacked within MLIR. We call numpy's unpackbits function |
| 879 | + // to convert the bytes. |
| 880 | + return getBooleanBufferFromBitpackedAttribute(); |
966 | 881 | }
|
967 | 882 |
|
968 | 883 | // TODO: Currently crashes the program.
|
@@ -1016,14 +931,183 @@ class PyDenseElementsAttribute
|
1016 | 931 | code == 'q';
|
1017 | 932 | }
|
1018 | 933 |
|
| 934 | + static MlirType |
| 935 | + getShapedType(std::optional<MlirType> bulkLoadElementType, |
| 936 | + std::optional<std::vector<int64_t>> explicitShape, |
| 937 | + Py_buffer &view) { |
| 938 | + SmallVector<int64_t> shape; |
| 939 | + if (explicitShape) { |
| 940 | + shape.append(explicitShape->begin(), explicitShape->end()); |
| 941 | + } else { |
| 942 | + shape.append(view.shape, view.shape + view.ndim); |
| 943 | + } |
| 944 | + |
| 945 | + if (mlirTypeIsAShaped(*bulkLoadElementType)) { |
| 946 | + if (explicitShape) { |
| 947 | + throw std::invalid_argument("Shape can only be specified explicitly " |
| 948 | + "when the type is not a shaped type."); |
| 949 | + } |
| 950 | + return *bulkLoadElementType; |
| 951 | + } else { |
| 952 | + MlirAttribute encodingAttr = mlirAttributeGetNull(); |
| 953 | + return mlirRankedTensorTypeGet(shape.size(), shape.data(), |
| 954 | + *bulkLoadElementType, encodingAttr); |
| 955 | + } |
| 956 | + } |
| 957 | + |
| 958 | + static MlirAttribute getAttributeFromBuffer( |
| 959 | + Py_buffer &view, bool signless, std::optional<PyType> explicitType, |
| 960 | + std::optional<std::vector<int64_t>> explicitShape, MlirContext &context) { |
| 961 | + // Detect format codes that are suitable for bulk loading. This includes |
| 962 | + // all byte aligned integer and floating point types up to 8 bytes. |
| 963 | + // Notably, this excludes exotics types which do not have a direct |
| 964 | + // representation in the buffer protocol (i.e. complex, etc). |
| 965 | + std::optional<MlirType> bulkLoadElementType; |
| 966 | + if (explicitType) { |
| 967 | + bulkLoadElementType = *explicitType; |
| 968 | + } else { |
| 969 | + std::string_view format(view.format); |
| 970 | + if (format == "f") { |
| 971 | + // f32 |
| 972 | + assert(view.itemsize == 4 && "mismatched array itemsize"); |
| 973 | + bulkLoadElementType = mlirF32TypeGet(context); |
| 974 | + } else if (format == "d") { |
| 975 | + // f64 |
| 976 | + assert(view.itemsize == 8 && "mismatched array itemsize"); |
| 977 | + bulkLoadElementType = mlirF64TypeGet(context); |
| 978 | + } else if (format == "e") { |
| 979 | + // f16 |
| 980 | + assert(view.itemsize == 2 && "mismatched array itemsize"); |
| 981 | + bulkLoadElementType = mlirF16TypeGet(context); |
| 982 | + } else if (format == "?") { |
| 983 | + // i1 |
| 984 | + // The i1 type needs to be bit-packed, so we will handle it seperately |
| 985 | + return getBitpackedAttributeFromBooleanBuffer(view, explicitShape, |
| 986 | + context); |
| 987 | + } else if (isSignedIntegerFormat(format)) { |
| 988 | + if (view.itemsize == 4) { |
| 989 | + // i32 |
| 990 | + bulkLoadElementType = signless |
| 991 | + ? mlirIntegerTypeGet(context, 32) |
| 992 | + : mlirIntegerTypeSignedGet(context, 32); |
| 993 | + } else if (view.itemsize == 8) { |
| 994 | + // i64 |
| 995 | + bulkLoadElementType = signless |
| 996 | + ? mlirIntegerTypeGet(context, 64) |
| 997 | + : mlirIntegerTypeSignedGet(context, 64); |
| 998 | + } else if (view.itemsize == 1) { |
| 999 | + // i8 |
| 1000 | + bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8) |
| 1001 | + : mlirIntegerTypeSignedGet(context, 8); |
| 1002 | + } else if (view.itemsize == 2) { |
| 1003 | + // i16 |
| 1004 | + bulkLoadElementType = signless |
| 1005 | + ? mlirIntegerTypeGet(context, 16) |
| 1006 | + : mlirIntegerTypeSignedGet(context, 16); |
| 1007 | + } |
| 1008 | + } else if (isUnsignedIntegerFormat(format)) { |
| 1009 | + if (view.itemsize == 4) { |
| 1010 | + // unsigned i32 |
| 1011 | + bulkLoadElementType = signless |
| 1012 | + ? mlirIntegerTypeGet(context, 32) |
| 1013 | + : mlirIntegerTypeUnsignedGet(context, 32); |
| 1014 | + } else if (view.itemsize == 8) { |
| 1015 | + // unsigned i64 |
| 1016 | + bulkLoadElementType = signless |
| 1017 | + ? mlirIntegerTypeGet(context, 64) |
| 1018 | + : mlirIntegerTypeUnsignedGet(context, 64); |
| 1019 | + } else if (view.itemsize == 1) { |
| 1020 | + // i8 |
| 1021 | + bulkLoadElementType = signless |
| 1022 | + ? mlirIntegerTypeGet(context, 8) |
| 1023 | + : mlirIntegerTypeUnsignedGet(context, 8); |
| 1024 | + } else if (view.itemsize == 2) { |
| 1025 | + // i16 |
| 1026 | + bulkLoadElementType = signless |
| 1027 | + ? mlirIntegerTypeGet(context, 16) |
| 1028 | + : mlirIntegerTypeUnsignedGet(context, 16); |
| 1029 | + } |
| 1030 | + } |
| 1031 | + if (!bulkLoadElementType) { |
| 1032 | + throw std::invalid_argument( |
| 1033 | + std::string("unimplemented array format conversion from format: ") + |
| 1034 | + std::string(format)); |
| 1035 | + } |
| 1036 | + } |
| 1037 | + |
| 1038 | + MlirType type = getShapedType(bulkLoadElementType, explicitShape, view); |
| 1039 | + return mlirDenseElementsAttrRawBufferGet(type, view.len, view.buf); |
| 1040 | + } |
| 1041 | + |
| 1042 | + // There is a complication for boolean numpy arrays, as numpy represents them |
| 1043 | + // as 8 bits (1 byte) per boolean, whereas MLIR bitpacks them into 8 booleans |
| 1044 | + // per byte. |
| 1045 | + static MlirAttribute getBitpackedAttributeFromBooleanBuffer( |
| 1046 | + Py_buffer &view, std::optional<std::vector<int64_t>> explicitShape, |
| 1047 | + MlirContext &context) { |
| 1048 | + if (llvm::endianness::native != llvm::endianness::little) { |
| 1049 | + // Given we have no good way of testing the behavior on big-endian systems |
| 1050 | + // we will throw |
| 1051 | + throw py::type_error("Constructing a bit-packed MLIR attribute is " |
| 1052 | + "unsupported on big-endian systems"); |
| 1053 | + } |
| 1054 | + |
| 1055 | + py::array_t<uint8_t> unpackedArray(view.len, |
| 1056 | + static_cast<uint8_t *>(view.buf)); |
| 1057 | + |
| 1058 | + py::module numpy = py::module::import("numpy"); |
| 1059 | + py::object packbits_func = numpy.attr("packbits"); |
| 1060 | + py::object packed_booleans = |
| 1061 | + packbits_func(unpackedArray, "bitorder"_a = "little"); |
| 1062 | + py::buffer_info pythonBuffer = packed_booleans.cast<py::buffer>().request(); |
| 1063 | + |
| 1064 | + MlirType bitpackedType = |
| 1065 | + getShapedType(mlirIntegerTypeGet(context, 1), explicitShape, view); |
| 1066 | + return mlirDenseElementsAttrRawBufferGet(bitpackedType, pythonBuffer.size, |
| 1067 | + pythonBuffer.ptr); |
| 1068 | + } |
| 1069 | + |
| 1070 | + // This does the opposite transformation of |
| 1071 | + // `getBitpackedAttributeFromBooleanBuffer` |
| 1072 | + py::buffer_info getBooleanBufferFromBitpackedAttribute() { |
| 1073 | + if (llvm::endianness::native != llvm::endianness::little) { |
| 1074 | + // Given we have no good way of testing the behavior on big-endian systems |
| 1075 | + // we will throw |
| 1076 | + throw py::type_error("Constructing a numpy array from a MLIR attribute " |
| 1077 | + "is unsupported on big-endian systems"); |
| 1078 | + } |
| 1079 | + |
| 1080 | + int64_t numBooleans = mlirElementsAttrGetNumElements(*this); |
| 1081 | + int64_t numBitpackedBytes = llvm::divideCeil(numBooleans, 8); |
| 1082 | + uint8_t *bitpackedData = static_cast<uint8_t *>( |
| 1083 | + const_cast<void *>(mlirDenseElementsAttrGetRawData(*this))); |
| 1084 | + py::array_t<uint8_t> packedArray(numBitpackedBytes, bitpackedData); |
| 1085 | + |
| 1086 | + py::module numpy = py::module::import("numpy"); |
| 1087 | + py::object unpackbits_func = numpy.attr("unpackbits"); |
| 1088 | + py::object unpacked_booleans = |
| 1089 | + unpackbits_func(packedArray, "bitorder"_a = "little"); |
| 1090 | + py::buffer_info pythonBuffer = |
| 1091 | + unpacked_booleans.cast<py::buffer>().request(); |
| 1092 | + |
| 1093 | + MlirType shapedType = mlirAttributeGetType(*this); |
| 1094 | + return bufferInfo<bool>(shapedType, (bool *)pythonBuffer.ptr, "?"); |
| 1095 | + } |
| 1096 | + |
1019 | 1097 | template <typename Type>
|
1020 | 1098 | py::buffer_info bufferInfo(MlirType shapedType,
|
1021 | 1099 | const char *explicitFormat = nullptr) {
|
1022 |
| - intptr_t rank = mlirShapedTypeGetRank(shapedType); |
1023 | 1100 | // Prepare the data for the buffer_info.
|
1024 |
| - // Buffer is configured for read-only access below. |
| 1101 | + // Buffer is configured for read-only access inside the `bufferInfo` call. |
1025 | 1102 | Type *data = static_cast<Type *>(
|
1026 | 1103 | const_cast<void *>(mlirDenseElementsAttrGetRawData(*this)));
|
| 1104 | + return bufferInfo<Type>(shapedType, data, explicitFormat); |
| 1105 | + } |
| 1106 | + |
| 1107 | + template <typename Type> |
| 1108 | + py::buffer_info bufferInfo(MlirType shapedType, Type *data, |
| 1109 | + const char *explicitFormat = nullptr) { |
| 1110 | + intptr_t rank = mlirShapedTypeGetRank(shapedType); |
1027 | 1111 | // Prepare the shape for the buffer_info.
|
1028 | 1112 | SmallVector<intptr_t, 4> shape;
|
1029 | 1113 | for (intptr_t i = 0; i < rank; ++i)
|
|
0 commit comments