Skip to content

Commit 9566ee2

Browse files
committed
[MLIR][python bindings] TypeCasters for Attributes
Differential Revision: https://reviews.llvm.org/D151840
1 parent 31fbfa5 commit 9566ee2

File tree

9 files changed

+288
-36
lines changed

9 files changed

+288
-36
lines changed

mlir/include/mlir-c/BuiltinAttributes.h

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,9 @@ MLIR_CAPI_EXPORTED MlirAttribute mlirAffineMapAttrGet(MlirAffineMap map);
4545
/// Returns the affine map wrapped in the given affine map attribute.
4646
MLIR_CAPI_EXPORTED MlirAffineMap mlirAffineMapAttrGetValue(MlirAttribute attr);
4747

48+
/// Returns the typeID of an AffineMap attribute.
49+
MLIR_CAPI_EXPORTED MlirTypeID mlirAffineMapAttrGetTypeID(void);
50+
4851
//===----------------------------------------------------------------------===//
4952
// Array attribute.
5053
//===----------------------------------------------------------------------===//
@@ -64,6 +67,9 @@ MLIR_CAPI_EXPORTED intptr_t mlirArrayAttrGetNumElements(MlirAttribute attr);
6467
MLIR_CAPI_EXPORTED MlirAttribute mlirArrayAttrGetElement(MlirAttribute attr,
6568
intptr_t pos);
6669

70+
/// Returns the typeID of an Array attribute.
71+
MLIR_CAPI_EXPORTED MlirTypeID mlirArrayAttrGetTypeID(void);
72+
6773
//===----------------------------------------------------------------------===//
6874
// Dictionary attribute.
6975
//===----------------------------------------------------------------------===//
@@ -89,6 +95,9 @@ mlirDictionaryAttrGetElement(MlirAttribute attr, intptr_t pos);
8995
MLIR_CAPI_EXPORTED MlirAttribute
9096
mlirDictionaryAttrGetElementByName(MlirAttribute attr, MlirStringRef name);
9197

