Skip to content

Commit b87219f

Browse files
[mlir][python] Add basic python support for GPU dialect and passes
Differential Revision: https://reviews.llvm.org/D101449
1 parent e7db840 commit b87219f

File tree

11 files changed

+164
-0
lines changed

11 files changed

+164
-0
lines changed

mlir/include/mlir-c/Dialect/GPU.h

+28
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
//===-- mlir-c/Dialect/GPU.h - C API for GPU dialect -------------*- C -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM
4+
// Exceptions.
5+
// See https://llvm.org/LICENSE.txt for license information.
6+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7+
//
8+
//===---------------------------------------------------------------------===//
9+
10+
#ifndef MLIR_C_DIALECT_GPU_H
11+
#define MLIR_C_DIALECT_GPU_H
12+
13+
#include "mlir-c/Registration.h"
14+
#include "mlir-c/Support.h"
15+
16+
#ifdef __cplusplus
17+
extern "C" {
18+
#endif
19+
20+
MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(GPU, gpu);
21+
22+
#ifdef __cplusplus
23+
}
24+
#endif
25+
26+
#include "mlir/Dialect/GPU/Passes.capi.h.inc"
27+
28+
#endif // MLIR_C_DIALECT_GPU_H

mlir/include/mlir/Dialect/GPU/CMakeLists.txt

+2
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ add_public_tablegen_target(MLIRParallelLoopMapperEnumsGen)
1818

1919
set(LLVM_TARGET_DEFINITIONS Passes.td)
2020
mlir_tablegen(Passes.h.inc -gen-pass-decls -name GPU)
21+
mlir_tablegen(Passes.capi.h.inc -gen-pass-capi-header --prefix GPU)
22+
mlir_tablegen(Passes.capi.cpp.inc -gen-pass-capi-impl --prefix GPU)
2123
add_public_tablegen_target(MLIRGPUPassIncGen)
2224

2325
add_mlir_doc(Passes GPUPasses ./ -gen-pass-doc)

mlir/lib/Bindings/Python/CMakeLists.txt

+13
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,11 @@ add_mlir_dialect_python_bindings(MLIRBindingsPythonBuiltinOps
4141
DIALECT_NAME builtin)
4242
add_dependencies(MLIRBindingsPythonSources MLIRBindingsPythonBuiltinOps)
4343

