Skip to content

Commit bfb1ba7

Browse files
committed
[MLIR][python bindings] Add TypeCaster for returning refined types from python APIs
depends on D150839 This diff uses `MlirTypeID` to register `TypeCaster`s (i.e., `[](PyType pyType) -> DerivedTy { return pyType; }`) for all concrete types (i.e., `PyConcrete<...>`) that are then queried for (by `MlirTypeID`) and called in `struct type_caster<MlirType>::cast`. The result is that anywhere an `MlirType mlirType` is returned from a python binding, that `mlirType` is automatically cast to the correct concrete type. For example: ``` c0 = arith.ConstantOp(f32, 0.0) # CHECK: F32Type(f32) print(repr(c0.result.type)) unranked_tensor_type = UnrankedTensorType.get(f32) unranked_tensor = tensor.FromElementsOp(unranked_tensor_type, [c0]).result # CHECK: UnrankedTensorType print(type(unranked_tensor.type).__name__) # CHECK: UnrankedTensorType(tensor<*xf32>) print(repr(unranked_tensor.type)) ``` This functionality immediately extends to typed attributes (i.e., `attr.type`). The diff also implements similar functionality for `mlir_type_subclass`es but in a slightly different way - for such types (which have no cpp corresponding `class` or `struct`) the user must provide a type caster in python (similar to how `AttrBuilder` works) or in cpp as a `py::cpp_function`. Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D150927
1 parent 5310be5 commit bfb1ba7

26 files changed

+460
-75
lines changed

mlir/include/mlir-c/Bindings/Python/Interop.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,23 @@
107107
* delineated). */
108108
#define MLIR_PYTHON_CAPI_FACTORY_ATTR "_CAPICreate"
109109

110+
/** Attribute on MLIR Python objects that expose a function for downcasting the
111+
* corresponding Python object to a subclass if the object is in fact a subclass
112+
* (Concrete or mlir_type_subclass) of ir.Type. The signature of the function
113+
* is: def maybe_downcast(self) -> object where the resulting object will
114+
* (possibly) be an instance of the subclass.
115+
*/
116+
#define MLIR_PYTHON_MAYBE_DOWNCAST_ATTR "maybe_downcast"
117+
118+
/** Attribute on main C extension module (_mlir) that corresponds to the
119+
* type caster registration binding. The signature of the function is:
120+
* def register_type_caster(MlirTypeID mlirTypeID, py::function typeCaster,
121+
* bool replace)
122+
* where replace indicates the typeCaster should replace any existing registered
123+
* type casters (such as those for upstream ConcreteTypes).
124+
*/
125+
#define MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR "register_type_caster"
126+
110127
/// Gets a void* from a wrapped struct. Needed because const cast is different
111128
/// between C/C++.
112129
#ifdef __cplusplus

mlir/include/mlir-c/Dialect/Transform.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ MLIR_CAPI_EXPORTED MlirType mlirTransformAnyOpTypeGet(MlirContext ctx);
3333

3434
MLIR_CAPI_EXPORTED bool mlirTypeIsATransformOperationType(MlirType type);
3535

36+
MLIR_CAPI_EXPORTED MlirTypeID mlirTransformOperationTypeGetTypeID(void);
37+
3638
MLIR_CAPI_EXPORTED MlirType
3739
mlirTransformOperationTypeGet(MlirContext ctx, MlirStringRef operationName);
3840

mlir/include/mlir-c/IR.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -825,6 +825,9 @@ MLIR_CAPI_EXPORTED MlirContext mlirTypeGetContext(MlirType type);
825825
/// Gets the type ID of the type.
826826
MLIR_CAPI_EXPORTED MlirTypeID mlirTypeGetTypeID(MlirType type);
827827

828+
/// Gets the dialect a type belongs to.
829+
MLIR_CAPI_EXPORTED MlirDialect mlirTypeGetDialect(MlirType type);
830+
828831
/// Checks whether a type is null.
829832
static inline bool mlirTypeIsNull(MlirType type) { return !type.ptr; }
830833