98+
/// Returns the typeID of a Dictionary attribute.
99+
MLIR_CAPI_EXPORTED MlirTypeID mlirDictionaryAttrGetTypeID(void);
100+
92101
//===----------------------------------------------------------------------===//
93102
// Floating point attribute.
94103
//===----------------------------------------------------------------------===//
@@ -115,6 +124,9 @@ MLIR_CAPI_EXPORTED MlirAttribute mlirFloatAttrDoubleGetChecked(MlirLocation loc,
115124
/// the value as double.
116125
MLIR_CAPI_EXPORTED double mlirFloatAttrGetValueDouble(MlirAttribute attr);
117126

127+
/// Returns the typeID of a Float attribute.
128+
MLIR_CAPI_EXPORTED MlirTypeID mlirFloatAttrGetTypeID(void);
129+
118130
//===----------------------------------------------------------------------===//
119131
// Integer attribute.
120132
//===----------------------------------------------------------------------===//
@@ -142,6 +154,9 @@ MLIR_CAPI_EXPORTED int64_t mlirIntegerAttrGetValueSInt(MlirAttribute attr);
142154
/// is of unsigned type and fits into an unsigned 64-bit integer.
143155
MLIR_CAPI_EXPORTED uint64_t mlirIntegerAttrGetValueUInt(MlirAttribute attr);
144156

157+
/// Returns the typeID of an Integer attribute.
158+
MLIR_CAPI_EXPORTED MlirTypeID mlirIntegerAttrGetTypeID(void);
159+
145160
//===----------------------------------------------------------------------===//
146161
// Bool attribute.
147162
//===----------------------------------------------------------------------===//
@@ -162,6 +177,9 @@ MLIR_CAPI_EXPORTED bool mlirBoolAttrGetValue(MlirAttribute attr);
162177
/// Checks whether the given attribute is an integer set attribute.
163178
MLIR_CAPI_EXPORTED bool mlirAttributeIsAIntegerSet(MlirAttribute attr);
164179

180+
/// Returns the typeID of an IntegerSet attribute.
181+
MLIR_CAPI_EXPORTED MlirTypeID mlirIntegerSetAttrGetTypeID(void);
182+
165183
//===----------------------------------------------------------------------===//
166184
// Opaque attribute.
167185
//===----------------------------------------------------------------------===//
@@ -185,6 +203,9 @@ mlirOpaqueAttrGetDialectNamespace(MlirAttribute attr);
185203
/// the context in which the attribute lives.
186204
MLIR_CAPI_EXPORTED MlirStringRef mlirOpaqueAttrGetData(MlirAttribute attr);
187205

206+
/// Returns the typeID of an Opaque attribute.
207+
MLIR_CAPI_EXPORTED MlirTypeID mlirOpaqueAttrGetTypeID(void);
208+
188209
//===----------------------------------------------------------------------===//
189210
// String attribute.
190211
//===----------------------------------------------------------------------===//
@@ -206,6 +227,9 @@ MLIR_CAPI_EXPORTED MlirAttribute mlirStringAttrTypedGet(MlirType type,
206227
/// long as the context in which the attribute lives.
207228
MLIR_CAPI_EXPORTED MlirStringRef mlirStringAttrGetValue(MlirAttribute attr);
208229

230+
/// Returns the typeID of a String attribute.
231+
MLIR_CAPI_EXPORTED MlirTypeID mlirStringAttrGetTypeID(void);
232+
209233
//===----------------------------------------------------------------------===//
210234
// SymbolRef attribute.
211235
//===----------------------------------------------------------------------===//
@@ -239,6 +263,9 @@ mlirSymbolRefAttrGetNumNestedReferences(MlirAttribute attr);
239263
MLIR_CAPI_EXPORTED MlirAttribute
240264
mlirSymbolRefAttrGetNestedReference(MlirAttribute attr, intptr_t pos);
241265

266+
/// Returns the typeID of an SymbolRef attribute.
267+
MLIR_CAPI_EXPORTED MlirTypeID mlirSymbolRefAttrGetTypeID(void);
268+
242269
//===----------------------------------------------------------------------===//
243270
// Flat SymbolRef attribute.
244271
//===----------------------------------------------------------------------===//
@@ -256,6 +283,9 @@ MLIR_CAPI_EXPORTED MlirAttribute mlirFlatSymbolRefAttrGet(MlirContext ctx,
256283
MLIR_CAPI_EXPORTED MlirStringRef
257284
mlirFlatSymbolRefAttrGetValue(MlirAttribute attr);
258285

286+
/// Returns the typeID of an FlatSymbolRef attribute.
287+
MLIR_CAPI_EXPORTED MlirTypeID mlirFlatSymbolRefAttrGetTypeID(void);
288+
259289
//===----------------------------------------------------------------------===//
260290
// Type attribute.
261291
//===----------------------------------------------------------------------===//
@@ -270,6 +300,9 @@ MLIR_CAPI_EXPORTED MlirAttribute mlirTypeAttrGet(MlirType type);
270300
/// Returns the type stored in the given type attribute.
271301
MLIR_CAPI_EXPORTED MlirType mlirTypeAttrGetValue(MlirAttribute attr);
272302

303+
/// Returns the typeID of a Type attribute.
304+
MLIR_CAPI_EXPORTED MlirTypeID mlirTypeAttrGetTypeID(void);
305+
273306
//===----------------------------------------------------------------------===//
274307
// Unit attribute.
275308
//===----------------------------------------------------------------------===//
@@ -280,6 +313,9 @@ MLIR_CAPI_EXPORTED bool mlirAttributeIsAUnit(MlirAttribute attr);
280313
/// Creates a unit attribute in the given context.
281314
MLIR_CAPI_EXPORTED MlirAttribute mlirUnitAttrGet(MlirContext ctx);
282315

316+
/// Returns the typeID of a Unit attribute.
317+
MLIR_CAPI_EXPORTED MlirTypeID mlirUnitAttrGetTypeID(void);
318+
283319
//===----------------------------------------------------------------------===//
284320
// Elements attributes.
285321
//===----------------------------------------------------------------------===//
@@ -306,6 +342,8 @@ MLIR_CAPI_EXPORTED int64_t mlirElementsAttrGetNumElements(MlirAttribute attr);
306342
// Dense array attribute.
307343
//===----------------------------------------------------------------------===//
308344

345+
MLIR_CAPI_EXPORTED MlirTypeID mlirDenseArrayAttrGetTypeID(void);
346+
309347
/// Checks whether the given attribute is a dense array attribute.
310348
MLIR_CAPI_EXPORTED bool mlirAttributeIsADenseBoolArray(MlirAttribute attr);
311349
MLIR_CAPI_EXPORTED bool mlirAttributeIsADenseI8Array(MlirAttribute attr);
@@ -370,6 +408,9 @@ MLIR_CAPI_EXPORTED bool mlirAttributeIsADenseElements(MlirAttribute attr);
370408
MLIR_CAPI_EXPORTED bool mlirAttributeIsADenseIntElements(MlirAttribute attr);
371409
MLIR_CAPI_EXPORTED bool mlirAttributeIsADenseFPElements(MlirAttribute attr);
372410

411+
/// Returns the typeID of an DenseIntOrFPElements attribute.
412+
MLIR_CAPI_EXPORTED MlirTypeID mlirDenseIntOrFPElementsAttrGetTypeID(void);
413+
373414
/// Creates a dense elements attribute with the given Shaped type and elements
374415
/// in the same context as the type.
375416
MLIR_CAPI_EXPORTED MlirAttribute mlirDenseElementsAttrGet(
@@ -612,6 +653,9 @@ mlirSparseElementsAttrGetIndices(MlirAttribute attr);
612653
MLIR_CAPI_EXPORTED MlirAttribute
613654
mlirSparseElementsAttrGetValues(MlirAttribute attr);
614655

656+
/// Returns the typeID of a SparseElements attribute.
657+
MLIR_CAPI_EXPORTED MlirTypeID mlirSparseElementsAttrGetTypeID(void);
658+
615659
//===----------------------------------------------------------------------===//
616660
// Strided layout attribute.
617661
//===----------------------------------------------------------------------===//
@@ -635,6 +679,9 @@ mlirStridedLayoutAttrGetNumStrides(MlirAttribute attr);
635679
MLIR_CAPI_EXPORTED int64_t mlirStridedLayoutAttrGetStride(MlirAttribute attr,
636680
intptr_t pos);
637681

682+
/// Returns the typeID of a StridedLayout attribute.
683+
MLIR_CAPI_EXPORTED MlirTypeID mlirStridedLayoutAttrGetTypeID(void);
684+
638685
#ifdef __cplusplus
639686
}
640687
#endif

mlir/include/mlir-c/IR.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -860,6 +860,9 @@ MLIR_CAPI_EXPORTED MlirType mlirAttributeGetType(MlirAttribute attribute);
860860
/// Gets the type id of the attribute.
861861
MLIR_CAPI_EXPORTED MlirTypeID mlirAttributeGetTypeID(MlirAttribute attribute);
862862

863+
/// Gets the dialect of the attribute.
864+
MLIR_CAPI_EXPORTED MlirDialect mlirAttributeGetDialect(MlirAttribute attribute);
865+
863866
/// Checks whether an attribute is null.
864867
static inline bool mlirAttributeIsNull(MlirAttribute attr) { return !attr.ptr; }
865868

mlir/include/mlir/Bindings/Python/PybindAdaptors.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ struct type_caster<MlirAttribute> {
9797
return py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
9898
.attr("Attribute")
9999
.attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule)
100+
.attr(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR)()
100101
.release();
101102
}
102103
};

mlir/lib/Bindings/Python/IRAttributes.cpp

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,8 @@ class PyAffineMapAttribute : public PyConcreteAttribute<PyAffineMapAttribute> {
8080
static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAAffineMap;
8181
static constexpr const char *pyClassName = "AffineMapAttr";
8282
using PyConcreteAttribute::PyConcreteAttribute;
83+
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
84+
mlirAffineMapAttrGetTypeID;
8385

8486
static void bindDerived(ClassTy &c) {
8587
c.def_static(
@@ -259,6 +261,8 @@ class PyArrayAttribute : public PyConcreteAttribute<PyArrayAttribute> {
259261
static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAArray;
260262
static constexpr const char *pyClassName = "ArrayAttr";
261263
using PyConcreteAttribute::PyConcreteAttribute;
264+
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
265+
mlirArrayAttrGetTypeID;
262266

263267
class PyArrayAttributeIterator {
264268
public:
@@ -339,6 +343,8 @@ class PyFloatAttribute : public PyConcreteAttribute<PyFloatAttribute> {
339343
static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFloat;
340344
static constexpr const char *pyClassName = "FloatAttr";
341345
using PyConcreteAttribute::PyConcreteAttribute;
346+
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
347+
mlirFloatAttrGetTypeID;
342348

343349
static void bindDerived(ClassTy &c) {
344350
c.def_static(
@@ -406,6 +412,10 @@ class PyIntegerAttribute : public PyConcreteAttribute<PyIntegerAttribute> {
406412
return mlirIntegerAttrGetValueUInt(self);
407413
},
408414
"Returns the value of the integer attribute");
415+
c.def_property_readonly_static("static_typeid",
416+
[](py::object & /*class*/) -> MlirTypeID {
417+
return mlirIntegerAttrGetTypeID();
418+
});
409419
}
410420
};
411421

@@ -438,6 +448,8 @@ class PyFlatSymbolRefAttribute
438448
static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFlatSymbolRef;
439449
static constexpr const char *pyClassName = "FlatSymbolRefAttr";
440450
using PyConcreteAttribute::PyConcreteAttribute;
451+
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
452+
mlirFlatSymbolRefAttrGetTypeID;
441453

442454
static void bindDerived(ClassTy &c) {
443455
c.def_static(
@@ -464,6 +476,8 @@ class PyOpaqueAttribute : public PyConcreteAttribute<PyOpaqueAttribute> {
464476
static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAOpaque;
465477
static constexpr const char *pyClassName = "OpaqueAttr";
466478
using PyConcreteAttribute::PyConcreteAttribute;
479+
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
480+
mlirOpaqueAttrGetTypeID;
467481

468482
static void bindDerived(ClassTy &c) {
469483
c.def_static(
@@ -501,6 +515,8 @@ class PyStringAttribute : public PyConcreteAttribute<PyStringAttribute> {
501515
static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAString;
502516
static constexpr const char *pyClassName = "StringAttr";
503517
using PyConcreteAttribute::PyConcreteAttribute;
518+
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
519+
mlirStringAttrGetTypeID;
504520

505521
static void bindDerived(ClassTy &c) {
506522
c.def_static(
@@ -921,6 +937,8 @@ class PyDictAttribute : public PyConcreteAttribute<PyDictAttribute> {
921937
static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADictionary;
922938
static constexpr const char *pyClassName = "DictAttr";
923939
using PyConcreteAttribute::PyConcreteAttribute;
940+
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
941+
mlirDictionaryAttrGetTypeID;
924942

925943
intptr_t dunderLen() { return mlirDictionaryAttrGetNumElements(*this); }
926944

@@ -1013,6 +1031,8 @@ class PyTypeAttribute : public PyConcreteAttribute<PyTypeAttribute> {
10131031
static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAType;
10141032
static constexpr const char *pyClassName = "TypeAttr";
10151033
using PyConcreteAttribute::PyConcreteAttribute;
1034+
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
1035+
mlirTypeAttrGetTypeID;
10161036

10171037
static void bindDerived(ClassTy &c) {
10181038
c.def_static(
@@ -1035,6 +1055,8 @@ class PyUnitAttribute : public PyConcreteAttribute<PyUnitAttribute> {
10351055
static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAUnit;
10361056
static constexpr const char *pyClassName = "UnitAttr";
10371057
using PyConcreteAttribute::PyConcreteAttribute;
1058+
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
1059+
mlirUnitAttrGetTypeID;
10381060

10391061
static void bindDerived(ClassTy &c) {
10401062
c.def_static(
@@ -1054,6 +1076,8 @@ class PyStridedLayoutAttribute
10541076
static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAStridedLayout;
10551077
static constexpr const char *pyClassName = "StridedLayoutAttr";
10561078
using PyConcreteAttribute::PyConcreteAttribute;
1079+
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
1080+
mlirStridedLayoutAttrGetTypeID;
10571081

10581082
static void bindDerived(ClassTy &c) {
10591083
c.def_static(
@@ -1099,6 +1123,50 @@ class PyStridedLayoutAttribute
10991123
}
11001124
};
11011125

1126+
py::object denseArrayAttributeCaster(PyAttribute &pyAttribute) {
1127+
if (PyDenseBoolArrayAttribute::isaFunction(pyAttribute))
1128+
return py::cast(PyDenseBoolArrayAttribute(pyAttribute));
1129+
if (PyDenseI8ArrayAttribute::isaFunction(pyAttribute))
1130+
return py::cast(PyDenseI8ArrayAttribute(pyAttribute));
1131+
if (PyDenseI16ArrayAttribute::isaFunction(pyAttribute))
1132+
return py::cast(PyDenseI16ArrayAttribute(pyAttribute));
1133+
if (PyDenseI32ArrayAttribute::isaFunction(pyAttribute))
1134+
return py::cast(PyDenseI32ArrayAttribute(pyAttribute));
1135+
if (PyDenseI64ArrayAttribute::isaFunction(pyAttribute))
1136+
return py::cast(PyDenseI64ArrayAttribute(pyAttribute));
1137+
if (PyDenseF32ArrayAttribute::isaFunction(pyAttribute))
1138+
return py::cast(PyDenseF32ArrayAttribute(pyAttribute));
1139+
if (PyDenseF64ArrayAttribute::isaFunction(pyAttribute))
1140+
return py::cast(PyDenseF64ArrayAttribute(pyAttribute));
1141+
std::string msg =
1142+
std::string("Can't cast unknown element type DenseArrayAttr (") +
1143+
std::string(py::repr(py::cast(pyAttribute))) + ")";
1144+
throw py::cast_error(msg);
1145+
}
1146+
1147+
py::object denseIntOrFPElementsAttributeCaster(PyAttribute &pyAttribute) {
1148+
if (PyDenseFPElementsAttribute::isaFunction(pyAttribute))
1149+
return py::cast(PyDenseFPElementsAttribute(pyAttribute));
1150+
if (PyDenseIntElementsAttribute::isaFunction(pyAttribute))
1151+
return py::cast(PyDenseIntElementsAttribute(pyAttribute));
1152+
std::string msg =
1153+
std::string(
1154+
"Can't cast unknown element type DenseIntOrFPElementsAttr (") +
1155+
std::string(py::repr(py::cast(pyAttribute))) + ")";
1156+
throw py::cast_error(msg);
1157+
}
1158+
1159+
py::object integerOrBoolAttributeCaster(PyAttribute &pyAttribute) {
1160+
if (PyBoolAttribute::isaFunction(pyAttribute))
1161+
return py::cast(PyBoolAttribute(pyAttribute));
1162+
if (PyIntegerAttribute::isaFunction(pyAttribute))
1163+
return py::cast(PyIntegerAttribute(pyAttribute));
1164+
std::string msg =
1165+
std::string("Can't cast unknown element type DenseArrayAttr (") +
1166+
std::string(py::repr(py::cast(pyAttribute))) + ")";
1167+
throw py::cast_error(msg);
1168+
}
1169+
11021170
} // namespace
11031171

11041172
void mlir::python::populateIRAttributes(py::module &m) {
@@ -1118,20 +1186,30 @@ void mlir::python::populateIRAttributes(py::module &m) {
11181186
PyDenseF32ArrayAttribute::PyDenseArrayIterator::bind(m);
11191187
PyDenseF64ArrayAttribute::bind(m);
11201188
PyDenseF64ArrayAttribute::PyDenseArrayIterator::bind(m);
1189+
PyGlobals::get().registerTypeCaster(
1190+
mlirDenseArrayAttrGetTypeID(),
1191+
pybind11::cpp_function(denseArrayAttributeCaster));
11211192

11221193
PyArrayAttribute::bind(m);
11231194
PyArrayAttribute::PyArrayAttributeIterator::bind(m);
11241195
PyBoolAttribute::bind(m);
11251196
PyDenseElementsAttribute::bind(m);
11261197
PyDenseFPElementsAttribute::bind(m);
11271198
PyDenseIntElementsAttribute::bind(m);
1199+
PyGlobals::get().registerTypeCaster(
1200+
mlirDenseIntOrFPElementsAttrGetTypeID(),
1201+
pybind11::cpp_function(denseIntOrFPElementsAttributeCaster));
1202+
11281203
PyDictAttribute::bind(m);
11291204
PyFlatSymbolRefAttribute::bind(m);
11301205
PyOpaqueAttribute::bind(m);
11311206
PyFloatAttribute::bind(m);
11321207
PyIntegerAttribute::bind(m);
11331208
PyStringAttribute::bind(m);
11341209
PyTypeAttribute::bind(m);
1210+
PyGlobals::get().registerTypeCaster(
1211+
mlirIntegerAttrGetTypeID(),
1212+
pybind11::cpp_function(integerOrBoolAttributeCaster));
11351213
PyUnitAttribute::bind(m);
11361214

11371215
PyStridedLayoutAttribute::bind(m);

0 commit comments

Comments
 (0)