Skip to content

[mlir] Add PDL C & Python usage #94714

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jun 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions mlir/include/mlir-c/Bindings/Python/Interop.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
#include "mlir-c/IR.h"
#include "mlir-c/IntegerSet.h"
#include "mlir-c/Pass.h"
#include "mlir-c/Rewrite.h"

// The 'mlir' Python package is relocatable and supports co-existing in multiple
// projects. Each project must define its outer package prefix with this define
Expand Down Expand Up @@ -284,6 +285,26 @@ static inline MlirModule mlirPythonCapsuleToModule(PyObject *capsule) {
return module;
}

/** Creates a capsule object encapsulating the raw C-API
* MlirFrozenRewritePatternSet.
* The returned capsule does not extend or affect ownership of any Python
* objects that reference the module in any way. */
static inline PyObject *
mlirPythonFrozenRewritePatternSetToCapsule(MlirFrozenRewritePatternSet pm) {
return PyCapsule_New(MLIR_PYTHON_GET_WRAPPED_POINTER(pm),
MLIR_PYTHON_CAPSULE_PASS_MANAGER, NULL);
}

/** Extracts an MlirFrozenRewritePatternSet from a capsule as produced from
* mlirPythonFrozenRewritePatternSetToCapsule. If the capsule is not of the
* right type, then a null module is returned. */
static inline MlirFrozenRewritePatternSet
mlirPythonCapsuleToFrozenRewritePatternSet(PyObject *capsule) {
void *ptr = PyCapsule_GetPointer(capsule, MLIR_PYTHON_CAPSULE_PASS_MANAGER);
MlirFrozenRewritePatternSet pm = {ptr};
return pm;
}

/** Creates a capsule object encapsulating the raw C-API MlirPassManager.
* The returned capsule does not extend or affect ownership of any Python
* objects that reference the module in any way. */
Expand Down
60 changes: 60 additions & 0 deletions mlir/include/mlir-c/Rewrite.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
//===-- mlir-c/Rewrite.h - Helpers for C API to Rewrites ----------*- C -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM
// Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This header declares the registration and creation method for
// rewrite patterns.
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_C_REWRITE_H
#define MLIR_C_REWRITE_H

#include "mlir-c/IR.h"
#include "mlir-c/Support.h"
#include "mlir/Config/mlir-config.h"

//===----------------------------------------------------------------------===//
/// Opaque type declarations (see mlir-c/IR.h for more details).
//===----------------------------------------------------------------------===//

#define DEFINE_C_API_STRUCT(name, storage) \
struct name { \
storage *ptr; \
}; \
typedef struct name name

DEFINE_C_API_STRUCT(MlirFrozenRewritePatternSet, void);
DEFINE_C_API_STRUCT(MlirGreedyRewriteDriverConfig, void);
DEFINE_C_API_STRUCT(MlirRewritePatternSet, void);

MLIR_CAPI_EXPORTED MlirFrozenRewritePatternSet
mlirFreezeRewritePattern(MlirRewritePatternSet op);

MLIR_CAPI_EXPORTED void
mlirFrozenRewritePatternSetDestroy(MlirFrozenRewritePatternSet op);

MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyPatternsAndFoldGreedily(
MlirModule op, MlirFrozenRewritePatternSet patterns,
MlirGreedyRewriteDriverConfig);

#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
DEFINE_C_API_STRUCT(MlirPDLPatternModule, void);

MLIR_CAPI_EXPORTED MlirPDLPatternModule
mlirPDLPatternModuleFromModule(MlirModule op);

MLIR_CAPI_EXPORTED void mlirPDLPatternModuleDestroy(MlirPDLPatternModule op);

MLIR_CAPI_EXPORTED MlirRewritePatternSet
mlirRewritePatternSetFromPDLPatternModule(MlirPDLPatternModule op);
#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH

#undef DEFINE_C_API_STRUCT

#endif // MLIR_C_REWRITE_H
21 changes: 21 additions & 0 deletions mlir/include/mlir/Bindings/Python/PybindAdaptors.h
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,27 @@ struct type_caster<MlirModule> {
};
};

/// Casts object <-> MlirFrozenRewritePatternSet.
template <>
struct type_caster<MlirFrozenRewritePatternSet> {
PYBIND11_TYPE_CASTER(MlirFrozenRewritePatternSet,
_("MlirFrozenRewritePatternSet"));
bool load(handle src, bool) {
py::object capsule = mlirApiObjectToCapsule(src);
value = mlirPythonCapsuleToFrozenRewritePatternSet(capsule.ptr());
return value.ptr != nullptr;
}
static handle cast(MlirFrozenRewritePatternSet v, return_value_policy,
handle) {
py::object capsule = py::reinterpret_steal<py::object>(
mlirPythonFrozenRewritePatternSetToCapsule(v));
return py::module::import(MAKE_MLIR_PYTHON_QUALNAME("rewrite"))
.attr("FrozenRewritePatternSet")
.attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule)
.release();
};
};

