Skip to content

Commit 5fb6cc8

Browse files
committed
Address comments
Signed-off-by: Jacques Pienaar <[email protected]>
1 parent d73570c commit 5fb6cc8

File tree

3 files changed

+12
-6
lines changed

3 files changed

+12
-6
lines changed

mlir/include/mlir/Bindings/Python/PybindAdaptors.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ struct type_caster<MlirFrozenRewritePatternSet> {
206206
bool load(handle src, bool) {
207207
py::object capsule = mlirApiObjectToCapsule(src);
208208
value = mlirPythonCapsuleToFrozenRewritePatternSet(capsule.ptr());
209-
return !mlirModuleIsNull(value);
209+
return value.ptr != nullptr;
210210
}
211211
static handle cast(MlirFrozenRewritePatternSet v, return_value_policy,
212212
handle) {

mlir/python/mlir/dialects/pdl.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from ._pdl_ops_gen import _Dialect
77
from .._mlir_libs._mlirDialectsPDL import *
88
from .._mlir_libs._mlirDialectsPDL import OperationType
9-
9+
from ..extras.meta import region_op
1010

1111
try:
1212
from ..ir import *
@@ -127,6 +127,9 @@ def body(self):
127127
return self.regions[0].blocks[0]
128128

129129

130+
pattern = region_op(PatternOp.__base__)
131+
132+
130133
@_ods_cext.register_operation(_Dialect, replace=True)
131134
class ReplaceOp(ReplaceOp):
132135
"""Specialization for PDL replace op class."""
@@ -195,6 +198,9 @@ def body(self):
195198
return self.regions[0].blocks[0]
196199

197200

201+
rewrite = region_op(RewriteOp)
202+
203+
198204
@_ods_cext.register_operation(_Dialect, replace=True)
199205
class TypeOp(TypeOp):
200206
"""Specialization for PDL type op class."""

mlir/test/python/integration/dialects/pdl.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,8 @@ def add_func(a, b):
3838
m = Module.create()
3939
with InsertionPoint(m.body):
4040
# Change all arith.addi with index types to arith.muli.
41-
pattern = pdl.PatternOp(1, "addi_to_mul")
42-
with InsertionPoint(pattern.body):
41+
@pdl.pattern(benefit=1, sym_name="addi_to_mul")
42+
def pat():
4343
# Match arith.addi with index types.
4444
index_type = pdl.TypeOp(IndexType.get())
4545
operand0 = pdl.OperandOp(index_type)
@@ -49,8 +49,8 @@ def add_func(a, b):
4949
)
5050

5151
# Replace the matched op with arith.muli.
52-
rewrite = pdl.RewriteOp(op0)
53-
with InsertionPoint(rewrite.add_body()):
52+
@pdl.rewrite()
53+
def rew():
5454
newOp = pdl.OperationOp(
5555
name="arith.muli", args=[operand0, operand1], types=[index_type]
5656
)

0 commit comments

Comments
 (0)