Skip to content

Commit 256c36f

Browse files
committed
[mlir][python] add binding to gpu.object
1 parent b1385db commit 256c36f

File tree

7 files changed

+211
-5
lines changed

7 files changed

+211
-5
lines changed

mlir/include/mlir-c/Dialect/GPU.h

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,31 @@ extern "C" {
1919

2020
MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(GPU, gpu);
2121

22+
//===---------------------------------------------------------------------===//
23+
// ObjectAttr
24+
//===---------------------------------------------------------------------===//
25+
26+
MLIR_CAPI_EXPORTED bool mlirAttributeIsAGPUObjectAttr(MlirAttribute attr);
27+
28+
MLIR_CAPI_EXPORTED MlirAttribute
29+
mlirGPUObjectAttrGet(MlirContext mlirCtx, MlirAttribute target, uint32_t format,
30+
MlirStringRef objectStrRef, MlirAttribute mlirObjectProps);
31+
32+
MLIR_CAPI_EXPORTED MlirAttribute
33+
mlirGPUObjectAttrGetTarget(MlirAttribute mlirObjectAttr);
34+
35+
MLIR_CAPI_EXPORTED uint32_t
36+
mlirGPUObjectAttrGetFormat(MlirAttribute mlirObjectAttr);
37+
38+
MLIR_CAPI_EXPORTED MlirStringRef
39+
mlirGPUObjectAttrGetObject(MlirAttribute mlirObjectAttr);
40+
41+
MLIR_CAPI_EXPORTED bool
42+
mlirGPUObjectAttrHasProperties(MlirAttribute mlirObjectAttr);
43+
44+
MLIR_CAPI_EXPORTED MlirAttribute
45+
mlirGPUObjectAttrGetProperties(MlirAttribute mlirObjectAttr);
46+
2247
#ifdef __cplusplus
2348
}
2449
#endif
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
//===- DialectGPU.cpp - Pybind module for the GPU passes ------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===---------------------------------------------------------------------===//
8+
9+
#include "mlir-c/Dialect/GPU.h"
10+
#include "mlir-c/IR.h"
11+
#include "mlir-c/Support.h"
12+
#include "mlir/Bindings/Python/PybindAdaptors.h"
13+
14+
#include <pybind11/detail/common.h>
15+
#include <pybind11/pybind11.h>
16+
17+
namespace py = pybind11;
18+
using namespace mlir;
19+
using namespace mlir::python;
20+
using namespace mlir::python::adaptors;
21+
22+
// -----------------------------------------------------------------------------
23+
// Module initialization.
24+
// -----------------------------------------------------------------------------
25+
26+
PYBIND11_MODULE(_mlirDialectsGPU, m) {
27+
m.doc() = "MLIR GPU Dialect";
28+
29+
//===-------------------------------------------------------------------===//
30+
// ObjectAttr
31+
//===-------------------------------------------------------------------===//
32+
33+
mlir_attribute_subclass(m, "ObjectAttr", mlirAttributeIsAGPUObjectAttr)
34+
.def_classmethod(
35+
"get",
36+
[](py::object cls, MlirAttribute target, uint32_t format,
37+
py::bytes object, std::optional<MlirAttribute> mlirObjectProps,
38+
std::optional<MlirContext> context) {
39+
py::buffer_info info(py::buffer(object).request());
40+
MlirStringRef objectStrRef =
41+
mlirStringRefCreate(static_cast<char *>(info.ptr), info.size);
42+
return cls(mlirGPUObjectAttrGet(
43+
context.has_value() ? *context
44+
: mlirAttributeGetContext(target),
45+
target, format, objectStrRef,
46+
mlirObjectProps.has_value() ? *mlirObjectProps
47+
: MlirAttribute{nullptr}));
48+
},
49+
"cls"_a, "target"_a, "format"_a, "object"_a,
50+
"properties"_a = py::none(), "context"_a = py::none(),
51+
"Gets a gpu.object from parameters.")
52+
.def_property_readonly(
53+
"target",
54+
[](MlirAttribute self) { return mlirGPUObjectAttrGetTarget(self); })
55+
.def_property_readonly(
56+
"format",
57+
[](MlirAttribute self) { return mlirGPUObjectAttrGetFormat(self); })
58+
.def_property_readonly(
59+
"object",
60+
[](MlirAttribute self) {
61+
MlirStringRef stringRef = mlirGPUObjectAttrGetObject(self);
62+
return py::bytes(stringRef.data, stringRef.length);
63+
})
64+
.def_property_readonly("properties", [](MlirAttribute self) {
65+
if (mlirGPUObjectAttrHasProperties(self))
66+
return py::cast(mlirGPUObjectAttrGetProperties(self));
67+
return py::none().cast<py::object>();
68+
});
69+
}

