Skip to content

Commit 6e6da74

Browse files
authored
[mlir][python] add binding to #gpu.object (#88992)
1 parent 172f6dd commit 6e6da74

File tree

7 files changed

+212
-6
lines changed

7 files changed

+212
-6
lines changed

mlir/include/mlir-c/Dialect/GPU.h

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,31 @@ extern "C" {
1919

2020
MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(GPU, gpu);
2121

22+
//===---------------------------------------------------------------------===//
23+
// ObjectAttr
24+
//===---------------------------------------------------------------------===//
25+
26+
MLIR_CAPI_EXPORTED bool mlirAttributeIsAGPUObjectAttr(MlirAttribute attr);
27+
28+
MLIR_CAPI_EXPORTED MlirAttribute
29+
mlirGPUObjectAttrGet(MlirContext mlirCtx, MlirAttribute target, uint32_t format,
30+
MlirStringRef objectStrRef, MlirAttribute mlirObjectProps);
31+
32+
MLIR_CAPI_EXPORTED MlirAttribute
33+
mlirGPUObjectAttrGetTarget(MlirAttribute mlirObjectAttr);
34+
35+
MLIR_CAPI_EXPORTED uint32_t
36+
mlirGPUObjectAttrGetFormat(MlirAttribute mlirObjectAttr);
37+
38+
MLIR_CAPI_EXPORTED MlirStringRef
39+
mlirGPUObjectAttrGetObject(MlirAttribute mlirObjectAttr);
40+
41+
MLIR_CAPI_EXPORTED bool
42+
mlirGPUObjectAttrHasProperties(MlirAttribute mlirObjectAttr);
43+
44+
MLIR_CAPI_EXPORTED MlirAttribute
45+
mlirGPUObjectAttrGetProperties(MlirAttribute mlirObjectAttr);
46+
2247
#ifdef __cplusplus
2348
}
2449
#endif
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
//===- DialectGPU.cpp - Pybind module for the GPU passes ------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===---------------------------------------------------------------------===//
8+
9+
#include "mlir-c/Dialect/GPU.h"
10+
#include "mlir-c/IR.h"
11+
#include "mlir-c/Support.h"
12+
#include "mlir/Bindings/Python/PybindAdaptors.h"
13+
14+
#include <pybind11/detail/common.h>
15+
#include <pybind11/pybind11.h>
16+
17+
namespace py = pybind11;
18+
using namespace mlir;
19+
using namespace mlir::python;
20+
using namespace mlir::python::adaptors;
21+
22+
// -----------------------------------------------------------------------------
23+
// Module initialization.
24+
// -----------------------------------------------------------------------------
25+
26+
PYBIND11_MODULE(_mlirDialectsGPU, m) {
27+
m.doc() = "MLIR GPU Dialect";
28+
29+
//===-------------------------------------------------------------------===//
30+
// ObjectAttr
31+
//===-------------------------------------------------------------------===//
32+
33+
mlir_attribute_subclass(m, "ObjectAttr", mlirAttributeIsAGPUObjectAttr)
34+
.def_classmethod(
35+
"get",
36+
[](py::object cls, MlirAttribute target, uint32_t format,
37+
py::bytes object, std::optional<MlirAttribute> mlirObjectProps) {
38+
py::buffer_info info(py::buffer(object).request());
39+
MlirStringRef objectStrRef =
40+
mlirStringRefCreate(static_cast<char *>(info.ptr), info.size);
41+
return cls(mlirGPUObjectAttrGet(
42+
mlirAttributeGetContext(target), target, format, objectStrRef,
43+
mlirObjectProps.has_value() ? *mlirObjectProps
44+
: MlirAttribute{nullptr}));
45+
},
46+
"cls"_a, "target"_a, "format"_a, "object"_a,
47+
"properties"_a = py::none(), "Gets a gpu.object from parameters.")
48+
.def_property_readonly(
49+
"target",
50+
[](MlirAttribute self) { return mlirGPUObjectAttrGetTarget(self); })
51+
.def_property_readonly(
52+
"format",
53+
[](MlirAttribute self) { return mlirGPUObjectAttrGetFormat(self); })
54+
.def_property_readonly(
55+
"object",
56+
[](MlirAttribute self) {
57+
MlirStringRef stringRef = mlirGPUObjectAttrGetObject(self);
58+
return py::bytes(stringRef.data, stringRef.length);
59+
})
60+
.def_property_readonly("properties", [](MlirAttribute self) {
61+
if (mlirGPUObjectAttrHasProperties(self))
62+
return py::cast(mlirGPUObjectAttrGetProperties(self));
63+
return py::none().cast<py::object>();
64+
});
65+
}

