Skip to content

Commit bae8e1f

Browse files
authored
[MLIR][DRR] Fix inconsistent operand and arg index usage (#139816)
Background issue: #139813 In [emitEitherOperandMatch()](https://github.com/llvm/llvm-project/blob/e62fc14a5d214f801758b35bdcad0c8efc65e8b8/mlir/tools/mlir-tblgen/RewriterGen.cpp#L774) we check if `op.getArg(argIndex)` is a `NamedTypeConstraint`: ```cpp } else if (isa<NamedTypeConstraint *>(op.getArg(argIndex))) { emitOperandMatch(tree, opName, /*operandName=*/formatv("v{0}", i).str(), operandIndex, /*operandMatcher=*/eitherArgTree.getArgAsLeaf(i), /*argName=*/eitherArgTree.getArgName(i), argIndex, /*variadicSubIndex=*/std::nullopt); ++operandIndex; } ``` but in `emitOperandMatch()` we cast on `op.getArg(operandIndex)`, which is incorrect if the operation has attributes or other non-operand arguments before its operands.
1 parent 6e574a4 commit bae8e1f

File tree

3 files changed

+40
-13
lines changed

3 files changed

+40
-13
lines changed

mlir/test/lib/Dialect/Test/TestOps.td

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1873,6 +1873,11 @@ def TestEitherOpB : TEST_Op<"either_op_b"> {
18731873
let results = (outs I32:$output);
18741874
}
18751875

1876+
def TestEitherOpC : TEST_Op<"either_op_c"> {
1877+
let arguments = (ins AnyI32Attr:$attr, AnyInteger:$arg0, AnyInteger:$arg1);
1878+
let results = (outs I32:$output);
1879+
}
1880+
18761881
def : Pat<(TestEitherOpA (either I32:$arg1, I16:$arg2), $x),
18771882
(TestEitherOpB $arg2, $x)>;
18781883

@@ -1884,6 +1889,9 @@ def : Pat<(TestEitherOpA (either (TestEitherOpB I32:$arg1, $_),
18841889
$x),
18851890
(TestEitherOpB $arg2, $x)>;
18861891

1892+
def : Pat<(TestEitherOpC ConstantAttr<I32Attr, "0">, (either $arg1, I32:$arg2)),
1893+
(TestEitherOpB $arg1, $arg2)>;
1894+
18871895
def TestEitherHelperOpA : TEST_Op<"either_helper_op_a"> {
18881896
let arguments = (ins I32:$arg0);
18891897
let results = (outs I32:$output);

mlir/test/mlir-tblgen/pattern.mlir

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -609,17 +609,17 @@ func.func @redundantTest(%arg0: i32) -> i32 {
609609
// Test either directive
610610
//===----------------------------------------------------------------------===//
611611

612-
// CHECK: @either_dag_leaf_only
613-
func.func @either_dag_leaf_only_1(%arg0 : i32, %arg1 : i16, %arg2 : i8) -> () {
612+
// CHECK-LABEL: @eitherDagLeafOnly
613+
func.func @eitherDagLeafOnly(%arg0 : i32, %arg1 : i16, %arg2 : i8) -> () {
614614
// CHECK: "test.either_op_b"(%arg1, %arg2) : (i16, i8) -> i32
615615
%0 = "test.either_op_a"(%arg0, %arg1, %arg2) : (i32, i16, i8) -> i32
616616
// CHECK: "test.either_op_b"(%arg1, %arg2) : (i16, i8) -> i32
617617
%1 = "test.either_op_a"(%arg1, %arg0, %arg2) : (i16, i32, i8) -> i32
618618
return
619619
}
620620

621-
// CHECK: @either_dag_leaf_dag_node
622-
func.func @either_dag_leaf_dag_node(%arg0 : i32, %arg1 : i16, %arg2 : i8) -> () {
621+
// CHECK-LABEL: @eitherDagLeafDagNode
622+
func.func @eitherDagLeafDagNode(%arg0 : i32, %arg1 : i16, %arg2 : i8) -> () {
623623
%0 = "test.either_op_b"(%arg0, %arg0) : (i32, i32) -> i32
624624
// CHECK: "test.either_op_b"(%arg1, %arg2) : (i16, i8) -> i32
625625
%1 = "test.either_op_a"(%0, %arg1, %arg2) : (i32, i16, i8) -> i32
@@ -628,8 +628,8 @@ func.func @either_dag_leaf_dag_node(%arg0 : i32, %arg1 : i16, %arg2 : i8) -> ()
628628
return
629629
}
630630

631-
// CHECK: @either_dag_node_dag_node
632-
func.func @either_dag_node_dag_node(%arg0 : i32, %arg1 : i16, %arg2 : i8) -> () {
631+
// CHECK-LABEL: @eitherDagNodeDagNode
632+
func.func @eitherDagNodeDagNode(%arg0 : i32, %arg1 : i16, %arg2 : i8) -> () {
633633
%0 = "test.either_op_b"(%arg0, %arg0) : (i32, i32) -> i32
634634
%1 = "test.either_op_b"(%arg1, %arg1) : (i16, i16) -> i32
635635
// CHECK: "test.either_op_b"(%arg1, %arg2) : (i16, i8) -> i32
@@ -639,24 +639,38 @@ func.func @either_dag_node_dag_node(%arg0 : i32, %arg1 : i16, %arg2 : i8) -> ()
639639
return
640640
}
641641

642+
// CHECK-LABEL: @testEitherOpWithAttr
643+
func.func @testEitherOpWithAttr(%arg0 : i32, %arg1 : i16) -> () {
644+
// CHECK: "test.either_op_b"(%arg1, %arg0) : (i16, i32) -> i32
645+
%0 = "test.either_op_c"(%arg0, %arg1) {attr = 0 : i32} : (i32, i16) -> i32
646+
// CHECK: "test.either_op_b"(%arg1, %arg0) : (i16, i32) -> i32
647+
%1 = "test.either_op_c"(%arg1, %arg0) {attr = 0 : i32} : (i16, i32) -> i32
648+
// CHECK: "test.either_op_c"(%arg0, %arg1) <{attr = 1 : i32}> : (i32, i16) -> i32
649+
%2 = "test.either_op_c"(%arg0, %arg1) {attr = 1 : i32} : (i32, i16) -> i32
650+
return
651+
}
652+
642653
//===----------------------------------------------------------------------===//
643654
// Test that ops without type deduction can be created with type builders.
644655
//===----------------------------------------------------------------------===//
645656

657+
// CHECK-LABEL: @explicitReturnTypeTest
646658
func.func @explicitReturnTypeTest(%arg0 : i64) -> i8 {
647659
%0 = "test.source_op"(%arg0) {tag = 11 : i32} : (i64) -> i8
648660
// CHECK: "test.op_x"(%arg0) : (i64) -> i32
649661
// CHECK: "test.op_x"(%0) : (i32) -> i8
650662
return %0 : i8
651663
}
652664

665+
// CHECK-LABEL: @returnTypeBuilderTest
653666
func.func @returnTypeBuilderTest(%arg0 : i1) -> i8 {
654667
%0 = "test.source_op"(%arg0) {tag = 22 : i32} : (i1) -> i8
655668
// CHECK: "test.op_x"(%arg0) : (i1) -> i1
656669
// CHECK: "test.op_x"(%0) : (i1) -> i8
657670
return %0 : i8
658671
}
659672

673+
// CHECK-LABEL: @multipleReturnTypeBuildTest
660674
func.func @multipleReturnTypeBuildTest(%arg0 : i1) -> i1 {
661675
%0 = "test.source_op"(%arg0) {tag = 33 : i32} : (i1) -> i1
662676
// CHECK: "test.one_to_two"(%arg0) : (i1) -> (i64, i32)
@@ -666,13 +680,15 @@ func.func @multipleReturnTypeBuildTest(%arg0 : i1) -> i1 {
666680
return %0 : i1
667681
}
668682

683+
// CHECK-LABEL: @copyValueType
669684
func.func @copyValueType(%arg0 : i8) -> i32 {
670685
%0 = "test.source_op"(%arg0) {tag = 44 : i32} : (i8) -> i32
671686
// CHECK: "test.op_x"(%arg0) : (i8) -> i8
672687
// CHECK: "test.op_x"(%0) : (i8) -> i32
673688
return %0 : i32
674689
}
675690

691+
// CHECK-LABEL: @multipleReturnTypeDifferent
676692
func.func @multipleReturnTypeDifferent(%arg0 : i1) -> i64 {
677693
%0 = "test.source_op"(%arg0) {tag = 55 : i32} : (i1) -> i64
678694
// CHECK: "test.one_to_two"(%arg0) : (i1) -> (i1, i64)
@@ -684,6 +700,7 @@ func.func @multipleReturnTypeDifferent(%arg0 : i1) -> i64 {
684700
// Test that multiple trailing directives can be mixed in patterns.
685701
//===----------------------------------------------------------------------===//
686702

703+
// CHECK-LABEL: @returnTypeAndLocation
687704
func.func @returnTypeAndLocation(%arg0 : i32) -> i1 {
688705
%0 = "test.source_op"(%arg0) {tag = 66 : i32} : (i32) -> i1
689706
// CHECK: "test.op_x"(%arg0) : (i32) -> i32 loc("loc1")
@@ -696,6 +713,7 @@ func.func @returnTypeAndLocation(%arg0 : i32) -> i1 {
696713
// Test that patterns can create ConstantStrAttr
697714
//===----------------------------------------------------------------------===//
698715

716+
// CHECK-LABEL: @testConstantStrAttr
699717
func.func @testConstantStrAttr() -> () {
700718
// CHECK: test.has_str_value {value = "foo"}
701719
test.no_str_value {value = "bar"}
@@ -706,6 +724,7 @@ func.func @testConstantStrAttr() -> () {
706724
// Test that patterns with variadics propagate sizes
707725
//===----------------------------------------------------------------------===//
708726

727+
// CHECK-LABEL: @testVariadic
709728
func.func @testVariadic(%arg_0: i32, %arg_1: i32, %brg: i64,
710729
%crg_0: f32, %crg_1: f32, %crg_2: f32, %crg_3: f32) -> () {
711730
// CHECK: "test.variadic_rewrite_dst_op"(%arg2, %arg3, %arg4, %arg5, %arg6, %arg0, %arg1) <{operandSegmentSizes = array<i32: 1, 4, 2>}> : (i64, f32, f32, f32, f32, i32, i32) -> ()

mlir/tools/mlir-tblgen/RewriterGen.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -658,7 +658,7 @@ void PatternEmitter::emitOpMatch(DagNode tree, StringRef opName, int depth) {
658658
if (isa<NamedTypeConstraint *>(opArg)) {
659659
auto operandName =
660660
formatv("{0}.getODSOperands({1})", castedName, nextOperand);
661-
emitOperandMatch(tree, castedName, operandName.str(), opArgIdx,
661+
emitOperandMatch(tree, castedName, operandName.str(), nextOperand,
662662
/*operandMatcher=*/tree.getArgAsLeaf(i),
663663
/*argName=*/tree.getArgName(i), opArgIdx,
664664
/*variadicSubIndex=*/std::nullopt);
@@ -680,7 +680,7 @@ void PatternEmitter::emitOperandMatch(DagNode tree, StringRef opName,
680680
int argIndex,
681681
std::optional<int> variadicSubIndex) {
682682
Operator &op = tree.getDialectOp(opMap);
683-
auto *operand = cast<NamedTypeConstraint *>(op.getArg(operandIndex));
683+
NamedTypeConstraint operand = op.getOperand(operandIndex);
684684

685685
// If a constraint is specified, we need to generate C++ statements to
686686
// check the constraint.
@@ -693,8 +693,8 @@ void PatternEmitter::emitOperandMatch(DagNode tree, StringRef opName,
693693
// Only need to verify if the matcher's type is different from the one
694694
// of op definition.
695695
Constraint constraint = operandMatcher.getAsConstraint();
696-
if (operand->constraint != constraint) {
697-
if (operand->isVariableLength()) {
696+
if (operand.constraint != constraint) {
697+
if (operand.isVariableLength()) {
698698
auto error = formatv(
699699
"further constrain op {0}'s variadic operand #{1} unsupported now",
700700
op.getOperationName(), argIndex);
@@ -706,7 +706,7 @@ void PatternEmitter::emitOperandMatch(DagNode tree, StringRef opName,
706706
verifier, opName, self.str(),
707707
formatv(
708708
"\"operand {0} of op '{1}' failed to satisfy constraint: '{2}'\"",
709-
operand - op.operand_begin(), op.getOperationName(),
709+
operandIndex, op.getOperationName(),
710710
escapeString(constraint.getSummary()))
711711
.str());
712712
}
@@ -715,7 +715,7 @@ void PatternEmitter::emitOperandMatch(DagNode tree, StringRef opName,
715715
// Capture the value
716716
// `$_` is a special symbol to ignore op argument matching.
717717
if (!argName.empty() && argName != "_") {
718-
auto res = symbolInfoMap.findBoundSymbol(argName, tree, op, operandIndex,
718+
auto res = symbolInfoMap.findBoundSymbol(argName, tree, op, argIndex,
719719
variadicSubIndex);
720720
if (res == symbolInfoMap.end())
721721
PrintFatalError(loc, formatv("symbol not found: {0}", argName));
@@ -821,7 +821,7 @@ void PatternEmitter::emitVariadicOperandMatch(DagNode tree,
821821
StringRef variadicTreeName = variadicArgTree.getSymbol();
822822
if (!variadicTreeName.empty()) {
823823
auto res =
824-
symbolInfoMap.findBoundSymbol(variadicTreeName, tree, op, operandIndex,
824+
symbolInfoMap.findBoundSymbol(variadicTreeName, tree, op, argIndex,
825825
/*variadicSubIndex=*/std::nullopt);
826826
if (res == symbolInfoMap.end())
827827
PrintFatalError(loc, formatv("symbol not found: {0}", variadicTreeName));

0 commit comments

Comments
 (0)