4
4
5
5
try :
6
6
from ..ir import *
7
- from ._ods_common import get_op_result_or_value as _get_op_result_or_value
8
7
from ..dialects import pdl , transform
9
8
except ImportError as e :
10
9
raise RuntimeError ("Error loading imports from extension module" ) from e
@@ -101,7 +100,7 @@ def _dispatch_mixed_values(
101
100
static_values .append (size )
102
101
else :
103
102
static_values .append (ShapedType .get_dynamic_size ())
104
- dynamic_values .append (_get_op_result_or_value ( size ) )
103
+ dynamic_values .append (size )
105
104
static_values = DenseI64ArrayAttr .get (static_values )
106
105
107
106
return (dynamic_values , packed_values , static_values )
@@ -204,9 +203,7 @@ class DecomposeOp:
204
203
"""Specialization for DecomposeOp class."""
205
204
206
205
def __init__ (self , target : Union [Operation , Value ], * , loc = None , ip = None ):
207
- super ().__init__ (
208
- pdl .OperationType .get (), _get_op_result_or_value (target ), loc = loc , ip = ip
209
- )
206
+ super ().__init__ (pdl .OperationType .get (), target , loc = loc , ip = ip )
210
207
211
208
212
209
class FuseIntoContainingOp :
@@ -277,9 +274,7 @@ class GeneralizeOp:
277
274
"""Specialization for GeneralizeOp class."""
278
275
279
276
def __init__ (self , target : Union [Operation , Value ], * , loc = None , ip = None ):
280
- super ().__init__ (
281
- pdl .OperationType .get (), _get_op_result_or_value (target ), loc = loc , ip = ip
282
- )
277
+ super ().__init__ (pdl .OperationType .get (), target , loc = loc , ip = ip )
283
278
284
279
285
280
class InterchangeOp :
@@ -296,7 +291,7 @@ def __init__(
296
291
pdl_operation_type = pdl .OperationType .get ()
297
292
super ().__init__ (
298
293
pdl_operation_type ,
299
- _get_op_result_or_value ( target ) ,
294
+ target ,
300
295
iterator_interchange = iterator_interchange ,
301
296
loc = loc ,
302
297
ip = ip ,
@@ -415,7 +410,7 @@ def match_op_names(
415
410
loc = None ,
416
411
ip = None ,
417
412
):
418
- ...
413
+ ...
419
414
420
415
@overload
421
416
@classmethod
@@ -428,7 +423,7 @@ def match_op_names(
428
423
loc = None ,
429
424
ip = None ,
430
425
):
431
- ...
426
+ ...
432
427
433
428
@classmethod
434
429
def match_op_names (
@@ -441,20 +436,20 @@ def match_op_names(
441
436
ip = None ,
442
437
):
443
438
if isinstance (result_type_or_target , Type ):
444
- result_type = result_type_or_target
445
- target = target_or_names
446
- names = names_or_none
439
+ result_type = result_type_or_target
440
+ target = target_or_names
441
+ names = names_or_none
447
442
else :
448
- result_type = transform .AnyOpType .get ()
449
- target = result_type_or_target
450
- names = target_or_names
443
+ result_type = transform .AnyOpType .get ()
444
+ target = result_type_or_target
445
+ names = target_or_names
451
446
452
447
if isinstance (names , str ):
453
- names = [names ]
448
+ names = [names ]
454
449
455
450
return cls (
456
451
result_type ,
457
- _get_op_result_or_value ( target ) ,
452
+ target ,
458
453
ops = ArrayAttr .get (list (map (lambda s : StringAttr .get (s ), names ))),
459
454
loc = loc ,
460
455
ip = ip ,
@@ -479,7 +474,7 @@ def __init__(
479
474
result_type ,
480
475
result_type ,
481
476
result_type ,
482
- _get_op_result_or_value ( target ) ,
477
+ target ,
483
478
dimension = dimension ,
484
479
target_size = target_size ,
485
480
divisor = divisor ,
@@ -530,9 +525,7 @@ class ScalarizeOp:
530
525
531
526
def __init__ (self , target : Union [Operation , Value ], * , loc = None , ip = None ):
532
527
pdl_operation_type = pdl .OperationType .get ()
533
- super ().__init__ (
534
- pdl_operation_type , _get_op_result_or_value (target ), loc = loc , ip = ip
535
- )
528
+ super ().__init__ (pdl_operation_type , target , loc = loc , ip = ip )
536
529
537
530
538
531
class SplitOp :
@@ -552,9 +545,7 @@ def __init__(
552
545
dynamic_split_point = None
553
546
else :
554
547
static_split_point = ShapedType .get_dynamic_size ()
555
- dynamic_split_point = _get_op_result_or_value (split_point )
556
-
557
- target = _get_op_result_or_value (target )
548
+ dynamic_split_point = split_point
558
549
559
550
super ().__init__ (
560
551
target .type ,
@@ -626,8 +617,6 @@ def __init__(
626
617
)
627
618
target = target_or_none
628
619
629
- target = _get_op_result_or_value (target )
630
-
631
620
super ().__init__ (
632
621
target .type ,
633
622
loop_types ,
@@ -750,7 +739,7 @@ def __init__(
750
739
pdl_operation_type = pdl .OperationType .get ()
751
740
super ().__init__ (
752
741
pdl_operation_type ,
753
- _get_op_result_or_value ( target ) ,
742
+ target ,
754
743
disable_multi_reduction_to_contract_patterns = disable_multi_reduction_to_contract_patterns ,
755
744
disable_transfer_permutation_map_lowering_patterns = disable_transfer_permutation_map_lowering_patterns ,
756
745
vectorize_nd_extract = vectorize_nd_extract ,
0 commit comments