mlir/lib/CAPI/Dialect/GPU.cpp

Lines changed: 57 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
//===- GPUc.cpp - C Interface for GPU dialect ----------------------------===//
1+
//===- GPU.cpp - C Interface for GPU dialect ------------------------------===//
22
//
33
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
44
// See https://llvm.org/LICENSE.txt for license information.
@@ -9,5 +9,60 @@
99
#include "mlir-c/Dialect/GPU.h"
1010
#include "mlir/CAPI/Registration.h"
1111
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
12+
#include "llvm/Support/Casting.h"
1213

13-
MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(GPU, gpu, mlir::gpu::GPUDialect)
14+
using namespace mlir;
15+
16+
MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(GPU, gpu, gpu::GPUDialect)
17+
18+
//===---------------------------------------------------------------------===//
19+
// ObjectAttr
20+
//===---------------------------------------------------------------------===//
21+
22+
bool mlirAttributeIsAGPUObjectAttr(MlirAttribute attr) {
23+
return llvm::isa<gpu::ObjectAttr>(unwrap(attr));
24+
}
25+
26+
MlirAttribute mlirGPUObjectAttrGet(MlirContext mlirCtx, MlirAttribute target,
27+
uint32_t format, MlirStringRef objectStrRef,
28+
MlirAttribute mlirObjectProps) {
29+
MLIRContext *ctx = unwrap(mlirCtx);
30+
llvm::StringRef object = unwrap(objectStrRef);
31+
DictionaryAttr objectProps;
32+
if (mlirObjectProps.ptr != nullptr)
33+
objectProps = llvm::cast<DictionaryAttr>(unwrap(mlirObjectProps));
34+
return wrap(gpu::ObjectAttr::get(ctx, unwrap(target),
35+
static_cast<gpu::CompilationTarget>(format),
36+
StringAttr::get(ctx, object), objectProps));
37+
}
38+
39+
MlirAttribute mlirGPUObjectAttrGetTarget(MlirAttribute mlirObjectAttr) {
40+
gpu::ObjectAttr objectAttr =
41+
llvm::cast<gpu::ObjectAttr>(unwrap(mlirObjectAttr));
42+
return wrap(objectAttr.getTarget());
43+
}
44+
45+
uint32_t mlirGPUObjectAttrGetFormat(MlirAttribute mlirObjectAttr) {
46+
gpu::ObjectAttr objectAttr =
47+
llvm::cast<gpu::ObjectAttr>(unwrap(mlirObjectAttr));
48+
return static_cast<uint32_t>(objectAttr.getFormat());
49+
}
50+
51+
MlirStringRef mlirGPUObjectAttrGetObject(MlirAttribute mlirObjectAttr) {
52+
gpu::ObjectAttr objectAttr =
53+
llvm::cast<gpu::ObjectAttr>(unwrap(mlirObjectAttr));
54+
llvm::StringRef object = objectAttr.getObject();
55+
return mlirStringRefCreate(object.data(), object.size());
56+
}
57+
58+
bool mlirGPUObjectAttrHasProperties(MlirAttribute mlirObjectAttr) {
59+
gpu::ObjectAttr objectAttr =
60+
llvm::cast<gpu::ObjectAttr>(unwrap(mlirObjectAttr));
61+
return objectAttr.getProperties() != nullptr;
62+
}
63+
64+
MlirAttribute mlirGPUObjectAttrGetProperties(MlirAttribute mlirObjectAttr) {
65+
gpu::ObjectAttr objectAttr =
66+
llvm::cast<gpu::ObjectAttr>(unwrap(mlirObjectAttr));
67+
return wrap(objectAttr.getProperties());
68+
}

mlir/python/CMakeLists.txt

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -498,6 +498,19 @@ declare_mlir_python_extension(MLIRPythonExtension.Dialects.Linalg.Pybind
498498
MLIRCAPILinalg
499499
)
500500

