Skip to content

Commit 07bd300

Browse files
committed
kwonly replace arg
1 parent ea64522 commit 07bd300

File tree

6 files changed

+40
-39
lines changed

6 files changed

+40
-39
lines changed

mlir/include/mlir-c/Bindings/Python/Interop.h

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -118,25 +118,24 @@
118118

119119
/** Attribute on main C extension module (_mlir) that corresponds to the
120120
* type caster registration binding. The signature of the function is:
121-
* def register_type_caster(MlirTypeID mlirTypeID, py::function typeCaster,
122-
* bool replace)
123-
* where replace indicates the typeCaster should replace any existing registered
124-
* type casters (such as those for upstream ConcreteTypes). The interface of the
125-
* typeCaster is:
126-
* def type_caster(ir.Type) -> SubClassTypeT
127-
* where SubClassTypeT indicates the result should be a subclass (inherit from)
128-
* ir.Type.
121+
* def register_type_caster(MlirTypeID mlirTypeID, *, bool replace)
122+
* which then takes a typeCaster (register_type_caster is meant to be used as a
123+
* decorator from python), and where replace indicates the typeCaster should
124+
* replace any existing registered type casters (such as those for upstream
125+
* ConcreteTypes). The interface of the typeCaster is: def type_caster(ir.Type)
126+
* -> SubClassTypeT where SubClassTypeT indicates the result should be a
127+
* subclass (inherit from) ir.Type.
129128
*/
130129
#define MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR "register_type_caster"
131130

132131
/** Attribute on main C extension module (_mlir) that corresponds to the
133132
* value caster registration binding. The signature of the function is:
134-
* def register_value_caster(MlirTypeID mlirTypeID, bool replace,
135-
* py::function valueCaster)
136-
* where replace indicates the valueCaster should replace any existing
137-
* registered value casters. The interface of the valueCaster is:
138-
* def value_caster(ir.Value) -> SubClassValueT
139-
* where SubClassValueT indicates the result should be a subclass (inherit from)
133+
* def register_value_caster(MlirTypeID mlirTypeID, *, bool replace)
134+
* which then takes a valueCaster (register_value_caster is meant to be used as
135+
* a decorator, from python), and where replace indicates the valueCaster should
136+
* replace any existing registered value casters. The interface of the
137+
* valueCaster is: def value_caster(ir.Value) -> SubClassValueT where
138+
* SubClassValueT indicates the result should be a subclass (inherit from)
140139
* ir.Value.
141140
*/
142141
#define MLIR_PYTHON_CAPI_VALUE_CASTER_REGISTER_ATTR "register_value_caster"

mlir/include/mlir/Bindings/Python/PybindAdaptors.h

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -497,11 +497,10 @@ class mlir_type_subclass : public pure_subclass {
497497
if (getTypeIDFunction) {
498498
py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
499499
.attr(MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR)(
500-
getTypeIDFunction(),
501-
pybind11::cpp_function(
502-
[thisClass = thisClass](const py::object &mlirType) {
503-
return thisClass(mlirType);
504-
}));
500+
getTypeIDFunction())(pybind11::cpp_function(
501+
[thisClass = thisClass](const py::object &mlirType) {
502+
return thisClass(mlirType);
503+
}));
505504
}
506505
}
507506
};

mlir/lib/Bindings/Python/MainModule.cpp

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,8 @@ PYBIND11_MODULE(_mlir, m) {
4444
"dialect_namespace"_a, "dialect_class"_a,
4545
"Testing hook for directly registering a dialect")
4646
.def("_register_operation_impl", &PyGlobals::registerOperationImpl,
47-
"operation_name"_a, "operation_class"_a, "replace"_a = false,
47+
"operation_name"_a, "operation_class"_a, py::kw_only(),
48+
"replace"_a = false,
4849
"Testing hook for directly registering an operation");
4950