44+
add_mlir_dialect_python_bindings(MLIRBindingsPythonGPUOps
45+
TD_FILE GPUOps.td
46+
DIALECT_NAME gpu)
47+
add_dependencies(MLIRBindingsPythonSources MLIRBindingsPythonGPUOps)
48+
4449
add_mlir_dialect_python_bindings(MLIRBindingsPythonLinalgOps
4550
TD_FILE LinalgOps.td
4651
DIALECT_NAME linalg
@@ -133,6 +138,14 @@ add_mlir_python_extension(MLIRAsyncPassesBindingsPythonExtension _mlirAsyncPasse
133138
)
134139
add_dependencies(MLIRBindingsPythonExtension MLIRAsyncPassesBindingsPythonExtension)
135140

141+
add_mlir_python_extension(MLIRGPUPassesBindingsPythonExtension _mlirGPUPasses
142+
INSTALL_DIR
143+
python
144+
SOURCES
145+
GPUPasses.cpp
146+
)
147+
add_dependencies(MLIRBindingsPythonExtension MLIRGPUPassesBindingsPythonExtension)
148+
136149
add_mlir_python_extension(MLIRLinalgPassesBindingsPythonExtension _mlirLinalgPasses
137150
INSTALL_DIR
138151
python

mlir/lib/Bindings/Python/GPUOps.td

+15
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
//===-- GPUOps.td - Entry point GPU_dialect bindings ------*- tablegen -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===---------------------------------------------------------------------===//
8+
9+
#ifndef PYTHON_BINDINGS_GPU_OPS
10+
#define PYTHON_BINDINGS_GPU_OPS
11+
12+
include "mlir/Bindings/Python/Attributes.td"
13+
include "mlir/Dialect/GPU/GPUOps.td"
14+
15+
#endif
+22
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
//===- GPUPasses.cpp - Pybind module for the GPU passes ------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===---------------------------------------------------------------------===//
8+
9+
#include "mlir-c/Dialect/GPU.h"
10+
11+
#include <pybind11/pybind11.h>
12+
13+
// -----------------------------------------------------------------------------
14+
// Module initialization.
15+
// -----------------------------------------------------------------------------
16+
17+
PYBIND11_MODULE(_mlirGPUPasses, m) {
18+
m.doc() = "MLIR GPU Dialect Passes";
19+
20+
// Register all GPU passes on load.
21+
mlirRegisterGPUPasses();
22+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
2+
# See https://llvm.org/LICENSE.txt for license information.
3+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
4+
5+
from .._gpu_ops_gen import *
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
2+
# See https://llvm.org/LICENSE.txt for license information.
3+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
4+
5+
from ...._cext_loader import _load_extension
6+
_cextGPUPasses = _load_extension("_mlirGPUPasses")

mlir/lib/CAPI/Dialect/CMakeLists.txt

+15
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
set(LLVM_OPTIONAL_SOURCES
33
Async.cpp
44
AsyncPasses.cpp
5+
GPU.cpp
6+
GPUPasses.cpp
57
Linalg.cpp
68
LinalgPasses.cpp
79
SCF.cpp
@@ -24,6 +26,19 @@ add_mlir_public_c_api_library(MLIRCAPIAsync
2426
MLIRPass
2527
)
2628

29+
add_mlir_public_c_api_library(MLIRCAPIGPU
30+
GPU.cpp
31+
GPUPasses.cpp
32+
33+
DEPENDS
34+
MLIRGPUPassIncGen
35+
36+
LINK_LIBS PUBLIC
37+
MLIRCAPIIR
38+
MLIRGPU
39+
MLIRPass
40+
)
41+
2742
add_mlir_public_c_api_library(MLIRCAPILinalg
2843
Linalg.cpp
2944
LinalgPasses.cpp

mlir/lib/CAPI/Dialect/GPU.cpp

+13
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
//===- GPUc.cpp - C Interface for GPU dialect ----------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir-c/Dialect/GPU.h"
10+
#include "mlir/CAPI/Registration.h"
11+
#include "mlir/Dialect/GPU/GPUDialect.h"
12+
13+
MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(GPU, gpu, mlir::gpu::GPUDialect)

mlir/lib/CAPI/Dialect/GPUPasses.cpp

+26
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
//===- GPUPasses.cpp - C API for GPU Dialect Passes ----------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/CAPI/Pass.h"
10+
#include "mlir/Dialect/GPU/Passes.h"
11+
#include "mlir/Pass/Pass.h"
12+
13+
// Must include the declarations as they carry important visibility attributes.
14+
#include "mlir/Dialect/GPU/Passes.capi.h.inc"
15+
16+
using namespace mlir;
17+
18+
#ifdef __cplusplus
19+
extern "C" {
20+
#endif
21+
22+
#include "mlir/Dialect/GPU/Passes.capi.cpp.inc"
23+
24+
#ifdef __cplusplus
25+
}
26+
#endif
+19
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# RUN: %PYTHON %s | FileCheck %s
2+
3+
from mlir.ir import *
4+
import mlir.dialects.gpu
5+
import mlir.dialects.gpu.passes
6+
from mlir.passmanager import *
7+
8+
def run(f):
9+
print("\nTEST:", f.__name__)
10+
f()
11+
12+
def testGPUPass():
13+
with Context() as context:
14+
PassManager.parse('gpu-kernel-outlining')
15+
print('SUCCESS')
16+
17+
# CHECK-LABEL: testGPUPass
18+
# CHECK: SUCCESS
19+
run(testGPUPass)

0 commit comments

Comments
 (0)