Skip to content

Commit 09316e0

Browse files
committed
[MLIR][DRR] Fix inconsistent operand and arg index usage
1 parent 7b8bc1b commit 09316e0

File tree

2 files changed

+15
-7
lines changed

2 files changed

+15
-7
lines changed

mlir/test/lib/Dialect/Test/TestOps.td

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1872,6 +1872,11 @@ def TestEitherOpB : TEST_Op<"either_op_b"> {
18721872
let results = (outs I32:$output);
18731873
}
18741874

1875+
def TestEitherOpC : TEST_Op<"either_op_c"> {
1876+
let arguments = (ins AnyI32Attr:$attr, AnyInteger:$arg0, AnyInteger:$arg1);
1877+
let results = (outs I32:$output);
1878+
}
1879+
18751880
def : Pat<(TestEitherOpA (either I32:$arg1, I16:$arg2), $x),
18761881
(TestEitherOpB $arg2, $x)>;
18771882

@@ -1883,6 +1888,9 @@ def : Pat<(TestEitherOpA (either (TestEitherOpB I32:$arg1, $_),
18831888
$x),
18841889
(TestEitherOpB $arg2, $x)>;
18851890

1891+
def : Pat<(TestEitherOpC ConstantAttr<I32Attr, "0">, (either $arg1, I32:$arg2)),
1892+
(TestEitherOpB $arg1, $arg2)>;
1893+
18861894
def TestEitherHelperOpA : TEST_Op<"either_helper_op_a"> {
18871895
let arguments = (ins I32:$arg0);
18881896
let results = (outs I32:$output);

mlir/tools/mlir-tblgen/RewriterGen.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -658,7 +658,7 @@ void PatternEmitter::emitOpMatch(DagNode tree, StringRef opName, int depth) {
658658
if (isa<NamedTypeConstraint *>(opArg)) {
659659
auto operandName =
660660
formatv("{0}.getODSOperands({1})", castedName, nextOperand);
661-
emitOperandMatch(tree, castedName, operandName.str(), opArgIdx,
661+
emitOperandMatch(tree, castedName, operandName.str(), nextOperand,
662662
/*operandMatcher=*/tree.getArgAsLeaf(i),
663663
/*argName=*/tree.getArgName(i), opArgIdx,
664664
/*variadicSubIndex=*/std::nullopt);
@@ -680,7 +680,7 @@ void PatternEmitter::emitOperandMatch(DagNode tree, StringRef opName,
680680
int argIndex,
681681
std::optional<int> variadicSubIndex) {
682682
Operator &op = tree.getDialectOp(opMap);
683-
auto *operand = cast<NamedTypeConstraint *>(op.getArg(operandIndex));
683+
NamedTypeConstraint operand = op.getOperand(operandIndex);
684684

685685
// If a constraint is specified, we need to generate C++ statements to
686686
// check the constraint.
@@ -693,8 +693,8 @@ void PatternEmitter::emitOperandMatch(DagNode tree, StringRef opName,
693693
// Only need to verify if the matcher's type is different from the one
694694
// of op definition.
695695
Constraint constraint = operandMatcher.getAsConstraint();
696-
if (operand->constraint != constraint) {
697-
if (operand->isVariableLength()) {
696+
if (operand.constraint != constraint) {
697+
if (operand.isVariableLength()) {
698698
auto error = formatv(
699699
"further constrain op {0}'s variadic operand #{1} unsupported now",
700700
op.getOperationName(), argIndex);
@@ -706,7 +706,7 @@ void PatternEmitter::emitOperandMatch(DagNode tree, StringRef opName,
706706
verifier, opName, self.str(),
707707
formatv(
708708
"\"operand {0} of op '{1}' failed to satisfy constraint: '{2}'\"",
709-
operand - op.operand_begin(), op.getOperationName(),
709+
operandIndex, op.getOperationName(),
710710
escapeString(constraint.getSummary()))
711711
.str());
712712
}
@@ -715,7 +715,7 @@ void PatternEmitter::emitOperandMatch(DagNode tree, StringRef opName,
715715
// Capture the value
716716
// `$_` is a special symbol to ignore op argument matching.
717717
if (!argName.empty() && argName != "_") {
718-
auto res = symbolInfoMap.findBoundSymbol(argName, tree, op, operandIndex,
718+
auto res = symbolInfoMap.findBoundSymbol(argName, tree, op, argIndex,
719719
variadicSubIndex);
720720
if (res == symbolInfoMap.end())
721721
PrintFatalError(loc, formatv("symbol not found: {0}", argName));
@@ -821,7 +821,7 @@ void PatternEmitter::emitVariadicOperandMatch(DagNode tree,
821821
StringRef variadicTreeName = variadicArgTree.getSymbol();
822822
if (!variadicTreeName.empty()) {
823823
auto res =
824-
symbolInfoMap.findBoundSymbol(variadicTreeName, tree, op, operandIndex,
824+
symbolInfoMap.findBoundSymbol(variadicTreeName, tree, op, argIndex,
825825
/*variadicSubIndex=*/std::nullopt);
826826
if (res == symbolInfoMap.end())
827827
PrintFatalError(loc, formatv("symbol not found: {0}", variadicTreeName));

0 commit comments

Comments
 (0)