501+
declare_mlir_python_extension(MLIRPythonExtension.Dialects.GPU.Pybind
502+
MODULE_NAME _mlirDialectsGPU
503+
ADD_TO_PARENT MLIRPythonSources.Dialects.gpu
504+
ROOT_DIR "${PYTHON_SOURCE_DIR}"
505+
SOURCES
506+
DialectGPU.cpp
507+
PRIVATE_LINK_LIBS
508+
LLVMSupport
509+
EMBED_CAPI_LINK_LIBS
510+
MLIRCAPIIR
511+
MLIRCAPIGPU
512+
)
513+
501514
declare_mlir_python_extension(MLIRPythonExtension.Dialects.LLVM.Pybind
502515
MODULE_NAME _mlirDialectsLLVM
503516
ADD_TO_PARENT MLIRPythonSources.Dialects.llvm

mlir/python/mlir/dialects/gpu/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,4 @@
44

55
from .._gpu_ops_gen import *
66
from .._gpu_enum_gen import *
7+
from ..._mlir_libs._mlirDialectsGPU import *

mlir/test/python/dialects/gpu/dialect.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,3 +30,28 @@ def testMMAElementWiseAttr():
3030
# CHECK: %block_dim_y = gpu.block_dim y
3131
print(module)
3232
pass
33+
34+
35+
# CHECK-LABEL: testObjectAttr
36+
@run
37+
def testObjectAttr():
38+
target = Attribute.parse("#nvvm.target")
39+
format = gpu.CompilationTarget.Fatbin
40+
object = b"BC\xc0\xde5\x14\x00\x00\x05\x00\x00\x00b\x0c0$MY\xbef"
41+
properties = DictAttr.get({"O": IntegerAttr.get(IntegerType.get_signless(32), 2)})
42+
o = gpu.ObjectAttr.get(target, format, object, properties)
43+
# CHECK: #gpu.object<#nvvm.target, properties = {O = 2 : i32}, "BC\C0\DE5\14\00\00\05\00\00\00b\0C0$MY\BEf">
44+
print(o)
45+
assert o.object == object
46+
47+
o = gpu.ObjectAttr.get(target, format, object)
48+
# CHECK: #gpu.object<#nvvm.target, "BC\C0\DE5\14\00\00\05\00\00\00b\0C0$MY\BEf">
49+
print(o)
50+
51+
object = (
52+
b"//\n// Generated by LLVM NVPTX Back-End\n//\n\n.version 6.0\n.target sm_50"
53+
)
54+
o = gpu.ObjectAttr.get(target, format, object)
55+
# CHECK: #gpu.object<#nvvm.target, "//\0A// Generated by LLVM NVPTX Back-End\0A//\0A\0A.version 6.0\0A.target sm_50">
56+
print(o)
57+
assert o.object == object

mlir/test/python/dialects/gpu/module-to-binary-nvvm.py

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,20 @@ def testGPUToLLVMBin():
3434
pm = PassManager("any")
3535
pm.add("gpu-module-to-binary{format=llvm}")
3636
pm.run(module.operation)
37+
# CHECK-LABEL: gpu.binary @kernel_module1
3738
print(module)
38-
# CHECK-LABEL:gpu.binary @kernel_module1
39-
# CHECK:[#gpu.object<#nvvm.target<chip = "sm_70">, offload = "{{.*}}">]
39+
40+
o = gpu.ObjectAttr(module.body.operations[0].objects[0])
41+
# CHECK: #gpu.object<#nvvm.target<chip = "sm_70">, offload = "{{.*}}">
42+
print(o)
43+
# CHECK: #nvvm.target<chip = "sm_70">
44+
print(o.target)
45+
# CHECK: offload
46+
print(gpu.CompilationTarget(o.format))
47+
# CHECK: b'{{.*}}'
48+
print(o.object)
49+
# CHECK: None
50+
print(o.properties)
4051

4152

4253
# CHECK-LABEL: testGPUToASMBin
@@ -59,6 +70,17 @@ def testGPUToASMBin():
5970
pm = PassManager("any")
6071
pm.add("gpu-module-to-binary{format=isa}")
6172
pm.run(module.operation)
62-
print(module)
6373
# CHECK-LABEL:gpu.binary @kernel_module2
64-
# CHECK:[#gpu.object<#nvvm.target<flags = {fast}>, properties = {O = 2 : i32}, assembly = "{{.*}}">, #gpu.object<#nvvm.target, properties = {O = 2 : i32}, assembly = "{{.*}}">]
74+
print(module)
75+
76+
o = gpu.ObjectAttr(module.body.operations[0].objects[0])
77+
# CHECK: #gpu.object<#nvvm.target<flags = {fast}>
78+
print(o)
79+
# CHECK: #nvvm.target<flags = {fast}>
80+
print(o.target)
81+
# CHECK: assembly
82+
print(gpu.CompilationTarget(o.format))
83+
# CHECK: b'//\n// Generated by LLVM NVPTX Back-End{{.*}}'
84+
print(o.object)
85+
# CHECK: {O = 2 : i32}
86+
print(o.properties)

0 commit comments

Comments
 (0)