Skip to content

Commit 97f9f1a

Browse files
[mlir][python] Expose transform param types (#67421)
This exposes the Transform dialect types `AnyParamType` and `ParamType` via the Python bindings.
1 parent f9149a3 commit 97f9f1a

File tree

4 files changed

+92
-0
lines changed

4 files changed

+92
-0
lines changed

mlir/include/mlir-c/Dialect/Transform.h

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,14 @@ MLIR_CAPI_EXPORTED bool mlirTypeIsATransformAnyOpType(MlirType type);
2727

2828
MLIR_CAPI_EXPORTED MlirType mlirTransformAnyOpTypeGet(MlirContext ctx);
2929

30+
//===---------------------------------------------------------------------===//
31+
// AnyParamType
32+
//===---------------------------------------------------------------------===//
33+
34+
MLIR_CAPI_EXPORTED bool mlirTypeIsATransformAnyParamType(MlirType type);
35+
36+
MLIR_CAPI_EXPORTED MlirType mlirTransformAnyParamTypeGet(MlirContext ctx);
37+
3038
//===---------------------------------------------------------------------===//
3139
// AnyValueType
3240
//===---------------------------------------------------------------------===//
@@ -49,6 +57,17 @@ mlirTransformOperationTypeGet(MlirContext ctx, MlirStringRef operationName);
4957
MLIR_CAPI_EXPORTED MlirStringRef
5058
mlirTransformOperationTypeGetOperationName(MlirType type);
5159

60+
//===---------------------------------------------------------------------===//
61+
// ParamType
62+
//===---------------------------------------------------------------------===//
63+
64+
MLIR_CAPI_EXPORTED bool mlirTypeIsATransformParamType(MlirType type);
65+
66+
MLIR_CAPI_EXPORTED MlirType mlirTransformParamTypeGet(MlirContext ctx,
67+
MlirType type);
68+
69+
MLIR_CAPI_EXPORTED MlirType mlirTransformParamTypeGetType(MlirType type);
70+
5271
#ifdef __cplusplus
5372
}
5473
#endif

mlir/lib/Bindings/Python/DialectTransform.cpp

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,20 @@ void populateDialectTransformSubmodule(const pybind11::module &m) {
3131
"Get an instance of AnyOpType in the given context.", py::arg("cls"),
3232
py::arg("context") = py::none());
3333

34+
//===-------------------------------------------------------------------===//
35+
// AnyParamType
36+
//===-------------------------------------------------------------------===//
37+
38+
auto anyParamType =
39+
mlir_type_subclass(m, "AnyParamType", mlirTypeIsATransformAnyParamType);
40+
anyParamType.def_classmethod(
41+
"get",
42+
[](py::object cls, MlirContext ctx) {
43+
return cls(mlirTransformAnyParamTypeGet(ctx));
44+
},
45+
"Get an instance of AnyParamType in the given context.", py::arg("cls"),
46+
py::arg("context") = py::none());
47+
3448
//===-------------------------------------------------------------------===//
3549
// AnyValueType
3650
//===-------------------------------------------------------------------===//
@@ -71,6 +85,27 @@ void populateDialectTransformSubmodule(const pybind11::module &m) {
7185
return py::str(operationName.data, operationName.length);
7286
},
7387
"Get the name of the payload operation accepted by the handle.");
88+
89+
//===-------------------------------------------------------------------===//
90+
// ParamType
91+
//===-------------------------------------------------------------------===//
92+
93+
auto paramType =
94+
mlir_type_subclass(m, "ParamType", mlirTypeIsATransformParamType);
95+
paramType.def_classmethod(
96+
"get",
97+
[](py::object cls, MlirType type, MlirContext ctx) {
98+
return cls(mlirTransformParamTypeGet(ctx, type));
99+
},
100+
"Get an instance of ParamType for the given type in the given context.",
101+
py::arg("cls"), py::arg("type"), py::arg("context") = py::none());
102+
paramType.def_property_readonly(
103+
"type",
104+
[](MlirType type) {
105+
MlirType paramType = mlirTransformParamTypeGetType(type);
106+
return paramType;
107+
},
108+
"Get the type this ParamType is associated with.");
74109
}
75110

76111
PYBIND11_MODULE(_mlirDialectsTransform, m) {

mlir/lib/CAPI/Dialect/Transform.cpp

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,18 @@ MlirType mlirTransformAnyOpTypeGet(MlirContext ctx) {
2929
return wrap(transform::AnyOpType::get(unwrap(ctx)));
3030
}
3131

32+
//===---------------------------------------------------------------------===//
33+
// AnyParamType
34+
//===---------------------------------------------------------------------===//
35+
36+
bool mlirTypeIsATransformAnyParamType(MlirType type) {
37+
return isa<transform::AnyParamType>(unwrap(type));
38+
}
39+
40+
MlirType mlirTransformAnyParamTypeGet(MlirContext ctx) {
41+
return wrap(transform::AnyParamType::get(unwrap(ctx)));
42+
}
43+
3244
//===---------------------------------------------------------------------===//
3345
// AnyValueType
3446
//===---------------------------------------------------------------------===//
@@ -62,3 +74,19 @@ MlirType mlirTransformOperationTypeGet(MlirContext ctx,
6274
MlirStringRef mlirTransformOperationTypeGetOperationName(MlirType type) {
6375
return wrap(cast<transform::OperationType>(unwrap(type)).getOperationName());
6476
}
77+
78+
//===---------------------------------------------------------------------===//
79+
// AnyOpType
80+
//===---------------------------------------------------------------------===//
81+
82+
bool mlirTypeIsATransformParamType(MlirType type) {
83+
return isa<transform::ParamType>(unwrap(type));
84+
}
85+
86+
MlirType mlirTransformParamTypeGet(MlirContext ctx, MlirType type) {
87+
return wrap(transform::ParamType::get(unwrap(ctx), unwrap(type)));
88+
}
89+
90+
MlirType mlirTransformParamTypeGetType(MlirType type) {
91+
return wrap(cast<transform::ParamType>(unwrap(type)).getType());
92+
}

mlir/test/python/dialects/transform.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,10 @@ def testTypes():
2222
any_op = transform.AnyOpType.get()
2323
print(any_op)
2424

25+
# CHECK: !transform.any_param
26+
any_param = transform.AnyParamType.get()
27+
print(any_param)
28+
2529
# CHECK: !transform.any_value
2630
any_value = transform.AnyValueType.get()
2731
print(any_value)
@@ -32,6 +36,12 @@ def testTypes():
3236
print(concrete_op)
3337
print(concrete_op.operation_name)
3438

39+
# CHECK: !transform.param<i32>
40+
# CHECK: i32
41+
param = transform.ParamType.get(IntegerType.get_signless(32))
42+
print(param)
43+
print(param.type)
44+
3545

3646
@run
3747
def testSequenceOp():

0 commit comments

Comments
 (0)