Skip to content

Commit 18cf1cd

Browse files
authored
[mlir] Add PDL C & Python usage (#94714)
Following a rather direct approach to expose PDL usage from C and then Python. This doesn't yes plumb through adding support for custom matchers through this interface, so constrained to basics initially. This also exposes greedy rewrite driver. Only way currently to define patterns is via PDL (just to keep small). The creation of the PDL pattern module could be improved to avoid folks potentially accessing the module used to construct it post construction. No ergonomic work done yet. --------- Signed-off-by: Jacques Pienaar <[email protected]>
1 parent 38ccee0 commit 18cf1cd

File tree

15 files changed

+424
-2
lines changed

15 files changed

+424
-2
lines changed

mlir/include/mlir-c/Bindings/Python/Interop.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
#include "mlir-c/IR.h"
4040
#include "mlir-c/IntegerSet.h"
4141
#include "mlir-c/Pass.h"
42+
#include "mlir-c/Rewrite.h"
4243

4344
// The 'mlir' Python package is relocatable and supports co-existing in multiple
4445
// projects. Each project must define its outer package prefix with this define
@@ -284,6 +285,26 @@ static inline MlirModule mlirPythonCapsuleToModule(PyObject *capsule) {
284285
return module;
285286
}
286287

288+
/** Creates a capsule object encapsulating the raw C-API
289+
* MlirFrozenRewritePatternSet.
290+
* The returned capsule does not extend or affect ownership of any Python
291+
* objects that reference the module in any way. */
292+
static inline PyObject *
293+
mlirPythonFrozenRewritePatternSetToCapsule(MlirFrozenRewritePatternSet pm) {
294+
return PyCapsule_New(MLIR_PYTHON_GET_WRAPPED_POINTER(pm),
295+
MLIR_PYTHON_CAPSULE_PASS_MANAGER, NULL);
296+
}
297+
298+
/** Extracts an MlirFrozenRewritePatternSet from a capsule as produced from
299+
* mlirPythonFrozenRewritePatternSetToCapsule. If the capsule is not of the
300+
* right type, then a null module is returned. */
301+
static inline MlirFrozenRewritePatternSet
302+
mlirPythonCapsuleToFrozenRewritePatternSet(PyObject *capsule) {
303+
void *ptr = PyCapsule_GetPointer(capsule, MLIR_PYTHON_CAPSULE_PASS_MANAGER);
304+
MlirFrozenRewritePatternSet pm = {ptr};
305+
return pm;
306+
}
307+
287308
/** Creates a capsule object encapsulating the raw C-API MlirPassManager.
288309
* The returned capsule does not extend or affect ownership of any Python
289310
* objects that reference the module in any way. */

mlir/include/mlir-c/Rewrite.h

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
//===-- mlir-c/Rewrite.h - Helpers for C API to Rewrites ----------*- C -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM
4+
// Exceptions.
5+
// See https://llvm.org/LICENSE.txt for license information.
6+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7+
//
8+
//===----------------------------------------------------------------------===//
9+
//
10+
// This header declares the registration and creation method for
11+
// rewrite patterns.
12+
//
13+
//===----------------------------------------------------------------------===//
14+
15+
#ifndef MLIR_C_REWRITE_H
16+
#define MLIR_C_REWRITE_H
17+
18+
#include "mlir-c/IR.h"
19+
#include "mlir-c/Support.h"
20+
#include "mlir/Config/mlir-config.h"
21+
22+
//===----------------------------------------------------------------------===//
23+
/// Opaque type declarations (see mlir-c/IR.h for more details).
24+
//===----------------------------------------------------------------------===//
25+
26+
#define DEFINE_C_API_STRUCT(name, storage) \
27+
struct name { \
28+
storage *ptr; \
29+
}; \
30+
typedef struct name name
31+
32+
DEFINE_C_API_STRUCT(MlirFrozenRewritePatternSet, void);
33+
DEFINE_C_API_STRUCT(MlirGreedyRewriteDriverConfig, void);
34+
DEFINE_C_API_STRUCT(MlirRewritePatternSet, void);
35+
36+
MLIR_CAPI_EXPORTED MlirFrozenRewritePatternSet
37+
mlirFreezeRewritePattern(MlirRewritePatternSet op);
38+
39+
MLIR_CAPI_EXPORTED void
40+
mlirFrozenRewritePatternSetDestroy(MlirFrozenRewritePatternSet op);
41+
42+
MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyPatternsAndFoldGreedily(
43+
MlirModule op, MlirFrozenRewritePatternSet patterns,
44+
MlirGreedyRewriteDriverConfig);
45+
46+
#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
47+
DEFINE_C_API_STRUCT(MlirPDLPatternModule, void);
48+
49+
MLIR_CAPI_EXPORTED MlirPDLPatternModule
50+
mlirPDLPatternModuleFromModule(MlirModule op);
51+
52+
MLIR_CAPI_EXPORTED void mlirPDLPatternModuleDestroy(MlirPDLPatternModule op);
53+
54+
MLIR_CAPI_EXPORTED MlirRewritePatternSet
55+
mlirRewritePatternSetFromPDLPatternModule(MlirPDLPatternModule op);
56+
#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
57+
58+
#undef DEFINE_C_API_STRUCT
59+
60+
#endif // MLIR_C_REWRITE_H

mlir/include/mlir/Bindings/Python/PybindAdaptors.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,27 @@ struct type_caster<MlirModule> {
198198
};
199199
};
200200