/// Casts object <-> MlirOperation.
template <>
struct type_caster<MlirOperation> {
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Bindings/Python/IRModule.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "mlir-c/Diagnostics.h"
#include "mlir-c/IR.h"
#include "mlir-c/IntegerSet.h"
#include "mlir-c/Transforms.h"
#include "mlir/Bindings/Python/PybindAdaptors.h"
#include "llvm/ADT/DenseMap.h"

Expand Down
4 changes: 4 additions & 0 deletions mlir/lib/Bindings/Python/MainModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "Globals.h"
#include "IRModule.h"
#include "Pass.h"
#include "Rewrite.h"

namespace py = pybind11;
using namespace mlir;
Expand Down Expand Up @@ -116,6 +117,9 @@ PYBIND11_MODULE(_mlir, m) {
populateIRInterfaces(irModule);
populateIRTypes(irModule);

auto rewriteModule = m.def_submodule("rewrite", "MLIR Rewrite Bindings");
populateRewriteSubmodule(rewriteModule);

// Define and populate PassManager submodule.
auto passModule =
m.def_submodule("passmanager", "MLIR Pass Management Bindings");
Expand Down
110 changes: 110 additions & 0 deletions mlir/lib/Bindings/Python/Rewrite.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
//===- Rewrite.cpp - Rewrite ----------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "Rewrite.h"

#include "IRModule.h"
#include "mlir-c/Bindings/Python/Interop.h"
#include "mlir-c/Rewrite.h"
#include "mlir/Config/mlir-config.h"

namespace py = pybind11;
using namespace mlir;
using namespace py::literals;
using namespace mlir::python;

namespace {

#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
/// Owning Wrapper around a PDLPatternModule.
class PyPDLPatternModule {
public:
PyPDLPatternModule(MlirPDLPatternModule module) : module(module) {}
PyPDLPatternModule(PyPDLPatternModule &&other) noexcept
: module(other.module) {
other.module.ptr = nullptr;
}
~PyPDLPatternModule() {
if (module.ptr != nullptr)
mlirPDLPatternModuleDestroy(module);
}
MlirPDLPatternModule get() { return module; }

private:
MlirPDLPatternModule module;
};
#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH

/// Owning Wrapper around a FrozenRewritePatternSet.
class PyFrozenRewritePatternSet {
public:
PyFrozenRewritePatternSet(MlirFrozenRewritePatternSet set) : set(set) {}
PyFrozenRewritePatternSet(PyFrozenRewritePatternSet &&other) noexcept
: set(other.set) {
other.set.ptr = nullptr;
}
~PyFrozenRewritePatternSet() {
if (set.ptr != nullptr)
mlirFrozenRewritePatternSetDestroy(set);
}
MlirFrozenRewritePatternSet get() { return set; }

pybind11::object getCapsule() {
return py::reinterpret_steal<py::object>(
mlirPythonFrozenRewritePatternSetToCapsule(get()));
}

static pybind11::object createFromCapsule(pybind11::object capsule) {
MlirFrozenRewritePatternSet rawPm =
mlirPythonCapsuleToFrozenRewritePatternSet(capsule.ptr());
if (rawPm.ptr == nullptr)
throw py::error_already_set();
return py::cast(PyFrozenRewritePatternSet(rawPm),
py::return_value_policy::move);
}

private:
MlirFrozenRewritePatternSet set;
};

} // namespace

/// Create the `mlir.rewrite` here.
void mlir::python::populateRewriteSubmodule(py::module &m) {
//----------------------------------------------------------------------------
// Mapping of the top-level PassManager
//----------------------------------------------------------------------------
#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
py::class_<PyPDLPatternModule>(m, "PDLModule", py::module_local())
.def(py::init<>([](MlirModule module) {
return mlirPDLPatternModuleFromModule(module);
}),
"module"_a, "Create a PDL module from the given module.")
.def("freeze", [](PyPDLPatternModule &self) {
return new PyFrozenRewritePatternSet(mlirFreezeRewritePattern(
mlirRewritePatternSetFromPDLPatternModule(self.get())));
});
#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCg
py::class_<PyFrozenRewritePatternSet>(m, "FrozenRewritePatternSet",
py::module_local())
.def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
&PyFrozenRewritePatternSet::getCapsule)
.def(MLIR_PYTHON_CAPI_FACTORY_ATTR,
&PyFrozenRewritePatternSet::createFromCapsule);
m.def(
"apply_patterns_and_fold_greedily",
[](MlirModule module, MlirFrozenRewritePatternSet set) {
auto status = mlirApplyPatternsAndFoldGreedily(module, set, {});
if (mlirLogicalResultIsFailure(status))
// FIXME: Not sure this is the right error to throw here.
throw py::value_error("pattern application failed to converge");
},
"module"_a, "set"_a,
"Applys the given patterns to the given module greedily while folding "
"results.");
}
22 changes: 22 additions & 0 deletions mlir/lib/Bindings/Python/Rewrite.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
//===- Rewrite.h - Rewrite Submodules of pybind module --------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_BINDINGS_PYTHON_REWRITE_H
#define MLIR_BINDINGS_PYTHON_REWRITE_H

#include "PybindUtils.h"

namespace mlir {
namespace python {

void populateRewriteSubmodule(pybind11::module &m);

} // namespace python
} // namespace mlir

#endif // MLIR_BINDINGS_PYTHON_REWRITE_H
3 changes: 3 additions & 0 deletions mlir/lib/CAPI/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
add_mlir_upstream_c_api_library(MLIRCAPITransforms
Passes.cpp
Rewrite.cpp

LINK_LIBS PUBLIC
MLIRIR
MLIRTransforms
MLIRTransformUtils
)
83 changes: 83 additions & 0 deletions mlir/lib/CAPI/Transforms/Rewrite.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
//===- Rewrite.cpp - C API for Rewrite Patterns ---------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir-c/Rewrite.h"
#include "mlir-c/Transforms.h"
#include "mlir/CAPI/IR.h"
#include "mlir/CAPI/Support.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Rewrite/FrozenRewritePatternSet.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

using namespace mlir;

inline mlir::RewritePatternSet &unwrap(MlirRewritePatternSet module) {
assert(module.ptr && "unexpected null module");
return *(static_cast<mlir::RewritePatternSet *>(module.ptr));
}

inline MlirRewritePatternSet wrap(mlir::RewritePatternSet *module) {
return {module};
}

inline mlir::FrozenRewritePatternSet *
unwrap(MlirFrozenRewritePatternSet module) {
assert(module.ptr && "unexpected null module");
return static_cast<mlir::FrozenRewritePatternSet *>(module.ptr);
}

inline MlirFrozenRewritePatternSet wrap(mlir::FrozenRewritePatternSet *module) {
return {module};
}

MlirFrozenRewritePatternSet mlirFreezeRewritePattern(MlirRewritePatternSet op) {
auto *m = new mlir::FrozenRewritePatternSet(std::move(unwrap(op)));
op.ptr = nullptr;
return wrap(m);
}

void mlirFrozenRewritePatternSetDestroy(MlirFrozenRewritePatternSet op) {
delete unwrap(op);
op.ptr = nullptr;
}

MlirLogicalResult
mlirApplyPatternsAndFoldGreedily(MlirModule op,
MlirFrozenRewritePatternSet patterns,
MlirGreedyRewriteDriverConfig) {
return wrap(
mlir::applyPatternsAndFoldGreedily(unwrap(op), *unwrap(patterns)));
}

#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
inline mlir::PDLPatternModule *unwrap(MlirPDLPatternModule module) {
assert(module.ptr && "unexpected null module");
return static_cast<mlir::PDLPatternModule *>(module.ptr);
}

inline MlirPDLPatternModule wrap(mlir::PDLPatternModule *module) {
return {module};
}

MlirPDLPatternModule mlirPDLPatternModuleFromModule(MlirModule op) {
return wrap(new mlir::PDLPatternModule(
mlir::OwningOpRef<mlir::ModuleOp>(unwrap(op))));
}

void mlirPDLPatternModuleDestroy(MlirPDLPatternModule op) {
delete unwrap(op);
op.ptr = nullptr;
}

MlirRewritePatternSet
mlirRewritePatternSetFromPDLPatternModule(MlirPDLPatternModule op) {
auto *m = new mlir::RewritePatternSet(std::move(*unwrap(op)));
op.ptr = nullptr;
return wrap(m);
}
#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
2 changes: 2 additions & 0 deletions mlir/python/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ declare_mlir_python_sources(MLIRPythonSources.Core.Python
_mlir_libs/__init__.py
ir.py
passmanager.py
rewrite.py
dialects/_ods_common.py

# The main _mlir module has submodules: include stubs from each.
Expand Down Expand Up @@ -448,6 +449,7 @@ declare_mlir_python_extension(MLIRPythonExtension.Core
IRModule.cpp
IRTypes.cpp
Pass.cpp
Rewrite.cpp

# Headers must be included explicitly so they are installed.
Globals.h
Expand Down
Loading
Loading