mlir/include/mlir/Bindings/Python/PybindAdaptors.h

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
#include "llvm/ADT/Twine.h"
2929

3030
namespace py = pybind11;
31+
using namespace py::literals;
3132

3233
// Raw CAPI type casters need to be declared before use, so always include them
3334
// first.
@@ -272,6 +273,7 @@ struct type_caster<MlirType> {
272273
return py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
273274
.attr("Type")
274275
.attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule)
276+
.attr(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR)()
275277
.release();
276278
}
277279
};
@@ -424,20 +426,24 @@ class mlir_attribute_subclass : public pure_subclass {
424426
class mlir_type_subclass : public pure_subclass {
425427
public:
426428
using IsAFunctionTy = bool (*)(MlirType);
429+
using GetTypeIDFunctionTy = MlirTypeID (*)();
427430

428431
/// Subclasses by looking up the super-class dynamically.
429432
mlir_type_subclass(py::handle scope, const char *typeClassName,
430-
IsAFunctionTy isaFunction)
433+
IsAFunctionTy isaFunction,
434+
GetTypeIDFunctionTy getTypeIDFunction = nullptr)
431435
: mlir_type_subclass(
432436
scope, typeClassName, isaFunction,
433-
py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")).attr("Type")) {}
437+
py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")).attr("Type"),
438+
getTypeIDFunction) {}
434439

435440
/// Subclasses with a provided mlir.ir.Type super-class. This must
436441
/// be used if the subclass is being defined in the same extension module
437442
/// as the mlir.ir class (otherwise, it will trigger a recursive
438443
/// initialization).
439444
mlir_type_subclass(py::handle scope, const char *typeClassName,
440-
IsAFunctionTy isaFunction, const py::object &superCls)
445+
IsAFunctionTy isaFunction, const py::object &superCls,
446+
GetTypeIDFunctionTy getTypeIDFunction = nullptr)
441447
: pure_subclass(scope, typeClassName, superCls) {
442448
// Casting constructor. Note that it hard, if not impossible, to properly
443449
// call chain to parent `__init__` in pybind11 due to its special handling
@@ -471,6 +477,19 @@ class mlir_type_subclass : public pure_subclass {
471477
"isinstance",
472478
[isaFunction](MlirType other) { return isaFunction(other); },
473479
py::arg("other_type"));
480+
def("__repr__", [superCls, captureTypeName](py::object self) {
481+
return py::repr(superCls(self))
482+
.attr("replace")(superCls.attr("__name__"), captureTypeName);
483+
});
484+
if (getTypeIDFunction) {
485+
py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
486+
.attr(MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR)(
487+
getTypeIDFunction(),
488+
pybind11::cpp_function(
489+
[thisClass = thisClass](const py::object &mlirType) {
490+
return thisClass(mlirType);
491+
}));
492+
}
474493
}
475494
};
476495

