Skip to content

Commit 896fd00

Browse files
committed
[mlir] Add PDL C & Python usage
Following a rather direct approach to expose PDL usage from C and then Python. This doesn't yes plumb through adding support for custom matchers through this interface, so constrained to basics initially. Signed-off-by: Jacques Pienaar <[email protected]>
1 parent 0bc33f4 commit 896fd00

File tree

14 files changed

+416
-1
lines changed

14 files changed

+416
-1
lines changed

mlir/include/mlir-c/Bindings/Python/Interop.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
#include "mlir-c/IR.h"
4040
#include "mlir-c/IntegerSet.h"
4141
#include "mlir-c/Pass.h"
42+
#include "mlir-c/Rewrite.h"
4243

4344
// The 'mlir' Python package is relocatable and supports co-existing in multiple
4445
// projects. Each project must define its outer package prefix with this define
@@ -284,6 +285,26 @@ static inline MlirModule mlirPythonCapsuleToModule(PyObject *capsule) {
284285
return module;
285286
}
286287

288+
/** Creates a capsule object encapsulating the raw C-API
289+
* MlirFrozenRewritePatternSet.
290+
* The returned capsule does not extend or affect ownership of any Python
291+
* objects that reference the module in any way. */
292+
static inline PyObject *
293+
mlirPythonFrozenRewritePatternSetToCapsule(MlirFrozenRewritePatternSet pm) {
294+
return PyCapsule_New(MLIR_PYTHON_GET_WRAPPED_POINTER(pm),
295+
MLIR_PYTHON_CAPSULE_PASS_MANAGER, NULL);
296+
}
297+
298+
/** Extracts an MlirFrozenRewritePatternSet from a capsule as produced from
299+
* mlirPythonFrozenRewritePatternSetToCapsule. If the capsule is not of the
300+
* right type, then a null module is returned. */
301+
static inline MlirFrozenRewritePatternSet
302+
mlirPythonCapsuleToFrozenRewritePatternSet(PyObject *capsule) {
303+
void *ptr = PyCapsule_GetPointer(capsule, MLIR_PYTHON_CAPSULE_PASS_MANAGER);
304+
MlirFrozenRewritePatternSet pm = {ptr};
305+
return pm;
306+
}
307+
287308
/** Creates a capsule object encapsulating the raw C-API MlirPassManager.
288309
* The returned capsule does not extend or affect ownership of any Python
289310
* objects that reference the module in any way. */

mlir/include/mlir-c/Rewrite.h

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
//===-- mlir-c/Rewrite.h - Helpers for C API to Rewrites ----------*- C -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM
4+
// Exceptions.
5+
// See https://llvm.org/LICENSE.txt for license information.
6+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7+
//
8+
//===----------------------------------------------------------------------===//
9+
//
10+
// This header declares the registration and creation method for
11+
// rewrite patterns.
12+
//
13+
//===----------------------------------------------------------------------===//
14+
15+
#ifndef MLIR_C_REWRITE_H
16+
#define MLIR_C_REWRITE_H
17+
18+
#include "mlir-c/IR.h"
19+
#include "mlir-c/Support.h"
20+
#include "mlir/Config/mlir-config.h"
21+
22+
//===----------------------------------------------------------------------===//
23+
/// Opaque type declarations (see mlir-c/IR.h for more details).
24+
//===----------------------------------------------------------------------===//
25+
26+
#define DEFINE_C_API_STRUCT(name, storage) \
27+
struct name { \
28+
storage *ptr; \
29+
}; \
30+
typedef struct name name
31+
32+
DEFINE_C_API_STRUCT(MlirFrozenRewritePatternSet, void);
33+
DEFINE_C_API_STRUCT(MlirGreedyRewriteDriverConfig, void);
34+
DEFINE_C_API_STRUCT(MlirRewritePatternSet, void);
35+
36+
MLIR_CAPI_EXPORTED MlirFrozenRewritePatternSet
37+
mlirFreezeRewritePattern(MlirRewritePatternSet op);
38+
39+
MLIR_CAPI_EXPORTED void
40+
mlirFrozenRewritePatternSetDestroy(MlirFrozenRewritePatternSet op);
41+
42+
MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyPatternsAndFoldGreedily(
43+
MlirModule op, MlirFrozenRewritePatternSet patterns,
44+
MlirGreedyRewriteDriverConfig);
45+
46+
#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
47+
DEFINE_C_API_STRUCT(MlirPDLPatternModule, void);
48+
49+
MLIR_CAPI_EXPORTED MlirPDLPatternModule
50+
mlirPDLPatternModuleFromModule(MlirModule op);
51+
52+
MLIR_CAPI_EXPORTED void mlirPDLPatternModuleDestroy(MlirPDLPatternModule op);
53+
54+
MLIR_CAPI_EXPORTED MlirRewritePatternSet
55+
mlirRewritePatternSetFromPDLPatternModule(MlirPDLPatternModule op);
56+
#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
57+
58+
#undef DEFINE_C_API_STRUCT
59+
60+
#endif // MLIR_C_REWRITE_H