5051
// Aside from making the globals accessible to python, having python manage
@@ -80,16 +81,19 @@ PYBIND11_MODULE(_mlir, m) {
8081
return opClass;
8182
});
8283
},
83-
"dialect_class"_a, "replace"_a = false,
84+
"dialect_class"_a, py::kw_only(), "replace"_a = false,
8485
"Produce a class decorator for registering an Operation class as part of "
8586
"a dialect");
8687
m.def(
8788
MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR,
88-
[](MlirTypeID mlirTypeID, py::function typeCaster, bool replace) {
89-
PyGlobals::get().registerTypeCaster(mlirTypeID, std::move(typeCaster),
90-
replace);
89+
[](MlirTypeID mlirTypeID, bool replace) -> py::cpp_function {
90+
return py::cpp_function([mlirTypeID,
91+
replace](py::object typeCaster) -> py::object {
92+
PyGlobals::get().registerTypeCaster(mlirTypeID, typeCaster, replace);
93+
return typeCaster;
94+
});
9195
},
92-
"typeid"_a, "type_caster"_a, "replace"_a = false,
96+
"typeid"_a, py::kw_only(), "replace"_a = false,
9397
"Register a type caster for casting MLIR types to custom user types.");
9498
m.def(
9599
MLIR_PYTHON_CAPI_VALUE_CASTER_REGISTER_ATTR,
@@ -101,7 +105,7 @@ PYBIND11_MODULE(_mlir, m) {
101105
return valueCaster;
102106
});
103107
},
104-
"typeid"_a, "replace"_a = false,
108+
"typeid"_a, py::kw_only(), "replace"_a = false,
105109
"Register a value caster for casting MLIR values to custom user values.");
106110

107111
// Define and populate IR submodule.

mlir/test/python/dialects/arith_dialect.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,11 @@ def testFastMathFlags():
4141
def testArithValue():
4242
def _binary_op(lhs, rhs, op: str) -> "ArithValue":
4343
op = op.capitalize()
44-
if arith._is_float_type(lhs.type):
44+
if arith._is_float_type(lhs.type) and arith._is_float_type(rhs.type):
4545
op += "F"
46-
elif arith._is_integer_like_type(lhs.type):
46+
elif arith._is_integer_like_type(lhs.type) and arith._is_integer_like_type(
47+
lhs.type
48+
):
4749
op += "I"
4850
else:
4951
raise NotImplementedError(f"Unsupported '{op}' operands: {lhs}, {rhs}")

mlir/test/python/dialects/python_test.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -508,31 +508,29 @@ def testCustomTypeTypeCaster():
508508
# CHECK: Type caster is already registered
509509
try:
510510

511+
@register_type_caster(c.typeid)
511512
def type_caster(pytype):
512513
return test.TestIntegerRankedTensorType(pytype)
513514

514-
register_type_caster(c.typeid, type_caster)
515515
except RuntimeError as e:
516516
print(e)
517517

518-
def type_caster(pytype):
519-
return RankedTensorType(pytype)
520-
521518
# python_test dialect registers a caster for RankedTensorType in its extension (pybind) module.
522519
# So this one replaces that one (successfully). And then just to be sure we restore the original caster below.
523-
register_type_caster(c.typeid, type_caster, replace=True)
520+
@register_type_caster(c.typeid, replace=True)
521+
def type_caster(pytype):
522+
return RankedTensorType(pytype)
524523

525524
d = tensor.EmptyOp([10, 10], IntegerType.get_signless(5)).result
526525
# CHECK: tensor<10x10xi5>
527526
print(d.type)
528527
# CHECK: ranked tensor type RankedTensorType(tensor<10x10xi5>)
529528
print("ranked tensor type", repr(d.type))
530529

530+
@register_type_caster(c.typeid, replace=True)
531531
def type_caster(pytype):
532532
return test.TestIntegerRankedTensorType(pytype)
533533

534-
register_type_caster(c.typeid, type_caster, replace=True)
535-
536534
d = tensor.EmptyOp([10, 10], IntegerType.get_signless(5)).result
537535
# CHECK: tensor<10x10xi5>
538536
print(d.type)

mlir/test/python/lib/PythonTestModule.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -74,12 +74,11 @@ PYBIND11_MODULE(_mlirPythonTest, m) {
7474
MlirTypeID mlirRankedTensorTypeID = mlirRankedTensorTypeGetTypeID();
7575

7676
py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
77-
.attr(MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR)(
78-
mlirRankedTensorTypeID,
77+
.attr(MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR)(mlirRankedTensorTypeID,
78+
"replace"_a = true)(
7979
pybind11::cpp_function([typeCls](const py::object &mlirType) {
8080
return typeCls.get_class()(mlirType);
81-
}),
82-
/*replace=*/true);
81+
}));
8382

8483
auto valueCls = mlir_value_subclass(m, "TestTensorValue",
8584
mlirTypeIsAPythonTestTestTensorValue)

0 commit comments

Comments
 (0)