mlir/include/mlir/CAPI/Support.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,4 +44,25 @@ inline mlir::LogicalResult unwrap(MlirLogicalResult res) {
4444
DEFINE_C_API_METHODS(MlirTypeID, mlir::TypeID)
4545
DEFINE_C_API_PTR_METHODS(MlirTypeIDAllocator, mlir::TypeIDAllocator)
4646

47+
namespace llvm {
48+
49+
template <>
50+
struct DenseMapInfo<MlirTypeID> {
51+
static inline MlirTypeID getEmptyKey() {
52+
auto *pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
53+
return mlirTypeIDCreate(pointer);
54+
}
55+
static inline MlirTypeID getTombstoneKey() {
56+
auto *pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
57+
return mlirTypeIDCreate(pointer);
58+
}
59+
static inline unsigned getHashValue(const MlirTypeID &val) {
60+
return mlirTypeIDHashValue(val);
61+
}
62+
static inline bool isEqual(const MlirTypeID &lhs, const MlirTypeID &rhs) {
63+
return mlirTypeIDEqual(lhs, rhs);
64+
}
65+
};
66+
} // namespace llvm
67+
4768
#endif // MLIR_CAPI_SUPPORT_H

mlir/lib/Bindings/Python/DialectTransform.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,8 @@ void populateDialectTransformSubmodule(const pybind11::module &m) {
3636
//===-------------------------------------------------------------------===//
3737

3838
auto operationType =
39-
mlir_type_subclass(m, "OperationType", mlirTypeIsATransformOperationType);
39+
mlir_type_subclass(m, "OperationType", mlirTypeIsATransformOperationType,
40+
mlirTransformOperationTypeGetTypeID);
4041
operationType.def_classmethod(
4142
"get",
4243
[](py::object cls, const std::string &operationName, MlirContext ctx) {

mlir/lib/Bindings/Python/Globals.h

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,15 @@
99
#ifndef MLIR_BINDINGS_PYTHON_GLOBALS_H
1010
#define MLIR_BINDINGS_PYTHON_GLOBALS_H
1111

12+
#include <optional>
1213
#include <string>
1314
#include <vector>
14-
#include <optional>
1515

1616
#include "PybindUtils.h"
1717

18+
#include "mlir-c/IR.h"
19+
#include "mlir/CAPI/Support.h"
20+
#include "llvm/ADT/DenseMap.h"
1821
#include "llvm/ADT/StringRef.h"
1922
#include "llvm/ADT/StringSet.h"
2023

@@ -54,16 +57,18 @@ class PyGlobals {
5457
/// entities.
5558
void loadDialectModule(llvm::StringRef dialectNamespace);
5659

57-
/// Decorator for registering a custom Dialect class. The class object must
58-
/// have a DIALECT_NAMESPACE attribute.
59-
pybind11::object registerDialectDecorator(pybind11::object pyClass);
60-
6160
/// Adds a user-friendly Attribute builder.
6261
/// Raises an exception if the mapping already exists.
6362
/// This is intended to be called by implementation code.
6463
void registerAttributeBuilder(const std::string &attributeKind,
6564
pybind11::function pyFunc);
6665

66+
/// Adds a user-friendly type caster. Raises an exception if the mapping
67+
/// already exists and replace == false. This is intended to be called by
68+
/// implementation code.
69+
void registerTypeCaster(MlirTypeID mlirTypeID, pybind11::function typeCaster,
70+
bool replace = false);
71+
6772
/// Adds a concrete implementation dialect class.
6873
/// Raises an exception if the mapping already exists.
6974
/// This is intended to be called by implementation code.
@@ -80,6 +85,10 @@ class PyGlobals {
8085
std::optional<pybind11::function>
8186
lookupAttributeBuilder(const std::string &attributeKind);
8287

88+
/// Returns the custom type caster for MlirTypeID mlirTypeID.
89+
std::optional<pybind11::function> lookupTypeCaster(MlirTypeID mlirTypeID,
90+
MlirDialect dialect);
91+
8392
/// Looks up a registered dialect class by namespace. Note that this may
8493
/// trigger loading of the defining module and can arbitrarily re-enter.
8594
std::optional<pybind11::object>
@@ -101,6 +110,10 @@ class PyGlobals {
101110
llvm::StringMap<pybind11::object> operationClassMap;
102111
/// Map of attribute ODS name to custom builder.
103112
llvm::StringMap<pybind11::object> attributeBuilderMap;
113+
/// Map of MlirTypeID to custom type caster.
114+
llvm::DenseMap<MlirTypeID, pybind11::object> typeCasterMap;
115+
/// Cache for map of MlirTypeID to custom type caster.
116+
llvm::DenseMap<MlirTypeID, pybind11::object> typeCasterMapCache;
104117

105118
/// Set of dialect namespaces that we have attempted to import implementation
106119
/// modules for.

mlir/lib/Bindings/Python/IRAttributes.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
#include "mlir-c/BuiltinAttributes.h"
1717
#include "mlir-c/BuiltinTypes.h"
18+
#include "mlir/Bindings/Python/PybindAdaptors.h"
1819

1920
namespace py = pybind11;
2021
using namespace mlir;
@@ -1023,8 +1024,7 @@ class PyTypeAttribute : public PyConcreteAttribute<PyTypeAttribute> {
10231024
py::arg("value"), py::arg("context") = py::none(),
10241025
"Gets a uniqued Type attribute");
10251026
c.def_property_readonly("value", [](PyTypeAttribute &self) {
1026-
return PyType(self.getContext()->getRef(),
1027-
mlirTypeAttrGetValue(self.get()));
1027+
return mlirTypeAttrGetValue(self.get());
10281028
});
10291029
}
10301030
};

mlir/lib/Bindings/Python/IRCore.cpp

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
#include <utility>
2626

2727
namespace py = pybind11;
28+
using namespace py::literals;
2829
using namespace mlir;
2930
using namespace mlir::python;
3031

@@ -2121,13 +2122,12 @@ class PyOpResult : public PyConcreteValue<PyOpResult> {
21212122

21222123
/// Returns the list of types of the values held by container.
21232124
template <typename Container>
2124-
static std::vector<PyType> getValueTypes(Container &container,
2125-
PyMlirContextRef &context) {
2126-
std::vector<PyType> result;
2125+
static std::vector<MlirType> getValueTypes(Container &container,
2126+
PyMlirContextRef &context) {
2127+
std::vector<MlirType> result;
21272128
result.reserve(container.size());
21282129
for (int i = 0, e = container.size(); i < e; ++i) {
2129-
result.push_back(
2130-
PyType(context, mlirValueGetType(container.getElement(i).get())));
2130+
result.push_back(mlirValueGetType(container.getElement(i).get()));
21312131
}
21322132
return result;
21332133
}
@@ -3148,11 +3148,8 @@ void mlir::python::populateIRCore(py::module &m) {
31483148
"context",
31493149
[](PyAttribute &self) { return self.getContext().getObject(); },
31503150
"Context that owns the Attribute")
3151-
.def_property_readonly("type",
3152-
[](PyAttribute &self) {
3153-
return PyType(self.getContext()->getRef(),
3154-
mlirAttributeGetType(self));
3155-
})
3151+
.def_property_readonly(
3152+
"type", [](PyAttribute &self) { return mlirAttributeGetType(self); })
31563153
.def(
31573154
"get_named",
31583155
[](PyAttribute &self, std::string name) {
@@ -3247,7 +3244,7 @@ void mlir::python::populateIRCore(py::module &m) {
32473244
mlirTypeParseGet(context->get(), toMlirStringRef(typeSpec));
32483245
if (mlirTypeIsNull(type))
32493246
throw MLIRError("Unable to parse type", errors.take());
3250-
return PyType(context->getRef(), type);
3247+
return type;
32513248
},
32523249
py::arg("asm"), py::arg("context") = py::none(),
32533250
kContextParseTypeDocstring)
@@ -3284,6 +3281,18 @@ void mlir::python::populateIRCore(py::module &m) {
32843281
printAccum.parts.append(")");
32853282
return printAccum.join();
32863283
})
3284+
.def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR,
3285+
[](PyType &self) {
3286+
MlirTypeID mlirTypeID = mlirTypeGetTypeID(self);
3287+
assert(!mlirTypeIDIsNull(mlirTypeID) &&
3288+
"mlirTypeID was expected to be non-null.");
3289+
std::optional<pybind11::function> typeCaster =
3290+
PyGlobals::get().lookupTypeCaster(mlirTypeID,
3291+
mlirTypeGetDialect(self));
3292+
if (!typeCaster)
3293+
return py::cast(self);
3294+
return typeCaster.value()(self);
3295+
})
32873296
.def_property_readonly("typeid", [](PyType &self) -> MlirTypeID {
32883297
MlirTypeID mlirTypeID = mlirTypeGetTypeID(self);
32893298
if (!mlirTypeIDIsNull(mlirTypeID))
@@ -3387,12 +3396,8 @@ void mlir::python::populateIRCore(py::module &m) {
33873396
return printAccum.join();
33883397
},
33893398
py::arg("use_local_scope") = false, kGetNameAsOperand)
3390-
.def_property_readonly("type",
3391-
[](PyValue &self) {
3392-
return PyType(
3393-
self.getParentOperation()->getContext(),
3394-
mlirValueGetType(self.get()));
3395-
})
3399+
.def_property_readonly(
3400+
"type", [](PyValue &self) { return mlirValueGetType(self.get()); })
33963401
.def(
33973402
"replace_all_uses_with",
33983403
[](PyValue &self, PyValue &with) {

mlir/lib/Bindings/Python/IRInterfaces.cpp

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -321,11 +321,7 @@ class PyShapedTypeComponents {
321321
py::module_local())
322322
.def_property_readonly(
323323
"element_type",
324-
[](PyShapedTypeComponents &self) {
325-
return PyType(PyMlirContext::forContext(
326-
mlirTypeGetContext(self.elementType)),
327-
self.elementType);
328-
},
324+
[](PyShapedTypeComponents &self) { return self.elementType; },
329325
"Returns the element type of the shaped type components.")
330326
.def_static(
331327
"get",

mlir/lib/Bindings/Python/IRModule.cpp

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include <vector>
1515

1616
#include "mlir-c/Bindings/Python/Interop.h"
17+
#include "mlir-c/Support.h"
1718

1819
namespace py = pybind11;
1920
using namespace mlir;
@@ -72,6 +73,15 @@ void PyGlobals::registerAttributeBuilder(const std::string &attributeKind,
7273
found = std::move(pyFunc);
7374
}
7475

76+
void PyGlobals::registerTypeCaster(MlirTypeID mlirTypeID,
77+
pybind11::function typeCaster,
78+
bool replace) {
79+
pybind11::object &found = typeCasterMap[mlirTypeID];
80+
if (found && !found.is_none() && !replace)
81+
throw std::runtime_error("Type caster is already registered");
82+
found = std::move(typeCaster);
83+
}
84+
7585
void PyGlobals::registerDialectImpl(const std::string &dialectNamespace,
7686
py::object pyClass) {
7787
py::object &found = dialectClassMap[dialectNamespace];
@@ -110,6 +120,39 @@ PyGlobals::lookupAttributeBuilder(const std::string &attributeKind) {
110120
return std::nullopt;
111121
}
112122

123+
std::optional<py::function> PyGlobals::lookupTypeCaster(MlirTypeID mlirTypeID,
124+
MlirDialect dialect) {
125+
{
126+
// Fast match against the class map first (common case).
127+
const auto foundIt = typeCasterMapCache.find(mlirTypeID);
128+
if (foundIt != typeCasterMapCache.end()) {
129+
if (foundIt->second.is_none())
130+
return std::nullopt;
131+
assert(foundIt->second && "py::function is defined");
132+
return foundIt->second;
133+
}
134+
}
135+
136+
// Not found. Load the dialect namespace.
137+
loadDialectModule(unwrap(mlirDialectGetNamespace(dialect)));
138+
139+
// Attempt to find from the canonical map and cache.
140+
{
141+
const auto foundIt = typeCasterMap.find(mlirTypeID);
142+
if (foundIt != typeCasterMap.end()) {
143+
if (foundIt->second.is_none())
144+
return std::nullopt;
145+
assert(foundIt->second && "py::object is defined");
146+
// Positive cache.
147+
typeCasterMapCache[mlirTypeID] = foundIt->second;
148+
return foundIt->second;
149+
}
150+
// Negative cache.
151+
typeCasterMap[mlirTypeID] = py::none();
152+
return std::nullopt;
153+
}
154+
}
155+
113156
std::optional<py::object>
114157
PyGlobals::lookupDialectClass(const std::string &dialectNamespace) {
115158
loadDialectModule(dialectNamespace);
@@ -164,4 +207,5 @@ PyGlobals::lookupOperationClass(llvm::StringRef operationName) {
164207
void PyGlobals::clearImportCache() {
165208
loadedDialectModulesCache.clear();
166209
operationClassMapCache.clear();
210+
typeCasterMapCache.clear();
167211
}

0 commit comments

Comments
 (0)