mlir/include/mlir/Bindings/Python/PybindAdaptors.h

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,26 @@ struct type_caster<MlirModule> {
198198
};
199199
};
200200

201+
/// Casts object <-> MlirFrozenRewritePatternSet.
202+
template <> struct type_caster<MlirFrozenRewritePatternSet> {
203+
PYBIND11_TYPE_CASTER(MlirFrozenRewritePatternSet,
204+
_("MlirFrozenRewritePatternSet"));
205+
bool load(handle src, bool) {
206+
py::object capsule = mlirApiObjectToCapsule(src);
207+
value = mlirPythonCapsuleToFrozenRewritePatternSet(capsule.ptr());
208+
return true; // !mlirModuleIsNull(value);
209+
}
210+
static handle cast(MlirFrozenRewritePatternSet v, return_value_policy,
211+
handle) {
212+
py::object capsule = py::reinterpret_steal<py::object>(
213+
mlirPythonFrozenRewritePatternSetToCapsule(v));
214+
return py::module::import(MAKE_MLIR_PYTHON_QUALNAME("rewrite"))
215+
.attr("FrozenRewritePatternSet")
216+
.attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule)
217+
.release();
218+
};
219+
};
220+
201221
/// Casts object <-> MlirOperation.
202222
template <>
203223
struct type_caster<MlirOperation> {

mlir/lib/Bindings/Python/IRModule.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include "mlir-c/Diagnostics.h"
2323
#include "mlir-c/IR.h"
2424
#include "mlir-c/IntegerSet.h"
25+
#include "mlir-c/Transforms.h"
2526
#include "mlir/Bindings/Python/PybindAdaptors.h"
2627
#include "llvm/ADT/DenseMap.h"
2728

mlir/lib/Bindings/Python/MainModule.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include "Globals.h"
1212
#include "IRModule.h"
1313
#include "Pass.h"
14+
#include "Rewrite.h"
1415

1516
namespace py = pybind11;
1617
using namespace mlir;
@@ -116,6 +117,9 @@ PYBIND11_MODULE(_mlir, m) {
116117
populateIRInterfaces(irModule);
117118
populateIRTypes(irModule);
118119

120+
auto rewriteModule = m.def_submodule("rewrite", "MLIR Rewrite Bindings");
121+
populateRewriteSubmodule(rewriteModule);
122+
119123
// Define and populate PassManager submodule.
120124
auto passModule =
121125
m.def_submodule("passmanager", "MLIR Pass Management Bindings");

mlir/lib/Bindings/Python/Rewrite.cpp

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
//===- Rewrite.cpp - Rewrite ----------------------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "Rewrite.h"
10+
11+
#include "IRModule.h"
12+
#include "mlir-c/Bindings/Python/Interop.h"
13+
#include "mlir-c/Rewrite.h"
14+
#include "mlir/Config/mlir-config.h"
15+
16+
namespace py = pybind11;
17+
using namespace mlir;
18+
using namespace py::literals;
19+
using namespace mlir::python;
20+
21+
namespace {
22+
23+
#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
24+
/// Owning Wrapper around a PDLPatternModule.
25+
class PyPDLPatternModule {
26+
public:
27+
PyPDLPatternModule(MlirPDLPatternModule module) : module(module) {}
28+
PyPDLPatternModule(PyPDLPatternModule &&other) noexcept
29+
: module(other.module) {
30+
other.module.ptr = nullptr;
31+
}
32+
~PyPDLPatternModule() {
33+
if (module.ptr != nullptr)
34+
mlirPDLPatternModuleDestroy(module);
35+
}
36+
MlirPDLPatternModule get() { return module; }
37+
38+
private:
39+
MlirPDLPatternModule module;
40+
};
41+
#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
42+
43+
/// Owning Wrapper around a FrozenRewritePatternSet.
44+
class PyFrozenRewritePatternSet {
45+
public:
46+
PyFrozenRewritePatternSet(MlirFrozenRewritePatternSet set) : set(set) {}
47+
PyFrozenRewritePatternSet(PyFrozenRewritePatternSet &&other) noexcept
48+
: set(other.set) {
49+
other.set.ptr = nullptr;
50+
}
51+
~PyFrozenRewritePatternSet() {
52+
if (set.ptr != nullptr)
53+
mlirFrozenRewritePatternSetDestroy(set);
54+
}
55+
MlirFrozenRewritePatternSet get() { return set; }
56+
57+
pybind11::object getCapsule() {
58+
return py::reinterpret_steal<py::object>(
59+
mlirPythonFrozenRewritePatternSetToCapsule(get()));
60+
}
61+
62+
static pybind11::object createFromCapsule(pybind11::object capsule) {
63+
MlirFrozenRewritePatternSet rawPm =
64+
mlirPythonCapsuleToFrozenRewritePatternSet(capsule.ptr());
65+
if (rawPm.ptr == nullptr)
66+
throw py::error_already_set();
67+
return py::cast(PyFrozenRewritePatternSet(rawPm),
68+
py::return_value_policy::move);
69+
}
70+
71+
private:
72+
MlirFrozenRewritePatternSet set;
73+
};
74+
75+
} // namespace
76+
77+
/// Create the `mlir.rewrite` here.
78+
void mlir::python::populateRewriteSubmodule(py::module &m) {
79+
//----------------------------------------------------------------------------
80+
// Mapping of the top-level PassManager
81+
//----------------------------------------------------------------------------
82+
#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
83+
py::class_<PyPDLPatternModule>(m, "PDLModule", py::module_local())
84+
.def(py::init<>([](MlirModule module) {
85+
return mlirPDLPatternModuleFromModule(module);
86+
}),
87+
"module"_a, "Create a PDL module from the given module.")
88+
.def("freeze", [](PyPDLPatternModule &self) {
89+
return new PyFrozenRewritePatternSet(mlirFreezeRewritePattern(
90+
mlirRewritePatternSetFromPDLPatternModule(self.get())));
91+
});
92+
#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCg
93+
py::class_<PyFrozenRewritePatternSet>(m, "FrozenRewritePatternSet",
94+
py::module_local())
95+
.def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
96+
&PyFrozenRewritePatternSet::getCapsule)
97+
.def(MLIR_PYTHON_CAPI_FACTORY_ATTR,
98+
&PyFrozenRewritePatternSet::createFromCapsule);
99+
m.def(
100+
"apply_patterns_and_fold_greedily",
101+
[](MlirModule module, MlirFrozenRewritePatternSet set) {
102+
auto status = mlirApplyPatternsAndFoldGreedily(module, set, {});
103+
if (mlirLogicalResultIsFailure(status))
104+
// FIXME: Not sure this is the right error to throw here.
105+
throw py::value_error("pattern application failed to converge");
106+
},
107+
"module"_a, "set"_a,
108+
"Applys the given patterns to the given module greedily while folding "
109+
"results.");
110+
}

mlir/lib/Bindings/Python/Rewrite.h

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
//===- Rewrite.h - Rewrite Submodules of pybind module --------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_BINDINGS_PYTHON_REWRITE_H
10+
#define MLIR_BINDINGS_PYTHON_REWRITE_H
11+
12+
#include "PybindUtils.h"
13+
14+
namespace mlir {
15+
namespace python {
16+
17+
void populateRewriteSubmodule(pybind11::module &m);
18+
19+
} // namespace python
20+
} // namespace mlir
21+
22+
#endif // MLIR_BINDINGS_PYTHON_REWRITE_H
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
add_mlir_upstream_c_api_library(MLIRCAPITransforms
22
Passes.cpp
3+
Rewrite.cpp
34

45
LINK_LIBS PUBLIC
6+
MLIRIR
57
MLIRTransforms
8+
MLIRTransformUtils
69
)

mlir/lib/CAPI/Transforms/Rewrite.cpp

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
//===- Rewrite.cpp - C API for Rewrite Patterns ---------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir-c/Rewrite.h"
10+
#include "mlir-c/Transforms.h"
11+
#include "mlir/CAPI/IR.h"
12+
#include "mlir/CAPI/Support.h"
13+
#include "mlir/IR/PatternMatch.h"
14+
#include "mlir/Rewrite/FrozenRewritePatternSet.h"
15+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
16+
17+
using namespace mlir;
18+
19+
inline mlir::RewritePatternSet &unwrap(MlirRewritePatternSet module) {
20+
assert(module.ptr && "unexpected null module");
21+
return *(static_cast<mlir::RewritePatternSet *>(module.ptr));
22+
}
23+
24+
inline MlirRewritePatternSet wrap(mlir::RewritePatternSet *module) {
25+
return {module};
26+
}
27+
28+
inline mlir::FrozenRewritePatternSet *
29+
unwrap(MlirFrozenRewritePatternSet module) {
30+
assert(module.ptr && "unexpected null module");
31+
return static_cast<mlir::FrozenRewritePatternSet *>(module.ptr);
32+
}
33+
34+
inline MlirFrozenRewritePatternSet wrap(mlir::FrozenRewritePatternSet *module) {
35+
return {module};
36+
}
37+
38+
MlirFrozenRewritePatternSet mlirFreezeRewritePattern(MlirRewritePatternSet op) {
39+
auto *m = new mlir::FrozenRewritePatternSet(std::move(unwrap(op)));
40+
op.ptr = nullptr;
41+
return wrap(m);
42+
}
43+
44+
void mlirFrozenRewritePatternSetDestroy(MlirFrozenRewritePatternSet op) {
45+
delete unwrap(op);
46+
op.ptr = nullptr;
47+
}
48+
49+
MlirLogicalResult
50+
mlirApplyPatternsAndFoldGreedily(MlirModule op,
51+
MlirFrozenRewritePatternSet patterns,
52+
MlirGreedyRewriteDriverConfig) {
53+
return wrap(
54+
mlir::applyPatternsAndFoldGreedily(unwrap(op), *unwrap(patterns)));
55+
}
56+
57+
#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
58+
inline mlir::PDLPatternModule *unwrap(MlirPDLPatternModule module) {
59+
assert(module.ptr && "unexpected null module");
60+
return static_cast<mlir::PDLPatternModule *>(module.ptr);
61+
}
62+
63+
inline MlirPDLPatternModule wrap(mlir::PDLPatternModule *module) {
64+
return {module};
65+
}
66+
67+
MlirPDLPatternModule mlirPDLPatternModuleFromModule(MlirModule op) {
68+
return wrap(new mlir::PDLPatternModule(
69+
mlir::OwningOpRef<mlir::ModuleOp>(unwrap(op))));
70+
}
71+
72+
void mlirPDLPatternModuleDestroy(MlirPDLPatternModule op) {
73+
delete unwrap(op);
74+
op.ptr = nullptr;
75+
}
76+
77+
MlirRewritePatternSet
78+
mlirRewritePatternSetFromPDLPatternModule(MlirPDLPatternModule op) {
79+
auto *m = new mlir::RewritePatternSet(std::move(*unwrap(op)));
80+
op.ptr = nullptr;
81+
return wrap(m);
82+
}
83+
#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH

mlir/python/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ declare_mlir_python_sources(MLIRPythonSources.Core.Python
2121
_mlir_libs/__init__.py
2222
ir.py
2323
passmanager.py
24+
rewrite.py
2425
dialects/_ods_common.py
2526

2627
# The main _mlir module has submodules: include stubs from each.
@@ -448,6 +449,7 @@ declare_mlir_python_extension(MLIRPythonExtension.Core
448449
IRModule.cpp
449450
IRTypes.cpp
450451
Pass.cpp
452+
Rewrite.cpp
451453

452454
# Headers must be included explicitly so they are installed.
453455
Globals.h

mlir/python/mlir/rewrite.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
2+
# See https://llvm.org/LICENSE.txt for license information.
3+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
4+
5+
from ._mlir_libs._mlir.rewrite import *

0 commit comments

Comments
 (0)