Skip to content

Commit 360c629

Browse files
[mlir][linalg][transform][python] Drop _get_op_result... from mix-ins. (#65726)
`_get_op_result_or_value` was used in mix-ins to unify the handling of op results and values. However, that function is now called in the generated constructors, such that doing so in the mix-ins is not necessary anymore.
1 parent ddc3346 commit 360c629

File tree

1 file changed

+18
-29
lines changed

1 file changed

+18
-29
lines changed

mlir/python/mlir/dialects/_structured_transform_ops_ext.py

Lines changed: 18 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44

55
try:
66
from ..ir import *
7-
from ._ods_common import get_op_result_or_value as _get_op_result_or_value
87
from ..dialects import pdl, transform
98
except ImportError as e:
109
raise RuntimeError("Error loading imports from extension module") from e
@@ -101,7 +100,7 @@ def _dispatch_mixed_values(
101100
static_values.append(size)
102101
else:
103102
static_values.append(ShapedType.get_dynamic_size())
104-
dynamic_values.append(_get_op_result_or_value(size))
103+
dynamic_values.append(size)
105104
static_values = DenseI64ArrayAttr.get(static_values)
106105

107106
return (dynamic_values, packed_values, static_values)
@@ -204,9 +203,7 @@ class DecomposeOp:
204203
"""Specialization for DecomposeOp class."""
205204

206205
def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None):
207-
super().__init__(
208-
pdl.OperationType.get(), _get_op_result_or_value(target), loc=loc, ip=ip
209-
)
206+
super().__init__(pdl.OperationType.get(), target, loc=loc, ip=ip)
210207

211208

212209
class FuseIntoContainingOp:
@@ -277,9 +274,7 @@ class GeneralizeOp:
277274
"""Specialization for GeneralizeOp class."""
278275

279276
def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None):
280-
super().__init__(
281-
pdl.OperationType.get(), _get_op_result_or_value(target), loc=loc, ip=ip
282-
)
277+
super().__init__(pdl.OperationType.get(), target, loc=loc, ip=ip)
283278

284279

285280
class InterchangeOp:
@@ -296,7 +291,7 @@ def __init__(
296291
pdl_operation_type = pdl.OperationType.get()
297292
super().__init__(
298293
pdl_operation_type,
299-
_get_op_result_or_value(target),
294+
target,
300295
iterator_interchange=iterator_interchange,
301296
loc=loc,
302297
ip=ip,
@@ -415,7 +410,7 @@ def match_op_names(
415410
loc=None,
416411
ip=None,
417412
):
418-
...
413+
...
419414

420415
@overload
421416
@classmethod
@@ -428,7 +423,7 @@ def match_op_names(
428423
loc=None,
429424
ip=None,
430425
):
431-
...
426+
...
432427

433428
@classmethod
434429
def match_op_names(
@@ -441,20 +436,20 @@ def match_op_names(
441436
ip=None,
442437
):
443438
if isinstance(result_type_or_target, Type):
444-
result_type = result_type_or_target
445-
target = target_or_names
446-
names = names_or_none
439+
result_type = result_type_or_target
440+
target = target_or_names
441+
names = names_or_none
447442
else:
448-
result_type = transform.AnyOpType.get()
449-
target = result_type_or_target
450-
names = target_or_names
443+
result_type = transform.AnyOpType.get()
444+
target = result_type_or_target
445+
names = target_or_names
451446

452447
if isinstance(names, str):
453-
names = [names]
448+
names = [names]
454449

455450
return cls(
456451
result_type,
457-
_get_op_result_or_value(target),
452+
target,
458453
ops=ArrayAttr.get(list(map(lambda s: StringAttr.get(s), names))),
459454
loc=loc,
460455
ip=ip,
@@ -479,7 +474,7 @@ def __init__(
479474
result_type,
480475
result_type,
481476
result_type,
482-
_get_op_result_or_value(target),
477+
target,
483478
dimension=dimension,
484479
target_size=target_size,
485480
divisor=divisor,
@@ -530,9 +525,7 @@ class ScalarizeOp:
530525

531526
def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None):
532527
pdl_operation_type = pdl.OperationType.get()
533-
super().__init__(
534-
pdl_operation_type, _get_op_result_or_value(target), loc=loc, ip=ip
535-
)
528+
super().__init__(pdl_operation_type, target, loc=loc, ip=ip)
536529

537530

538531
class SplitOp:
@@ -552,9 +545,7 @@ def __init__(
552545
dynamic_split_point = None
553546
else:
554547
static_split_point = ShapedType.get_dynamic_size()
555-
dynamic_split_point = _get_op_result_or_value(split_point)
556-
557-
target = _get_op_result_or_value(target)
548+
dynamic_split_point = split_point
558549

559550
super().__init__(
560551
target.type,
@@ -626,8 +617,6 @@ def __init__(
626617
)
627618
target = target_or_none
628619

629-
target = _get_op_result_or_value(target)
630-
631620
super().__init__(
632621
target.type,
633622
loop_types,
@@ -750,7 +739,7 @@ def __init__(
750739
pdl_operation_type = pdl.OperationType.get()
751740
super().__init__(
752741
pdl_operation_type,
753-
_get_op_result_or_value(target),
742+
target,
754743
disable_multi_reduction_to_contract_patterns=disable_multi_reduction_to_contract_patterns,
755744
disable_transfer_permutation_map_lowering_patterns=disable_transfer_permutation_map_lowering_patterns,
756745
vectorize_nd_extract=vectorize_nd_extract,

0 commit comments

Comments
 (0)