mlir/lib/CAPI/Dialect/GPU.cpp

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
//===- GPUc.cpp - C Interface for GPU dialect ----------------------------===//
1+
//===- GPU.cpp - C Interface for GPU dialect ------------------------------===//
22
//
33
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
44
// See https://llvm.org/LICENSE.txt for license information.
@@ -9,5 +9,57 @@
99
#include "mlir-c/Dialect/GPU.h"
1010
#include "mlir/CAPI/Registration.h"
1111
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
12+
#include "llvm/Support/Casting.h"
1213

1314
MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(GPU, gpu, mlir::gpu::GPUDialect)
15+
16+
//===---------------------------------------------------------------------===//
17+
// ObjectAttr
18+
//===---------------------------------------------------------------------===//
19+
20+
bool mlirAttributeIsAGPUObjectAttr(MlirAttribute attr) {
21+
return llvm::isa<mlir::gpu::ObjectAttr>(unwrap(attr));
22+
}
23+
24+
MlirAttribute mlirGPUObjectAttrGet(MlirContext mlirCtx, MlirAttribute target,
25+
uint32_t format, MlirStringRef objectStrRef,
26+
MlirAttribute mlirObjectProps) {
27+
mlir::MLIRContext *ctx = unwrap(mlirCtx);
28+
llvm::StringRef object = unwrap(objectStrRef);
29+
mlir::DictionaryAttr objectProps =
30+
llvm::cast<mlir::DictionaryAttr>(unwrap(mlirObjectProps));
31+
return wrap(mlir::gpu::ObjectAttr::get(
32+
ctx, unwrap(target), static_cast<mlir::gpu::CompilationTarget>(format),
33+
mlir::StringAttr::get(ctx, object), objectProps));
34+
}
35+
36+
MlirAttribute mlirGPUObjectAttrGetTarget(MlirAttribute mlirObjectAttr) {
37+
mlir::gpu::ObjectAttr objectAttr =
38+
llvm::cast<mlir::gpu::ObjectAttr>(unwrap(mlirObjectAttr));
39+
return wrap(objectAttr.getTarget());
40+
}
41+
42+
uint32_t mlirGPUObjectAttrGetFormat(MlirAttribute mlirObjectAttr) {
43+
mlir::gpu::ObjectAttr objectAttr =
44+
llvm::cast<mlir::gpu::ObjectAttr>(unwrap(mlirObjectAttr));
45+
return static_cast<uint32_t>(objectAttr.getFormat());
46+
}
47+
48+
MlirStringRef mlirGPUObjectAttrGetObject(MlirAttribute mlirObjectAttr) {
49+
mlir::gpu::ObjectAttr objectAttr =
50+
llvm::cast<mlir::gpu::ObjectAttr>(unwrap(mlirObjectAttr));
51+
llvm::StringRef object = objectAttr.getObject();
52+
return mlirStringRefCreate(object.data(), object.size());
53+
}
54+
55+
bool mlirGPUObjectAttrHasProperties(MlirAttribute mlirObjectAttr) {
56+
mlir::gpu::ObjectAttr objectAttr =
57+
llvm::cast<mlir::gpu::ObjectAttr>(unwrap(mlirObjectAttr));
58+
return objectAttr.getProperties() != nullptr;
59+
}
60+
61+
MlirAttribute mlirGPUObjectAttrGetProperties(MlirAttribute mlirObjectAttr) {
62+
mlir::gpu::ObjectAttr objectAttr =
63+
llvm::cast<mlir::gpu::ObjectAttr>(unwrap(mlirObjectAttr));
64+
return wrap(objectAttr.getProperties());
65+
}

mlir/python/CMakeLists.txt

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -498,6 +498,17 @@ declare_mlir_python_extension(MLIRPythonExtension.Dialects.Linalg.Pybind
498498
MLIRCAPILinalg
499499
)
500500

