Skip to content

Commit 2b2889b

Browse files
committed
Add index::CmpOp canonicalization.
Add canonicalization pattern for index::CmpOp Differential Revision: https://reviews.llvm.org/D157903
1 parent f8ad86c commit 2b2889b

File tree

4 files changed

+49
-4
lines changed

4 files changed

+49
-4
lines changed

mlir/include/mlir/Dialect/Index/IR/IndexOps.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,13 @@
2323
// Forward Declarations
2424
//===----------------------------------------------------------------------===//
2525

26-
namespace mlir::index {
26+
namespace mlir {
27+
class PatternRewriter;
28+
namespace index {
2729
enum class IndexCmpPredicate : uint32_t;
2830
class IndexCmpPredicateAttr;
29-
} // namespace mlir::index
31+
} // namespace index
32+
} // namespace mlir
3033

3134
//===----------------------------------------------------------------------===//
3235
// ODS-Generated Declarations

mlir/include/mlir/Dialect/Index/IR/IndexOps.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -542,6 +542,7 @@ def Index_CmpOp : IndexOp<"cmp"> {
542542
let results = (outs I1:$result);
543543
let assemblyFormat = "`` $pred `(` $lhs `,` $rhs `)` attr-dict";
544544
let hasFolder = 1;
545+
let hasCanonicalizeMethod = 1;
545546
}
546547

547548
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Index/IR/IndexOps.cpp

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include "mlir/IR/Builders.h"
1313
#include "mlir/IR/Matchers.h"
1414
#include "mlir/IR/OpImplementation.h"
15+
#include "mlir/IR/PatternMatch.h"
1516
#include "mlir/Interfaces/Utils/InferIntRangeCommon.h"
1617
#include "llvm/ADT/SmallString.h"
1718
#include "llvm/ADT/TypeSwitch.h"
@@ -549,6 +550,37 @@ OpFoldResult CmpOp::fold(FoldAdaptor adaptor) {
549550
return {};
550551
}
551552

553+
/// Canonicalize
554+
/// `x - y cmp 0` to `x cmp y`. or `x - y cmp 0` to `x cmp y`.
555+
/// `0 cmp x - y` to `y cmp x`. or `0 cmp x - y` to `y cmp x`.
556+
LogicalResult CmpOp::canonicalize(CmpOp op, PatternRewriter &rewriter) {
557+
IntegerAttr cmpRhs;
558+
IntegerAttr cmpLhs;
559+
560+
bool rhsIsZero = matchPattern(op.getRhs(), m_Constant(&cmpRhs)) &&
561+
cmpRhs.getValue().isZero();
562+
bool lhsIsZero = matchPattern(op.getLhs(), m_Constant(&cmpLhs)) &&
563+
cmpLhs.getValue().isZero();
564+
if (!rhsIsZero && !lhsIsZero)
565+
return rewriter.notifyMatchFailure(op.getLoc(),
566+
"cmp is not comparing something with 0");
567+
SubOp subOp = rhsIsZero ? op.getLhs().getDefiningOp<index::SubOp>()
568+
: op.getRhs().getDefiningOp<index::SubOp>();
569+
if (!subOp)
570+
return rewriter.notifyMatchFailure(
571+
op.getLoc(), "non-zero operand is not a result of subtraction");
572+
573+
index::CmpOp newCmp;
574+
if (rhsIsZero)
575+
newCmp = rewriter.create<index::CmpOp>(op.getLoc(), op.getPred(),
576+
subOp.getLhs(), subOp.getRhs());
577+
else
578+
newCmp = rewriter.create<index::CmpOp>(op.getLoc(), op.getPred(),
579+
subOp.getRhs(), subOp.getLhs());
580+
rewriter.replaceOp(op, newCmp);
581+
return success();
582+
}
583+
552584
//===----------------------------------------------------------------------===//
553585
// ConstantOp
554586
//===----------------------------------------------------------------------===//

mlir/test/Dialect/Index/index-canonicalize.mlir

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -473,7 +473,7 @@ func.func @xor() -> index {
473473
}
474474

475475
// CHECK-LABEL: @cmp
476-
func.func @cmp() -> (i1, i1, i1, i1) {
476+
func.func @cmp(%arg0: index) -> (i1, i1, i1, i1, i1, i1) {
477477
%a = index.constant 0
478478
%b = index.constant -1
479479
%c = index.constant -2
@@ -484,10 +484,19 @@ func.func @cmp() -> (i1, i1, i1, i1) {
484484
%2 = index.cmp ne(%d, %a)
485485
%3 = index.cmp sgt(%b, %a)
486486

487+
%4 = index.sub %a, %arg0
488+
%5 = index.cmp sgt(%4, %a)
489+
490+
%6 = index.sub %a, %arg0
491+
%7 = index.cmp sgt(%a, %6)
492+
487493
// CHECK-DAG: %[[TRUE:.*]] = index.bool.constant true
488494
// CHECK-DAG: %[[FALSE:.*]] = index.bool.constant false
495+
// CHECK-DAG: [[IDX0:%.*]] = index.constant 0
496+
// CHECK-DAG: [[V4:%.*]] = index.cmp sgt([[IDX0]], %arg0)
497+
// CHECK-DAG: [[V5:%.*]] = index.cmp sgt(%arg0, [[IDX0]])
489498
// CHECK: return %[[FALSE]], %[[TRUE]], %[[TRUE]], %[[FALSE]]
490-
return %0, %1, %2, %3 : i1, i1, i1, i1
499+
return %0, %1, %2, %3, %5, %7 : i1, i1, i1, i1, i1, i1
491500
}
492501

493502
// CHECK-LABEL: @cmp_nofold

0 commit comments

Comments
 (0)