201+
/// Casts object <-> MlirFrozenRewritePatternSet.
202+
template <>
203+
struct type_caster<MlirFrozenRewritePatternSet> {
204+
PYBIND11_TYPE_CASTER(MlirFrozenRewritePatternSet,
205+
_("MlirFrozenRewritePatternSet"));
206+
bool load(handle src, bool) {
207+
py::object capsule = mlirApiObjectToCapsule(src);
208+
value = mlirPythonCapsuleToFrozenRewritePatternSet(capsule.ptr());
209+
return value.ptr != nullptr;
210+
}
211+
static handle cast(MlirFrozenRewritePatternSet v, return_value_policy,
212+
handle) {
213+
py::object capsule = py::reinterpret_steal<py::object>(
214+
mlirPythonFrozenRewritePatternSetToCapsule(v));
215+
return py::module::import(MAKE_MLIR_PYTHON_QUALNAME("rewrite"))
216+
.attr("FrozenRewritePatternSet")
217+
.attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule)
218+
.release();
219+
};
220+
};
221+
201222
/// Casts object <-> MlirOperation.
202223
template <>
203224
struct type_caster<MlirOperation> {

mlir/lib/Bindings/Python/IRModule.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include "mlir-c/Diagnostics.h"
2323
#include "mlir-c/IR.h"
2424
#include "mlir-c/IntegerSet.h"
25+
#include "mlir-c/Transforms.h"
2526
#include "mlir/Bindings/Python/PybindAdaptors.h"
2627
#include "llvm/ADT/DenseMap.h"
2728

mlir/lib/Bindings/Python/MainModule.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include "Globals.h"
1212
#include "IRModule.h"
1313
#include "Pass.h"
14+
#include "Rewrite.h"
1415

1516
namespace py = pybind11;
1617
using namespace mlir;
@@ -116,6 +117,9 @@ PYBIND11_MODULE(_mlir, m) {
116117
populateIRInterfaces(irModule);
117118
populateIRTypes(irModule);
118119

120+
auto rewriteModule = m.def_submodule("rewrite", "MLIR Rewrite Bindings");
121+
populateRewriteSubmodule(rewriteModule);
122+
119123
// Define and populate PassManager submodule.
120124
auto passModule =
121125
m.def_submodule("passmanager", "MLIR Pass Management Bindings");

mlir/lib/Bindings/Python/Rewrite.cpp

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
//===- Rewrite.cpp - Rewrite ----------------------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "Rewrite.h"
10+
11+
#include "IRModule.h"
12+
#include "mlir-c/Bindings/Python/Interop.h"
13+
#include "mlir-c/Rewrite.h"
14+
#include "mlir/Config/mlir-config.h"
15+
16+
namespace py = pybind11;
17+
using namespace mlir;
18+
using namespace py::literals;
19+
using namespace mlir::python;
20+
21+
namespace {
22+
23+
#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
24+
/// Owning Wrapper around a PDLPatternModule.
25+
class PyPDLPatternModule {
26+
public:
27+
PyPDLPatternModule(MlirPDLPatternModule module) : module(module) {}
28+
PyPDLPatternModule(PyPDLPatternModule &&other) noexcept
29+
: module(other.module) {
30+
other.module.ptr = nullptr;
31+
}
32+
~PyPDLPatternModule() {
33+
if (module.ptr != nullptr)
34+
mlirPDLPatternModuleDestroy(module);
35+
}
36+
MlirPDLPatternModule get() { return module; }
37+
38+
private:
39+
MlirPDLPatternModule module;
40+
};
41+
#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
42+
43+
/// Owning Wrapper around a FrozenRewritePatternSet.
44+
class PyFrozenRewritePatternSet {
45+
public:
46+
PyFrozenRewritePatternSet(MlirFrozenRewritePatternSet set) : set(set) {}
47+
PyFrozenRewritePatternSet(PyFrozenRewritePatternSet &&other) noexcept
48+
: set(other.set) {
49+
other.set.ptr = nullptr;
50+
}
51+
~PyFrozenRewritePatternSet() {
52+
if (set.ptr != nullptr)
53+
mlirFrozenRewritePatternSetDestroy(set);
54+
}
55+
MlirFrozenRewritePatternSet get() { return set; }
56+
57+
pybind11::object getCapsule() {
58+
return py::reinterpret_steal<py::object>(
59+
mlirPythonFrozenRewritePatternSetToCapsule(get()));
60+
}
61+
62+
static pybind11::object createFromCapsule(pybind11::object capsule) {
63+
MlirFrozenRewritePatternSet rawPm =
64+
mlirPythonCapsuleToFrozenRewritePatternSet(capsule.ptr());
65+
if (rawPm.ptr == nullptr)
66+
throw py::error_already_set();
67+
return py::cast(PyFrozenRewritePatternSet(rawPm),
68+
py::return_value_policy::move);
69+
}
70+
71+
private:
72+
MlirFrozenRewritePatternSet set;
73+
};
74+
75+
} // namespace
76+
77+
/// Create the `mlir.rewrite` here.
78+
void mlir::python::populateRewriteSubmodule(py::module &m) {
79+
//----------------------------------------------------------------------------
80+
// Mapping of the top-level PassManager
81+
//----------------------------------------------------------------------------
82+
#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
83+
py::class_<PyPDLPatternModule>(m, "PDLModule", py::module_local())
84+
.def(py::init<>([](MlirModule module) {
85+
return mlirPDLPatternModuleFromModule(module);
86+
}),
87+
"module"_a, "Create a PDL module from the given module.")
88+
.def("freeze", [](PyPDLPatternModule &self) {
89+
return new PyFrozenRewritePatternSet(mlirFreezeRewritePattern(
90+
mlirRewritePatternSetFromPDLPatternModule(self.get())));
91+
});
92+
#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCg
93+
py::class_<PyFrozenRewritePatternSet>(m, "FrozenRewritePatternSet",
94+
py::module_local())
95+
.def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
96+
&PyFrozenRewritePatternSet::getCapsule)
97+
.def(MLIR_PYTHON_CAPI_FACTORY_ATTR,
98+
&PyFrozenRewritePatternSet::createFromCapsule);
99+
m.def(
100+
"apply_patterns_and_fold_greedily",
101+
[](MlirModule module, MlirFrozenRewritePatternSet set) {
102+
auto status = mlirApplyPatternsAndFoldGreedily(module, set, {});
103+
if (mlirLogicalResultIsFailure(status))
104+
// FIXME: Not sure this is the right error to throw here.
105+
throw py::value_error("pattern application failed to converge");
106+
},
107+
"module"_a, "set"_a,
108+
"Applys the given patterns to the given module greedily while folding "
109+
"results.");
110+
}

mlir/lib/Bindings/Python/Rewrite.h

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
//===- Rewrite.h - Rewrite Submodules of pybind module --------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_BINDINGS_PYTHON_REWRITE_H
10+
#define MLIR_BINDINGS_PYTHON_REWRITE_H
11+
12+
#include "PybindUtils.h"
13+
14+
namespace mlir {
15+
namespace python {
16+
17+
void populateRewriteSubmodule(pybind11::module &m);
18+
19+
} // namespace python
20+
} // namespace mlir
21+
22+
#endif // MLIR_BINDINGS_PYTHON_REWRITE_H
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
add_mlir_upstream_c_api_library(MLIRCAPITransforms
22
Passes.cpp
3+
Rewrite.cpp
34

45
LINK_LIBS PUBLIC
6+
MLIRIR
57
MLIRTransforms
8+
MLIRTransformUtils
69
)

mlir/lib/CAPI/Transforms/Rewrite.cpp

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
//===- Rewrite.cpp - C API for Rewrite Patterns ---------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir-c/Rewrite.h"
10+
#include "mlir-c/Transforms.h"
11+
#include "mlir/CAPI/IR.h"
12+
#include "mlir/CAPI/Support.h"
13+
#include "mlir/IR/PatternMatch.h"
14+
#include "mlir/Rewrite/FrozenRewritePatternSet.h"
15+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
16+
17+
using namespace mlir;
18+
19+
inline mlir::RewritePatternSet &unwrap(MlirRewritePatternSet module) {
20+
assert(module.ptr && "unexpected null module");
21+
return *(static_cast<mlir::RewritePatternSet *>(module.ptr));
22+
}
23+
24+
inline MlirRewritePatternSet wrap(mlir::RewritePatternSet *module) {
25+
return {module};
26+
}
27+
28+
inline mlir::FrozenRewritePatternSet *
29+
unwrap(MlirFrozenRewritePatternSet module) {
30+
assert(module.ptr && "unexpected null module");
31+
return static_cast<mlir::FrozenRewritePatternSet *>(module.ptr);
32+
}
33+
34+
inline MlirFrozenRewritePatternSet wrap(mlir::FrozenRewritePatternSet *module) {
35+
return {module};
36+
}
37+
38+
MlirFrozenRewritePatternSet mlirFreezeRewritePattern(MlirRewritePatternSet op) {
39+
auto *m = new mlir::FrozenRewritePatternSet(std::move(unwrap(op)));
40+
op.ptr = nullptr;
41+
return wrap(m);
42+
}
43+
44+
void mlirFrozenRewritePatternSetDestroy(MlirFrozenRewritePatternSet op) {
45+
delete unwrap(op);
46+
op.ptr = nullptr;
47+
}
48+
49+
MlirLogicalResult
50+
mlirApplyPatternsAndFoldGreedily(MlirModule op,
51+
MlirFrozenRewritePatternSet patterns,
52+
MlirGreedyRewriteDriverConfig) {
53+
return wrap(
54+
mlir::applyPatternsAndFoldGreedily(unwrap(op), *unwrap(patterns)));
55+
}
56+
57+
#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
58+
inline mlir::PDLPatternModule *unwrap(MlirPDLPatternModule module) {
59+
assert(module.ptr && "unexpected null module");
60+
return static_cast<mlir::PDLPatternModule *>(module.ptr);
61+
}
62+
63+
inline MlirPDLPatternModule wrap(mlir::PDLPatternModule *module) {
64+
return {module};
65+
}
66+
67+
MlirPDLPatternModule mlirPDLPatternModuleFromModule(MlirModule op) {
68+
return wrap(new mlir::PDLPatternModule(
69+
mlir::OwningOpRef<mlir::ModuleOp>(unwrap(op))));
70+
}
71+
72+
void mlirPDLPatternModuleDestroy(MlirPDLPatternModule op) {
73+
delete unwrap(op);
74+
op.ptr = nullptr;
75+
}
76+
77+
MlirRewritePatternSet
78+
mlirRewritePatternSetFromPDLPatternModule(MlirPDLPatternModule op) {
79+
auto *m = new mlir::RewritePatternSet(std::move(*unwrap(op)));
80+
op.ptr = nullptr;
81+
return wrap(m);
82+
}
83+
#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH

mlir/python/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ declare_mlir_python_sources(MLIRPythonSources.Core.Python
2121
_mlir_libs/__init__.py
2222
ir.py
2323
passmanager.py
24+
rewrite.py
2425
dialects/_ods_common.py
2526

2627
# The main _mlir module has submodules: include stubs from each.
@@ -448,6 +449,7 @@ declare_mlir_python_extension(MLIRPythonExtension.Core
448449
IRModule.cpp
449450
IRTypes.cpp
450451
Pass.cpp
452+
Rewrite.cpp
451453

452454
# Headers must be included explicitly so they are installed.
453455
Globals.h

0 commit comments

Comments
 (0)