501+
declare_mlir_python_extension(MLIRPythonExtension.Dialects.GPU.Pybind
502+
MODULE_NAME _mlirDialectsGPU
503+
ADD_TO_PARENT MLIRPythonSources.Dialects.gpu
504+
ROOT_DIR "${PYTHON_SOURCE_DIR}"
505+
SOURCES
506+
DialectGPU.cpp
507+
EMBED_CAPI_LINK_LIBS
508+
MLIRCAPIIR
509+
MLIRCAPIGPU
510+
)
511+
501512
declare_mlir_python_extension(MLIRPythonExtension.Dialects.LLVM.Pybind
502513
MODULE_NAME _mlirDialectsLLVM
503514
ADD_TO_PARENT MLIRPythonSources.Dialects.llvm

mlir/python/mlir/dialects/gpu/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,4 @@
44

55
from .._gpu_ops_gen import *
66
from .._gpu_enum_gen import *
7+
from ..._mlir_libs._mlirDialectsGPU import *

mlir/test/python/dialects/gpu/dialect.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,3 +30,29 @@ def testMMAElementWiseAttr():
3030
# CHECK: %block_dim_y = gpu.block_dim y
3131
print(module)
3232
pass
33+
34+
35+
# CHECK-LABEL: testObjectAttr
36+
@run
37+
def testObjectAttr():
38+
module = Module.create()
39+
target = Attribute.parse("#nvvm.target")
40+
format = gpu.CompilationTarget.Fatbin
41+
object = b"BC\xc0\xde5\x14\x00\x00\x05\x00\x00\x00b\x0c0$MY\xbef"
42+
properties = DictAttr.get({"O": IntegerAttr.get(IntegerType.get_signless(32), 2)})
43+
o = gpu.ObjectAttr.get(target, format, object, properties)
44+
# CHECK: #gpu.object<#nvvm.target, properties = {O = 2 : i32}, "BC\C0\DE5\14\00\00\05\00\00\00b\0C0$MY\BEf">
45+
print(o)
46+
assert o.object == object
47+
48+
o = gpu.ObjectAttr.get(target, format, object)
49+
# CHECK: #gpu.object<#nvvm.target, "BC\C0\DE5\14\00\00\05\00\00\00b\0C0$MY\BEf">
50+
print(o)
51+
52+
object = (
53+
b"//\n// Generated by LLVM NVPTX Back-End\n//\n\n.version 6.0\n.target sm_50"
54+
)
55+
o = gpu.ObjectAttr.get(target, format, object)
56+
# CHECK: #gpu.object<#nvvm.target, "//\0A// Generated by LLVM NVPTX Back-End\0A//\0A\0A.version 6.0\0A.target sm_50">
57+
print(o)
58+
assert o.object == object

mlir/test/python/dialects/gpu/module-to-binary-nvvm.py

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,20 @@ def testGPUToLLVMBin():
3434
pm = PassManager("any")
3535
pm.add("gpu-module-to-binary{format=llvm}")
3636
pm.run(module.operation)
37+
# CHECK-LABEL: gpu.binary @kernel_module1
3738
print(module)
38-
# CHECK-LABEL:gpu.binary @kernel_module1
39-
# CHECK:[#gpu.object<#nvvm.target<chip = "sm_70">, offload = "{{.*}}">]
39+
40+
o = gpu.ObjectAttr(module.body.operations[0].objects[0])
41+
# CHECK: #gpu.object<#nvvm.target<chip = "sm_70">, offload = "{{.*}}">
42+
print(o)
43+
# CHECK: #nvvm.target<chip = "sm_70">
44+
print(o.target)
45+
# CHECK: offload
46+
print(gpu.CompilationTarget(o.format))
47+
# CHECK: b'{{.*}}'
48+
print(o.object)
49+
# CHECK: None
50+
print(o.properties)
4051

4152

4253
# CHECK-LABEL: testGPUToASMBin
@@ -59,6 +70,17 @@ def testGPUToASMBin():
5970
pm = PassManager("any")
6071
pm.add("gpu-module-to-binary{format=isa}")
6172
pm.run(module.operation)
62-
print(module)
6373
# CHECK-LABEL:gpu.binary @kernel_module2
64-
# CHECK:[#gpu.object<#nvvm.target<flags = {fast}>, properties = {O = 2 : i32}, assembly = "{{.*}}">, #gpu.object<#nvvm.target, properties = {O = 2 : i32}, assembly = "{{.*}}">]
74+
print(module)
75+
76+
o = gpu.ObjectAttr(module.body.operations[0].objects[0])
77+
# CHECK: #gpu.object<#nvvm.target<flags = {fast}>
78+
print(o)
79+
# CHECK: #nvvm.target<flags = {fast}>
80+
print(o.target)
81+
# CHECK: assembly
82+
print(gpu.CompilationTarget(o.format))
83+
# CHECK: b'//\n// Generated by LLVM NVPTX Back-End{{.*}}'
84+
print(o.object)
85+
# CHECK: {O = 2 : i32}
86+
print(o.properties)

0 commit comments

Comments
 (0)