@@ -80,6 +80,8 @@ class PyAffineMapAttribute : public PyConcreteAttribute<PyAffineMapAttribute> {
80
80
static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAAffineMap;
81
81
static constexpr const char *pyClassName = " AffineMapAttr" ;
82
82
using PyConcreteAttribute::PyConcreteAttribute;
83
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
84
+ mlirAffineMapAttrGetTypeID;
83
85
84
86
static void bindDerived (ClassTy &c) {
85
87
c.def_static (
@@ -259,6 +261,8 @@ class PyArrayAttribute : public PyConcreteAttribute<PyArrayAttribute> {
259
261
static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAArray;
260
262
static constexpr const char *pyClassName = " ArrayAttr" ;
261
263
using PyConcreteAttribute::PyConcreteAttribute;
264
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
265
+ mlirArrayAttrGetTypeID;
262
266
263
267
class PyArrayAttributeIterator {
264
268
public:
@@ -339,6 +343,8 @@ class PyFloatAttribute : public PyConcreteAttribute<PyFloatAttribute> {
339
343
static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFloat;
340
344
static constexpr const char *pyClassName = " FloatAttr" ;
341
345
using PyConcreteAttribute::PyConcreteAttribute;
346
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
347
+ mlirFloatAttrGetTypeID;
342
348
343
349
static void bindDerived (ClassTy &c) {
344
350
c.def_static (
@@ -406,6 +412,10 @@ class PyIntegerAttribute : public PyConcreteAttribute<PyIntegerAttribute> {
406
412
return mlirIntegerAttrGetValueUInt (self);
407
413
},
408
414
" Returns the value of the integer attribute" );
415
+ c.def_property_readonly_static (" static_typeid" ,
416
+ [](py::object & /* class*/ ) -> MlirTypeID {
417
+ return mlirIntegerAttrGetTypeID ();
418
+ });
409
419
}
410
420
};
411
421
@@ -438,6 +448,8 @@ class PyFlatSymbolRefAttribute
438
448
static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFlatSymbolRef;
439
449
static constexpr const char *pyClassName = " FlatSymbolRefAttr" ;
440
450
using PyConcreteAttribute::PyConcreteAttribute;
451
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
452
+ mlirFlatSymbolRefAttrGetTypeID;
441
453
442
454
static void bindDerived (ClassTy &c) {
443
455
c.def_static (
@@ -464,6 +476,8 @@ class PyOpaqueAttribute : public PyConcreteAttribute<PyOpaqueAttribute> {
464
476
static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAOpaque;
465
477
static constexpr const char *pyClassName = " OpaqueAttr" ;
466
478
using PyConcreteAttribute::PyConcreteAttribute;
479
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
480
+ mlirOpaqueAttrGetTypeID;
467
481
468
482
static void bindDerived (ClassTy &c) {
469
483
c.def_static (
@@ -501,6 +515,8 @@ class PyStringAttribute : public PyConcreteAttribute<PyStringAttribute> {
501
515
static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAString;
502
516
static constexpr const char *pyClassName = " StringAttr" ;
503
517
using PyConcreteAttribute::PyConcreteAttribute;
518
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
519
+ mlirStringAttrGetTypeID;
504
520
505
521
static void bindDerived (ClassTy &c) {
506
522
c.def_static (
@@ -921,6 +937,8 @@ class PyDictAttribute : public PyConcreteAttribute<PyDictAttribute> {
921
937
static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADictionary;
922
938
static constexpr const char *pyClassName = " DictAttr" ;
923
939
using PyConcreteAttribute::PyConcreteAttribute;
940
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
941
+ mlirDictionaryAttrGetTypeID;
924
942
925
943
intptr_t dunderLen () { return mlirDictionaryAttrGetNumElements (*this ); }
926
944
@@ -1013,6 +1031,8 @@ class PyTypeAttribute : public PyConcreteAttribute<PyTypeAttribute> {
1013
1031
static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAType;
1014
1032
static constexpr const char *pyClassName = " TypeAttr" ;
1015
1033
using PyConcreteAttribute::PyConcreteAttribute;
1034
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
1035
+ mlirTypeAttrGetTypeID;
1016
1036
1017
1037
static void bindDerived (ClassTy &c) {
1018
1038
c.def_static (
@@ -1035,6 +1055,8 @@ class PyUnitAttribute : public PyConcreteAttribute<PyUnitAttribute> {
1035
1055
static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAUnit;
1036
1056
static constexpr const char *pyClassName = " UnitAttr" ;
1037
1057
using PyConcreteAttribute::PyConcreteAttribute;
1058
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
1059
+ mlirUnitAttrGetTypeID;
1038
1060
1039
1061
static void bindDerived (ClassTy &c) {
1040
1062
c.def_static (
@@ -1054,6 +1076,8 @@ class PyStridedLayoutAttribute
1054
1076
static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAStridedLayout;
1055
1077
static constexpr const char *pyClassName = " StridedLayoutAttr" ;
1056
1078
using PyConcreteAttribute::PyConcreteAttribute;
1079
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
1080
+ mlirStridedLayoutAttrGetTypeID;
1057
1081
1058
1082
static void bindDerived (ClassTy &c) {
1059
1083
c.def_static (
@@ -1099,6 +1123,50 @@ class PyStridedLayoutAttribute
1099
1123
}
1100
1124
};
1101
1125
1126
+ py::object denseArrayAttributeCaster (PyAttribute &pyAttribute) {
1127
+ if (PyDenseBoolArrayAttribute::isaFunction (pyAttribute))
1128
+ return py::cast (PyDenseBoolArrayAttribute (pyAttribute));
1129
+ if (PyDenseI8ArrayAttribute::isaFunction (pyAttribute))
1130
+ return py::cast (PyDenseI8ArrayAttribute (pyAttribute));
1131
+ if (PyDenseI16ArrayAttribute::isaFunction (pyAttribute))
1132
+ return py::cast (PyDenseI16ArrayAttribute (pyAttribute));
1133
+ if (PyDenseI32ArrayAttribute::isaFunction (pyAttribute))
1134
+ return py::cast (PyDenseI32ArrayAttribute (pyAttribute));
1135
+ if (PyDenseI64ArrayAttribute::isaFunction (pyAttribute))
1136
+ return py::cast (PyDenseI64ArrayAttribute (pyAttribute));
1137
+ if (PyDenseF32ArrayAttribute::isaFunction (pyAttribute))
1138
+ return py::cast (PyDenseF32ArrayAttribute (pyAttribute));
1139
+ if (PyDenseF64ArrayAttribute::isaFunction (pyAttribute))
1140
+ return py::cast (PyDenseF64ArrayAttribute (pyAttribute));
1141
+ std::string msg =
1142
+ std::string (" Can't cast unknown element type DenseArrayAttr (" ) +
1143
+ std::string (py::repr (py::cast (pyAttribute))) + " )" ;
1144
+ throw py::cast_error (msg);
1145
+ }
1146
+
1147
+ py::object denseIntOrFPElementsAttributeCaster (PyAttribute &pyAttribute) {
1148
+ if (PyDenseFPElementsAttribute::isaFunction (pyAttribute))
1149
+ return py::cast (PyDenseFPElementsAttribute (pyAttribute));
1150
+ if (PyDenseIntElementsAttribute::isaFunction (pyAttribute))
1151
+ return py::cast (PyDenseIntElementsAttribute (pyAttribute));
1152
+ std::string msg =
1153
+ std::string (
1154
+ " Can't cast unknown element type DenseIntOrFPElementsAttr (" ) +
1155
+ std::string (py::repr (py::cast (pyAttribute))) + " )" ;
1156
+ throw py::cast_error (msg);
1157
+ }
1158
+
1159
+ py::object integerOrBoolAttributeCaster (PyAttribute &pyAttribute) {
1160
+ if (PyBoolAttribute::isaFunction (pyAttribute))
1161
+ return py::cast (PyBoolAttribute (pyAttribute));
1162
+ if (PyIntegerAttribute::isaFunction (pyAttribute))
1163
+ return py::cast (PyIntegerAttribute (pyAttribute));
1164
+ std::string msg =
1165
+ std::string (" Can't cast unknown element type DenseArrayAttr (" ) +
1166
+ std::string (py::repr (py::cast (pyAttribute))) + " )" ;
1167
+ throw py::cast_error (msg);
1168
+ }
1169
+
1102
1170
} // namespace
1103
1171
1104
1172
void mlir::python::populateIRAttributes (py::module &m) {
@@ -1118,20 +1186,30 @@ void mlir::python::populateIRAttributes(py::module &m) {
1118
1186
PyDenseF32ArrayAttribute::PyDenseArrayIterator::bind (m);
1119
1187
PyDenseF64ArrayAttribute::bind (m);
1120
1188
PyDenseF64ArrayAttribute::PyDenseArrayIterator::bind (m);
1189
+ PyGlobals::get ().registerTypeCaster (
1190
+ mlirDenseArrayAttrGetTypeID (),
1191
+ pybind11::cpp_function (denseArrayAttributeCaster));
1121
1192
1122
1193
PyArrayAttribute::bind (m);
1123
1194
PyArrayAttribute::PyArrayAttributeIterator::bind (m);
1124
1195
PyBoolAttribute::bind (m);
1125
1196
PyDenseElementsAttribute::bind (m);
1126
1197
PyDenseFPElementsAttribute::bind (m);
1127
1198
PyDenseIntElementsAttribute::bind (m);
1199
+ PyGlobals::get ().registerTypeCaster (
1200
+ mlirDenseIntOrFPElementsAttrGetTypeID (),
1201
+ pybind11::cpp_function (denseIntOrFPElementsAttributeCaster));
1202
+
1128
1203
PyDictAttribute::bind (m);
1129
1204
PyFlatSymbolRefAttribute::bind (m);
1130
1205
PyOpaqueAttribute::bind (m);
1131
1206
PyFloatAttribute::bind (m);
1132
1207
PyIntegerAttribute::bind (m);
1133
1208
PyStringAttribute::bind (m);
1134
1209
PyTypeAttribute::bind (m);
1210
+ PyGlobals::get ().registerTypeCaster (
1211
+ mlirIntegerAttrGetTypeID (),
1212
+ pybind11::cpp_function (integerOrBoolAttributeCaster));
1135
1213
PyUnitAttribute::bind (m);
1136
1214
1137
1215
PyStridedLayoutAttribute::bind (m);
0 commit comments