Skip to content

[mlir][Arith] Generalize and improve -int-range-optimizations #94712

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

krzysz00
Copy link
Contributor

@krzysz00 krzysz00 commented Jun 7, 2024

When the integer range analysis was first develop, a pass that did integer range-based constant folding was developed and used as a test pass. There was an intent to add such a folding to SCCP, but that hasn't happened.

Meanwhile, -int-range-optimizations was added to the arith dialect's transformations. The cmpi simplification in that pass is a strict subset of the constant folding that lived in
-test-int-range-inference.

This commit moves the former test pass into -int-range-optimizaitons, subsuming its previous contents. It also adds an optimization from rocMLIR where rem{s,u}i operations that are noops are replaced by their left operands.

When the integer range analysis was first develop, a pass that did
integer range-based constant folding was developed and used as a test
pass. There was an intent to add such a folding to SCCP, but that
hasn't happened.

Meanwhile, -int-range-optimizations was added to the arith dialect's
transformations. The cmpi simplification in that pass is a strict
subset of the constant folding that lived in
-test-int-range-inference.

This commit moves the former test pass into -int-range-optimizaitons,
subsuming its previous contents. It also adds an optimization from
rocMLIR where `rem{s,u}i` operations that are noops are replaced by
their left operands.
@llvmbot
Copy link
Member

llvmbot commented Jun 7, 2024

@llvm/pr-subscribers-mlir-core
@llvm/pr-subscribers-mlir-index

@llvm/pr-subscribers-mlir

Author: Krzysztof Drewniak (krzysz00)

Changes

When the integer range analysis was first develop, a pass that did integer range-based constant folding was developed and used as a test pass. There was an intent to add such a folding to SCCP, but that hasn't happened.

Meanwhile, -int-range-optimizations was added to the arith dialect's transformations. The cmpi simplification in that pass is a strict subset of the constant folding that lived in
-test-int-range-inference.

This commit moves the former test pass into -int-range-optimizaitons, subsuming its previous contents. It also adds an optimization from rocMLIR where rem{s,u}i operations that are noops are replaced by their left operands.


Patch is 24.68 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/94712.diff

11 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Arith/Transforms/Passes.h (-4)
  • (modified) mlir/include/mlir/Dialect/Arith/Transforms/Passes.td (+7-2)
  • (modified) mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp (+132-155)
  • (modified) mlir/test/Dialect/Arith/int-range-interface.mlir (+1-1)
  • (modified) mlir/test/Dialect/Arith/int-range-opts.mlir (+36)
  • (modified) mlir/test/Dialect/GPU/int-range-interface.mlir (+1-1)
  • (modified) mlir/test/Dialect/Index/int-range-inference.mlir (+1-1)
  • (modified) mlir/test/Interfaces/InferIntRangeInterface/infer-int-range-test-ops.mlir (+3-7)
  • (modified) mlir/test/lib/Transforms/CMakeLists.txt (-1)
  • (removed) mlir/test/lib/Transforms/TestIntRangeInference.cpp (-125)
  • (modified) mlir/tools/mlir-opt/mlir-opt.cpp (-2)
diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
index 9dc262cc72ed0..b8a7d0c78d323 100644
--- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
@@ -64,10 +64,6 @@ void populateArithExpandOpsPatterns(RewritePatternSet &patterns);
 /// equivalent.
 std::unique_ptr<Pass> createArithUnsignedWhenEquivalentPass();
 
-/// Add patterns for int range based optimizations.
-void populateIntRangeOptimizationsPatterns(RewritePatternSet &patterns,
-                                           DataFlowSolver &solver);
-
 /// Create a pass which do optimizations based on integer range analysis.
 std::unique_ptr<Pass> createIntRangeOptimizationsPass();
 
diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
index 550c5c0cf4f60..1517f71f1a7c9 100644
--- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
@@ -40,9 +40,14 @@ def ArithIntRangeOpts : Pass<"int-range-optimizations"> {
   let summary = "Do optimizations based on integer range analysis";
   let description = [{
     This pass runs integer range analysis and apllies optimizations based on its
-    results. e.g. replace arith.cmpi with const if it can be inferred from
-    args ranges.
+    results. It replaces operations with known-constant results with said constants,
+    rewrites `(0 <= %x < D) mod D` to `%x`.
   }];
+  // Explicitly depend on "arith" because this pass could create operations in
+  // `arith` out of thin air in some cases.
+  let dependentDialects = [
+    "::mlir::arith::ArithDialect"
+  ];
 }
 
 def ArithEmulateUnsupportedFloats : Pass<"arith-emulate-unsupported-floats"> {
diff --git a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
index 2473169962b95..e991d0fbe7410 100644
--- a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
@@ -13,7 +13,8 @@
 #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
 #include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/Transforms/FoldUtils.h"
 
 namespace mlir::arith {
 #define GEN_PASS_DEF_ARITHINTRANGEOPTS
@@ -24,155 +25,145 @@ using namespace mlir;
 using namespace mlir::arith;
 using namespace mlir::dataflow;
 
-/// Returns true if 2 integer ranges have intersection.
-static bool intersects(const ConstantIntRanges &lhs,
-                       const ConstantIntRanges &rhs) {
-  return !((lhs.smax().slt(rhs.smin()) || lhs.smin().sgt(rhs.smax())) &&
-           (lhs.umax().ult(rhs.umin()) || lhs.umin().ugt(rhs.umax())));
+/// Patterned after SCCP
+static LogicalResult replaceWithConstant(DataFlowSolver &solver,
+                                         RewriterBase &rewriter,
+                                         OperationFolder &folder, Value value) {
+  auto *maybeInferredRange =
+      solver.lookupState<IntegerValueRangeLattice>(value);
+  if (!maybeInferredRange || maybeInferredRange->getValue().isUninitialized())
+    return failure();
+  const ConstantIntRanges &inferredRange =
+      maybeInferredRange->getValue().getValue();
+  std::optional<APInt> maybeConstValue = inferredRange.getConstantValue();
+  if (!maybeConstValue.has_value())
+    return failure();
+
+  Operation *maybeDefiningOp = value.getDefiningOp();
+  Dialect *valueDialect =
+      maybeDefiningOp ? maybeDefiningOp->getDialect()
+                      : value.getParentRegion()->getParentOp()->getDialect();
+  Attribute constAttr =
+      rewriter.getIntegerAttr(value.getType(), *maybeConstValue);
+  Value constant = folder.getOrCreateConstant(
+      rewriter.getInsertionBlock(), valueDialect, constAttr, value.getType());
+  // Fall back to arith.constant if the dialect materializer doesn't know what
+  // to do with an integer constant.
+  if (!constant)
+    constant = folder.getOrCreateConstant(
+        rewriter.getInsertionBlock(),
+        rewriter.getContext()->getLoadedDialect<ArithDialect>(), constAttr,
+        value.getType());
+  if (!constant)
+    return failure();
+
+  rewriter.replaceAllUsesWith(value, constant);
+  return success();
 }
 
-static FailureOr<bool> handleEq(ConstantIntRanges lhs, ConstantIntRanges rhs) {
-  if (!intersects(lhs, rhs))
-    return false;
-
-  return failure();
-}
-
-static FailureOr<bool> handleNe(ConstantIntRanges lhs, ConstantIntRanges rhs) {
-  if (!intersects(lhs, rhs))
-    return true;
-
-  return failure();
-}
-
-static FailureOr<bool> handleSlt(ConstantIntRanges lhs, ConstantIntRanges rhs) {
-  if (lhs.smax().slt(rhs.smin()))
-    return true;
-
-  if (lhs.smin().sge(rhs.smax()))
-    return false;
-
-  return failure();
-}
-
-static FailureOr<bool> handleSle(ConstantIntRanges lhs, ConstantIntRanges rhs) {
-  if (lhs.smax().sle(rhs.smin()))
-    return true;
-
-  if (lhs.smin().sgt(rhs.smax()))
-    return false;
-
-  return failure();
-}
-
-static FailureOr<bool> handleSgt(ConstantIntRanges lhs, ConstantIntRanges rhs) {
-  return handleSlt(std::move(rhs), std::move(lhs));
-}
-
-static FailureOr<bool> handleSge(ConstantIntRanges lhs, ConstantIntRanges rhs) {
-  return handleSle(std::move(rhs), std::move(lhs));
-}
-
-static FailureOr<bool> handleUlt(ConstantIntRanges lhs, ConstantIntRanges rhs) {
-  if (lhs.umax().ult(rhs.umin()))
-    return true;
-
-  if (lhs.umin().uge(rhs.umax()))
-    return false;
-
-  return failure();
-}
-
-static FailureOr<bool> handleUle(ConstantIntRanges lhs, ConstantIntRanges rhs) {
-  if (lhs.umax().ule(rhs.umin()))
-    return true;
-
-  if (lhs.umin().ugt(rhs.umax()))
-    return false;
-
+/// Rewrite any results of `op` that were inferred to be constant integers to
+/// and replace their uses with that constant. Return success() if all results
+/// where thus replaced and the operation is erased.
+static LogicalResult foldResultsToConstants(DataFlowSolver &solver,
+                                            RewriterBase &rewriter,
+                                            OperationFolder &folder,
+                                            Operation &op) {
+  bool replacedAll = op.getNumResults() != 0;
+  for (Value res : op.getResults())
+    replacedAll &=
+        succeeded(replaceWithConstant(solver, rewriter, folder, res));
+
+  // If all of the results of the operation were replaced, try to erase
+  // the operation completely.
+  if (replacedAll && wouldOpBeTriviallyDead(&op)) {
+    assert(op.use_empty() && "expected all uses to be replaced");
+    rewriter.eraseOp(&op);
+    return success();
+  }
   return failure();
 }
 
-static FailureOr<bool> handleUgt(ConstantIntRanges lhs, ConstantIntRanges rhs) {
-  return handleUlt(std::move(rhs), std::move(lhs));
+/// This function hasn't come from anywhere and is relying on the overall
+/// tests of the integer range inference implementation for its correctness.
+static LogicalResult deleteTrivialRemainder(DataFlowSolver &solver,
+                                            RewriterBase &rewriter,
+                                            Operation &op) {
+  if (!isa<RemSIOp, RemUIOp>(op))
+    return failure();
+  Value lhs = op.getOperand(0);
+  Value rhs = op.getOperand(1);
+  auto rhsConstVal = rhs.getDefiningOp<arith::ConstantIntOp>();
+  if (!rhsConstVal)
+    return failure();
+  int64_t modulus = rhsConstVal.value();
+  if (modulus <= 0)
+    return failure();
+  auto *maybeLhsRange = solver.lookupState<IntegerValueRangeLattice>(lhs);
+  if (!maybeLhsRange || maybeLhsRange->getValue().isUninitialized())
+    return failure();
+  const ConstantIntRanges &lhsRange = maybeLhsRange->getValue().getValue();
+  const APInt &min = llvm::isa<RemUIOp>(op) ? lhsRange.umin() : lhsRange.smin();
+  const APInt &max = llvm::isa<RemUIOp>(op) ? lhsRange.umax() : lhsRange.smax();
+  // The minima and maxima here are given as closed ranges, we must be strictly
+  // less than the modulus.
+  if (min.isNegative() || min.uge(modulus))
+    return failure();
+  if (max.isNegative() || max.uge(modulus))
+    return failure();
+  if (!min.ule(max))
+    return failure();
+
+  // With all those conditions out of the way, we know thas this invocation of
+  // a remainder is a noop because the input is strictly within the range
+  // [0, modulus), so get rid of it.
+  rewriter.replaceOp(&op, ValueRange{lhs});
+  return success();
 }
 
-static FailureOr<bool> handleUge(ConstantIntRanges lhs, ConstantIntRanges rhs) {
-  return handleUle(std::move(rhs), std::move(lhs));
+static void doRewrites(DataFlowSolver &solver, MLIRContext *context,
+                       MutableArrayRef<Region> initialRegions) {
+  SmallVector<Block *> worklist;
+  auto addToWorklist = [&](MutableArrayRef<Region> regions) {
+    for (Region &region : regions)
+      for (Block &block : llvm::reverse(region))
+        worklist.push_back(&block);
+  };
+
+  IRRewriter rewriter(context);
+  OperationFolder folder(context, rewriter.getListener());
+
+  addToWorklist(initialRegions);
+  while (!worklist.empty()) {
+    Block *block = worklist.pop_back_val();
+
+    for (Operation &op : llvm::make_early_inc_range(*block)) {
+      if (matchPattern(&op, m_Constant())) {
+        if (auto arithConstant = dyn_cast<ConstantOp>(op))
+          folder.insertKnownConstant(&op, arithConstant.getValue());
+        else
+          folder.insertKnownConstant(&op);
+        continue;
+      }
+      rewriter.setInsertionPoint(&op);
+
+      // Try rewrites. Success means that the underlying operation was erased.
+      if (succeeded(foldResultsToConstants(solver, rewriter, folder, op)))
+        continue;
+      if (isa<RemSIOp, RemUIOp>(op) &&
+          succeeded(deleteTrivialRemainder(solver, rewriter, op)))
+        continue;
+      // Add any the regions of this operation to the worklist.
+      addToWorklist(op.getRegions());
+    }
+
+    // Replace any block arguments with constants.
+    rewriter.setInsertionPointToStart(block);
+    for (BlockArgument arg : block->getArguments())
+      (void)replaceWithConstant(solver, rewriter, folder, arg);
+  }
 }
 
 namespace {
-/// This class listens on IR transformations performed during a pass relying on
-/// information from a `DataflowSolver`. It erases state associated with the
-/// erased operation and its results from the `DataFlowSolver` so that Patterns
-/// do not accidentally query old state information for newly created Ops.
-class DataFlowListener : public RewriterBase::Listener {
-public:
-  DataFlowListener(DataFlowSolver &s) : s(s) {}
-
-protected:
-  void notifyOperationErased(Operation *op) override {
-    s.eraseState(op);
-    for (Value res : op->getResults())
-      s.eraseState(res);
-  }
-
-  DataFlowSolver &s;
-};
-
-struct ConvertCmpOp : public OpRewritePattern<arith::CmpIOp> {
-
-  ConvertCmpOp(MLIRContext *context, DataFlowSolver &s)
-      : OpRewritePattern<arith::CmpIOp>(context), solver(s) {}
-
-  LogicalResult matchAndRewrite(arith::CmpIOp op,
-                                PatternRewriter &rewriter) const override {
-    auto *lhsResult =
-        solver.lookupState<dataflow::IntegerValueRangeLattice>(op.getLhs());
-    if (!lhsResult || lhsResult->getValue().isUninitialized())
-      return failure();
-
-    auto *rhsResult =
-        solver.lookupState<dataflow::IntegerValueRangeLattice>(op.getRhs());
-    if (!rhsResult || rhsResult->getValue().isUninitialized())
-      return failure();
-
-    using HandlerFunc =
-        FailureOr<bool> (*)(ConstantIntRanges, ConstantIntRanges);
-    std::array<HandlerFunc, arith::getMaxEnumValForCmpIPredicate() + 1>
-        handlers{};
-    using Pred = arith::CmpIPredicate;
-    handlers[static_cast<size_t>(Pred::eq)] = &handleEq;
-    handlers[static_cast<size_t>(Pred::ne)] = &handleNe;
-    handlers[static_cast<size_t>(Pred::slt)] = &handleSlt;
-    handlers[static_cast<size_t>(Pred::sle)] = &handleSle;
-    handlers[static_cast<size_t>(Pred::sgt)] = &handleSgt;
-    handlers[static_cast<size_t>(Pred::sge)] = &handleSge;
-    handlers[static_cast<size_t>(Pred::ult)] = &handleUlt;
-    handlers[static_cast<size_t>(Pred::ule)] = &handleUle;
-    handlers[static_cast<size_t>(Pred::ugt)] = &handleUgt;
-    handlers[static_cast<size_t>(Pred::uge)] = &handleUge;
-
-    HandlerFunc handler = handlers[static_cast<size_t>(op.getPredicate())];
-    if (!handler)
-      return failure();
-
-    ConstantIntRanges lhsValue = lhsResult->getValue().getValue();
-    ConstantIntRanges rhsValue = rhsResult->getValue().getValue();
-    FailureOr<bool> result = handler(lhsValue, rhsValue);
-
-    if (failed(result))
-      return failure();
-
-    rewriter.replaceOpWithNewOp<arith::ConstantIntOp>(
-        op, static_cast<int64_t>(*result), /*width*/ 1);
-    return success();
-  }
-
-private:
-  DataFlowSolver &solver;
-};
-
 struct IntRangeOptimizationsPass
     : public arith::impl::ArithIntRangeOptsBase<IntRangeOptimizationsPass> {
 
@@ -185,25 +176,11 @@ struct IntRangeOptimizationsPass
     if (failed(solver.initializeAndRun(op)))
       return signalPassFailure();
 
-    DataFlowListener listener(solver);
-
-    RewritePatternSet patterns(ctx);
-    populateIntRangeOptimizationsPatterns(patterns, solver);
-
-    GreedyRewriteConfig config;
-    config.listener = &listener;
-
-    if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns), config)))
-      signalPassFailure();
+    doRewrites(solver, ctx, op->getRegions());
   }
 };
 } // namespace
 
-void mlir::arith::populateIntRangeOptimizationsPatterns(
-    RewritePatternSet &patterns, DataFlowSolver &solver) {
-  patterns.add<ConvertCmpOp>(patterns.getContext(), solver);
-}
-
 std::unique_ptr<Pass> mlir::arith::createIntRangeOptimizationsPass() {
   return std::make_unique<IntRangeOptimizationsPass>();
 }
diff --git a/mlir/test/Dialect/Arith/int-range-interface.mlir b/mlir/test/Dialect/Arith/int-range-interface.mlir
index 60f0ab41afa48..e00b7692fe396 100644
--- a/mlir/test/Dialect/Arith/int-range-interface.mlir
+++ b/mlir/test/Dialect/Arith/int-range-interface.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -test-int-range-inference -canonicalize %s | FileCheck %s
+// RUN: mlir-opt -int-range-optimizations -canonicalize %s | FileCheck %s
 
 // CHECK-LABEL: func @add_min_max
 // CHECK: %[[c3:.*]] = arith.constant 3 : index
diff --git a/mlir/test/Dialect/Arith/int-range-opts.mlir b/mlir/test/Dialect/Arith/int-range-opts.mlir
index dd62a481a1246..ea5969a100258 100644
--- a/mlir/test/Dialect/Arith/int-range-opts.mlir
+++ b/mlir/test/Dialect/Arith/int-range-opts.mlir
@@ -96,3 +96,39 @@ func.func @test() -> i8 {
   return %1: i8
 }
 
+// -----
+
+// CHECK-LABEL: func @trivial_rem
+// CHECK: [[val:%.+]] = test.with_bounds
+// CHECK: return [[val]]
+func.func @trivial_rem() -> i8 {
+  %c64 = arith.constant 64 : i8
+  %val = test.with_bounds { umin = 0 : ui8, umax = 63 : ui8, smin = 0 : si8, smax = 63 : si8 } : i8
+  %mod = arith.remsi %val, %c64 : i8
+  return %mod : i8
+}
+
+// -----
+
+// CHECK-LABEL: func @non_const_rhs
+// CHECK: [[mod:%.+]] = arith.remui
+// CHECK: return [[mod]]
+func.func @non_const_rhs() -> i8 {
+  %c64 = arith.constant 64 : i8
+  %val = test.with_bounds { umin = 0 : ui8, umax = 2 : ui8, smin = 0 : si8, smax = 2 : si8 } : i8
+  %rhs = test.with_bounds { umin = 63 : ui8, umax = 64 : ui8, smin = 63 : si8, smax = 64 : si8 } : i8
+  %mod = arith.remui %val, %rhs : i8
+  return %mod : i8
+}
+
+// -----
+
+// CHECK-LABEL: func @wraps
+// CHECK: [[mod:%.+]] = arith.remsi
+// CHECK: return [[mod]]
+func.func @wraps() -> i8 {
+  %c64 = arith.constant 64 : i8
+  %val = test.with_bounds { umin = 63 : ui8, umax = 65 : ui8, smin = 63 : si8, smax = 65 : si8 } : i8
+  %mod = arith.remsi %val, %c64 : i8
+  return %mod : i8
+}
diff --git a/mlir/test/Dialect/GPU/int-range-interface.mlir b/mlir/test/Dialect/GPU/int-range-interface.mlir
index 980f7e5873e0c..a0917a2fdf110 100644
--- a/mlir/test/Dialect/GPU/int-range-interface.mlir
+++ b/mlir/test/Dialect/GPU/int-range-interface.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -test-int-range-inference -split-input-file %s | FileCheck %s
+// RUN: mlir-opt -int-range-optimizations -split-input-file %s | FileCheck %s
 
 // CHECK-LABEL: func @launch_func
 func.func @launch_func(%arg0 : index) {
diff --git a/mlir/test/Dialect/Index/int-range-inference.mlir b/mlir/test/Dialect/Index/int-range-inference.mlir
index 2784d5fd5cf70..951624d573a64 100644
--- a/mlir/test/Dialect/Index/int-range-inference.mlir
+++ b/mlir/test/Dialect/Index/int-range-inference.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -test-int-range-inference -canonicalize %s | FileCheck %s
+// RUN: mlir-opt -int-range-optimizations -canonicalize %s | FileCheck %s
 
 // Most operations are covered by the `arith` tests, which use the same code
 // Here, we add a few tests to ensure the "index can be 32- or 64-bit" handling
diff --git a/mlir/test/Interfaces/InferIntRangeInterface/infer-int-range-test-ops.mlir b/mlir/test/Interfaces/InferIntRangeInterface/infer-int-range-test-ops.mlir
index 2106eeefdca4d..1ec3441b1fde8 100644
--- a/mlir/test/Interfaces/InferIntRangeInterface/infer-int-range-test-ops.mlir
+++ b/mlir/test/Interfaces/InferIntRangeInterface/infer-int-range-test-ops.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -test-int-range-inference %s | FileCheck %s
+// RUN: mlir-opt -int-range-optimizations %s | FileCheck %s
 
 // CHECK-LABEL: func @constant
 // CHECK: %[[cst:.*]] = "test.constant"() <{value = 3 : index}
@@ -103,13 +103,11 @@ func.func @func_args_unbound(%arg0 : index) -> index {
 
 // CHECK-LABEL: func @propagate_across_while_loop_false()
 func.func @propagate_across_while_loop_false() -> index {
-  // CHECK-DAG: %[[C0:.*]] = "test.constant"() <{value = 0
-  // CHECK-DAG: %[[C1:.*]] = "test.constant"() <{value = 1
+  // CHECK: %[[C1:.*]] = "test.constant"() <{value = 1
   %0 = test.with_bounds { umin = 0 : index, umax = 0 : index,
                           smin = 0 : index, smax = 0 : index } : index
   %1 = scf.while : () -> index {
     %false = arith.constant false
-    // CHECK: scf.condition(%{{.*}}) %[[C0]]
     scf.condition(%false) %0 : index
   } do {
   ^bb0(%i1: index):
@@ -122,12 +120,10 @@ func.func @propagate_across_while_loop_false() -> index {
 
 // CHECK-LABEL: func @propagate_across_while_loop
 func.func @propagate_across_while_loop(%arg0 : i1) -> index {
-  // CHECK-DAG: %[[C0:.*]] = "test.constant"() <{value = 0
-  // CHECK-DAG: %[[C1:.*]] = "test.constant"() <{value = 1
+  // CHECK: %[[C1:.*]] = "test.constant"() <{value = 1
   %0 = test.with_bounds { umin = 0 : index, umax = 0 : index,
                           smin = 0 : index, smax = 0 : index } : index
   %1 = scf.while : () -> index {
-    // CHECK: scf.condition(%{{.*}}) %[[C0]]
     scf.condition(%arg0) %0 : index
   } do {
   ^bb0(%i1: index):
diff --git a/mlir/test/lib/Transforms/CMakeLists.txt b/mlir/test/lib/Transforms/CMakeLists.txt
index 975a41ac3d5fe..66b1faf78e2d8 100644
--- a/mlir/test/lib/Transforms/CMakeLists.txt
+++ b/mlir/test/lib/Transforms/CMakeLists.txt
@@ -24,7 +24,6 @@ add_mlir_library(MLIRTestTransforms
   TestConstantFold.cpp
   TestControlFlowSink.cpp
   TestInlining.cpp
-  TestIntRangeInference.cpp
   TestMakeIsolatedFromAbove.cpp
   ${MLIRTestTransformsPDLSrc}
 
diff --git a/mlir/test/lib/Transforms/TestIntRangeInference.cpp b/mlir/test/lib/Transforms/TestIntRangeInference.cpp
deleted file mode 100644
index 5758f6acf2f0f..0000000000000
--- a/mlir/test/lib/Transforms/TestIntRangeInference.cpp
+++ /dev/null
@@ -1,125 +0,0 @@
-//===- TestIntRangeInference.cpp - Create consts from range inference ---===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-// TODO: This pass is needed to test integer range inference until that
-// functionality has been integrated into SCCP.
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
-#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
-#includ...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Jun 7, 2024

@llvm/pr-subscribers-mlir-arith

Author: Krzysztof Drewniak (krzysz00)

Changes

When the integer range analysis was first develop, a pass that did integer range-based constant folding was developed and used as a test pass. There was an intent to add such a folding to SCCP, but that hasn't happened.

Meanwhile, -int-range-optimizations was added to the arith dialect's transformations. The cmpi simplification in that pass is a strict subset of the constant folding that lived in
-test-int-range-inference.

This commit moves the former test pass into -int-range-optimizaitons, subsuming its previous contents. It also adds an optimization from rocMLIR where rem{s,u}i operations that are noops are replaced by their left operands.


Patch is 24.68 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/94712.diff

11 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Arith/Transforms/Passes.h (-4)
  • (modified) mlir/include/mlir/Dialect/Arith/Transforms/Passes.td (+7-2)
  • (modified) mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp (+132-155)
  • (modified) mlir/test/Dialect/Arith/int-range-interface.mlir (+1-1)
  • (modified) mlir/test/Dialect/Arith/int-range-opts.mlir (+36)
  • (modified) mlir/test/Dialect/GPU/int-range-interface.mlir (+1-1)
  • (modified) mlir/test/Dialect/Index/int-range-inference.mlir (+1-1)
  • (modified) mlir/test/Interfaces/InferIntRangeInterface/infer-int-range-test-ops.mlir (+3-7)
  • (modified) mlir/test/lib/Transforms/CMakeLists.txt (-1)
  • (removed) mlir/test/lib/Transforms/TestIntRangeInference.cpp (-125)
  • (modified) mlir/tools/mlir-opt/mlir-opt.cpp (-2)
diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
index 9dc262cc72ed0..b8a7d0c78d323 100644
--- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
@@ -64,10 +64,6 @@ void populateArithExpandOpsPatterns(RewritePatternSet &patterns);
 /// equivalent.
 std::unique_ptr<Pass> createArithUnsignedWhenEquivalentPass();
 
-/// Add patterns for int range based optimizations.
-void populateIntRangeOptimizationsPatterns(RewritePatternSet &patterns,
-                                           DataFlowSolver &solver);
-
 /// Create a pass which do optimizations based on integer range analysis.
 std::unique_ptr<Pass> createIntRangeOptimizationsPass();
 
diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
index 550c5c0cf4f60..1517f71f1a7c9 100644
--- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
@@ -40,9 +40,14 @@ def ArithIntRangeOpts : Pass<"int-range-optimizations"> {
   let summary = "Do optimizations based on integer range analysis";
   let description = [{
     This pass runs integer range analysis and apllies optimizations based on its
-    results. e.g. replace arith.cmpi with const if it can be inferred from
-    args ranges.
+    results. It replaces operations with known-constant results with said constants,
+    rewrites `(0 <= %x < D) mod D` to `%x`.
   }];
+  // Explicitly depend on "arith" because this pass could create operations in
+  // `arith` out of thin air in some cases.
+  let dependentDialects = [
+    "::mlir::arith::ArithDialect"
+  ];
 }
 
 def ArithEmulateUnsupportedFloats : Pass<"arith-emulate-unsupported-floats"> {
diff --git a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
index 2473169962b95..e991d0fbe7410 100644
--- a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
@@ -13,7 +13,8 @@
 #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
 #include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/Transforms/FoldUtils.h"
 
 namespace mlir::arith {
 #define GEN_PASS_DEF_ARITHINTRANGEOPTS
@@ -24,155 +25,145 @@ using namespace mlir;
 using namespace mlir::arith;
 using namespace mlir::dataflow;
 
-/// Returns true if 2 integer ranges have intersection.
-static bool intersects(const ConstantIntRanges &lhs,
-                       const ConstantIntRanges &rhs) {
-  return !((lhs.smax().slt(rhs.smin()) || lhs.smin().sgt(rhs.smax())) &&
-           (lhs.umax().ult(rhs.umin()) || lhs.umin().ugt(rhs.umax())));
+/// Patterned after SCCP
+static LogicalResult replaceWithConstant(DataFlowSolver &solver,
+                                         RewriterBase &rewriter,
+                                         OperationFolder &folder, Value value) {
+  auto *maybeInferredRange =
+      solver.lookupState<IntegerValueRangeLattice>(value);
+  if (!maybeInferredRange || maybeInferredRange->getValue().isUninitialized())
+    return failure();
+  const ConstantIntRanges &inferredRange =
+      maybeInferredRange->getValue().getValue();
+  std::optional<APInt> maybeConstValue = inferredRange.getConstantValue();
+  if (!maybeConstValue.has_value())
+    return failure();
+
+  Operation *maybeDefiningOp = value.getDefiningOp();
+  Dialect *valueDialect =
+      maybeDefiningOp ? maybeDefiningOp->getDialect()
+                      : value.getParentRegion()->getParentOp()->getDialect();
+  Attribute constAttr =
+      rewriter.getIntegerAttr(value.getType(), *maybeConstValue);
+  Value constant = folder.getOrCreateConstant(
+      rewriter.getInsertionBlock(), valueDialect, constAttr, value.getType());
+  // Fall back to arith.constant if the dialect materializer doesn't know what
+  // to do with an integer constant.
+  if (!constant)
+    constant = folder.getOrCreateConstant(
+        rewriter.getInsertionBlock(),
+        rewriter.getContext()->getLoadedDialect<ArithDialect>(), constAttr,
+        value.getType());
+  if (!constant)
+    return failure();
+
+  rewriter.replaceAllUsesWith(value, constant);
+  return success();
 }
 
-static FailureOr<bool> handleEq(ConstantIntRanges lhs, ConstantIntRanges rhs) {
-  if (!intersects(lhs, rhs))
-    return false;
-
-  return failure();
-}
-
-static FailureOr<bool> handleNe(ConstantIntRanges lhs, ConstantIntRanges rhs) {
-  if (!intersects(lhs, rhs))
-    return true;
-
-  return failure();
-}
-
-static FailureOr<bool> handleSlt(ConstantIntRanges lhs, ConstantIntRanges rhs) {
-  if (lhs.smax().slt(rhs.smin()))
-    return true;
-
-  if (lhs.smin().sge(rhs.smax()))
-    return false;
-
-  return failure();
-}
-
-static FailureOr<bool> handleSle(ConstantIntRanges lhs, ConstantIntRanges rhs) {
-  if (lhs.smax().sle(rhs.smin()))
-    return true;
-
-  if (lhs.smin().sgt(rhs.smax()))
-    return false;
-
-  return failure();
-}
-
-static FailureOr<bool> handleSgt(ConstantIntRanges lhs, ConstantIntRanges rhs) {
-  return handleSlt(std::move(rhs), std::move(lhs));
-}
-
-static FailureOr<bool> handleSge(ConstantIntRanges lhs, ConstantIntRanges rhs) {
-  return handleSle(std::move(rhs), std::move(lhs));
-}
-
-static FailureOr<bool> handleUlt(ConstantIntRanges lhs, ConstantIntRanges rhs) {
-  if (lhs.umax().ult(rhs.umin()))
-    return true;
-
-  if (lhs.umin().uge(rhs.umax()))
-    return false;
-
-  return failure();
-}
-
-static FailureOr<bool> handleUle(ConstantIntRanges lhs, ConstantIntRanges rhs) {
-  if (lhs.umax().ule(rhs.umin()))
-    return true;
-
-  if (lhs.umin().ugt(rhs.umax()))
-    return false;
-
+/// Rewrite any results of `op` that were inferred to be constant integers to
+/// and replace their uses with that constant. Return success() if all results
+/// where thus replaced and the operation is erased.
+static LogicalResult foldResultsToConstants(DataFlowSolver &solver,
+                                            RewriterBase &rewriter,
+                                            OperationFolder &folder,
+                                            Operation &op) {
+  bool replacedAll = op.getNumResults() != 0;
+  for (Value res : op.getResults())
+    replacedAll &=
+        succeeded(replaceWithConstant(solver, rewriter, folder, res));
+
+  // If all of the results of the operation were replaced, try to erase
+  // the operation completely.
+  if (replacedAll && wouldOpBeTriviallyDead(&op)) {
+    assert(op.use_empty() && "expected all uses to be replaced");
+    rewriter.eraseOp(&op);
+    return success();
+  }
   return failure();
 }
 
-static FailureOr<bool> handleUgt(ConstantIntRanges lhs, ConstantIntRanges rhs) {
-  return handleUlt(std::move(rhs), std::move(lhs));
+/// This function hasn't come from anywhere and is relying on the overall
+/// tests of the integer range inference implementation for its correctness.
+static LogicalResult deleteTrivialRemainder(DataFlowSolver &solver,
+                                            RewriterBase &rewriter,
+                                            Operation &op) {
+  if (!isa<RemSIOp, RemUIOp>(op))
+    return failure();
+  Value lhs = op.getOperand(0);
+  Value rhs = op.getOperand(1);
+  auto rhsConstVal = rhs.getDefiningOp<arith::ConstantIntOp>();
+  if (!rhsConstVal)
+    return failure();
+  int64_t modulus = rhsConstVal.value();
+  if (modulus <= 0)
+    return failure();
+  auto *maybeLhsRange = solver.lookupState<IntegerValueRangeLattice>(lhs);
+  if (!maybeLhsRange || maybeLhsRange->getValue().isUninitialized())
+    return failure();
+  const ConstantIntRanges &lhsRange = maybeLhsRange->getValue().getValue();
+  const APInt &min = llvm::isa<RemUIOp>(op) ? lhsRange.umin() : lhsRange.smin();
+  const APInt &max = llvm::isa<RemUIOp>(op) ? lhsRange.umax() : lhsRange.smax();
+  // The minima and maxima here are given as closed ranges, we must be strictly
+  // less than the modulus.
+  if (min.isNegative() || min.uge(modulus))
+    return failure();
+  if (max.isNegative() || max.uge(modulus))
+    return failure();
+  if (!min.ule(max))
+    return failure();
+
+  // With all those conditions out of the way, we know thas this invocation of
+  // a remainder is a noop because the input is strictly within the range
+  // [0, modulus), so get rid of it.
+  rewriter.replaceOp(&op, ValueRange{lhs});
+  return success();
 }
 
-static FailureOr<bool> handleUge(ConstantIntRanges lhs, ConstantIntRanges rhs) {
-  return handleUle(std::move(rhs), std::move(lhs));
+static void doRewrites(DataFlowSolver &solver, MLIRContext *context,
+                       MutableArrayRef<Region> initialRegions) {
+  SmallVector<Block *> worklist;
+  auto addToWorklist = [&](MutableArrayRef<Region> regions) {
+    for (Region &region : regions)
+      for (Block &block : llvm::reverse(region))
+        worklist.push_back(&block);
+  };
+
+  IRRewriter rewriter(context);
+  OperationFolder folder(context, rewriter.getListener());
+
+  addToWorklist(initialRegions);
+  while (!worklist.empty()) {
+    Block *block = worklist.pop_back_val();
+
+    for (Operation &op : llvm::make_early_inc_range(*block)) {
+      if (matchPattern(&op, m_Constant())) {
+        if (auto arithConstant = dyn_cast<ConstantOp>(op))
+          folder.insertKnownConstant(&op, arithConstant.getValue());
+        else
+          folder.insertKnownConstant(&op);
+        continue;
+      }
+      rewriter.setInsertionPoint(&op);
+
+      // Try rewrites. Success means that the underlying operation was erased.
+      if (succeeded(foldResultsToConstants(solver, rewriter, folder, op)))
+        continue;
+      if (isa<RemSIOp, RemUIOp>(op) &&
+          succeeded(deleteTrivialRemainder(solver, rewriter, op)))
+        continue;
+      // Add any the regions of this operation to the worklist.
+      addToWorklist(op.getRegions());
+    }
+
+    // Replace any block arguments with constants.
+    rewriter.setInsertionPointToStart(block);
+    for (BlockArgument arg : block->getArguments())
+      (void)replaceWithConstant(solver, rewriter, folder, arg);
+  }
 }
 
 namespace {
-/// This class listens on IR transformations performed during a pass relying on
-/// information from a `DataflowSolver`. It erases state associated with the
-/// erased operation and its results from the `DataFlowSolver` so that Patterns
-/// do not accidentally query old state information for newly created Ops.
-class DataFlowListener : public RewriterBase::Listener {
-public:
-  DataFlowListener(DataFlowSolver &s) : s(s) {}
-
-protected:
-  void notifyOperationErased(Operation *op) override {
-    s.eraseState(op);
-    for (Value res : op->getResults())
-      s.eraseState(res);
-  }
-
-  DataFlowSolver &s;
-};
-
-struct ConvertCmpOp : public OpRewritePattern<arith::CmpIOp> {
-
-  ConvertCmpOp(MLIRContext *context, DataFlowSolver &s)
-      : OpRewritePattern<arith::CmpIOp>(context), solver(s) {}
-
-  LogicalResult matchAndRewrite(arith::CmpIOp op,
-                                PatternRewriter &rewriter) const override {
-    auto *lhsResult =
-        solver.lookupState<dataflow::IntegerValueRangeLattice>(op.getLhs());
-    if (!lhsResult || lhsResult->getValue().isUninitialized())
-      return failure();
-
-    auto *rhsResult =
-        solver.lookupState<dataflow::IntegerValueRangeLattice>(op.getRhs());
-    if (!rhsResult || rhsResult->getValue().isUninitialized())
-      return failure();
-
-    using HandlerFunc =
-        FailureOr<bool> (*)(ConstantIntRanges, ConstantIntRanges);
-    std::array<HandlerFunc, arith::getMaxEnumValForCmpIPredicate() + 1>
-        handlers{};
-    using Pred = arith::CmpIPredicate;
-    handlers[static_cast<size_t>(Pred::eq)] = &handleEq;
-    handlers[static_cast<size_t>(Pred::ne)] = &handleNe;
-    handlers[static_cast<size_t>(Pred::slt)] = &handleSlt;
-    handlers[static_cast<size_t>(Pred::sle)] = &handleSle;
-    handlers[static_cast<size_t>(Pred::sgt)] = &handleSgt;
-    handlers[static_cast<size_t>(Pred::sge)] = &handleSge;
-    handlers[static_cast<size_t>(Pred::ult)] = &handleUlt;
-    handlers[static_cast<size_t>(Pred::ule)] = &handleUle;
-    handlers[static_cast<size_t>(Pred::ugt)] = &handleUgt;
-    handlers[static_cast<size_t>(Pred::uge)] = &handleUge;
-
-    HandlerFunc handler = handlers[static_cast<size_t>(op.getPredicate())];
-    if (!handler)
-      return failure();
-
-    ConstantIntRanges lhsValue = lhsResult->getValue().getValue();
-    ConstantIntRanges rhsValue = rhsResult->getValue().getValue();
-    FailureOr<bool> result = handler(lhsValue, rhsValue);
-
-    if (failed(result))
-      return failure();
-
-    rewriter.replaceOpWithNewOp<arith::ConstantIntOp>(
-        op, static_cast<int64_t>(*result), /*width*/ 1);
-    return success();
-  }
-
-private:
-  DataFlowSolver &solver;
-};
-
 struct IntRangeOptimizationsPass
     : public arith::impl::ArithIntRangeOptsBase<IntRangeOptimizationsPass> {
 
@@ -185,25 +176,11 @@ struct IntRangeOptimizationsPass
     if (failed(solver.initializeAndRun(op)))
       return signalPassFailure();
 
-    DataFlowListener listener(solver);
-
-    RewritePatternSet patterns(ctx);
-    populateIntRangeOptimizationsPatterns(patterns, solver);
-
-    GreedyRewriteConfig config;
-    config.listener = &listener;
-
-    if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns), config)))
-      signalPassFailure();
+    doRewrites(solver, ctx, op->getRegions());
   }
 };
 } // namespace
 
-void mlir::arith::populateIntRangeOptimizationsPatterns(
-    RewritePatternSet &patterns, DataFlowSolver &solver) {
-  patterns.add<ConvertCmpOp>(patterns.getContext(), solver);
-}
-
 std::unique_ptr<Pass> mlir::arith::createIntRangeOptimizationsPass() {
   return std::make_unique<IntRangeOptimizationsPass>();
 }
diff --git a/mlir/test/Dialect/Arith/int-range-interface.mlir b/mlir/test/Dialect/Arith/int-range-interface.mlir
index 60f0ab41afa48..e00b7692fe396 100644
--- a/mlir/test/Dialect/Arith/int-range-interface.mlir
+++ b/mlir/test/Dialect/Arith/int-range-interface.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -test-int-range-inference -canonicalize %s | FileCheck %s
+// RUN: mlir-opt -int-range-optimizations -canonicalize %s | FileCheck %s
 
 // CHECK-LABEL: func @add_min_max
 // CHECK: %[[c3:.*]] = arith.constant 3 : index
diff --git a/mlir/test/Dialect/Arith/int-range-opts.mlir b/mlir/test/Dialect/Arith/int-range-opts.mlir
index dd62a481a1246..ea5969a100258 100644
--- a/mlir/test/Dialect/Arith/int-range-opts.mlir
+++ b/mlir/test/Dialect/Arith/int-range-opts.mlir
@@ -96,3 +96,39 @@ func.func @test() -> i8 {
   return %1: i8
 }
 
+// -----
+
+// CHECK-LABEL: func @trivial_rem
+// CHECK: [[val:%.+]] = test.with_bounds
+// CHECK: return [[val]]
+func.func @trivial_rem() -> i8 {
+  %c64 = arith.constant 64 : i8
+  %val = test.with_bounds { umin = 0 : ui8, umax = 63 : ui8, smin = 0 : si8, smax = 63 : si8 } : i8
+  %mod = arith.remsi %val, %c64 : i8
+  return %mod : i8
+}
+
+// -----
+
+// CHECK-LABEL: func @non_const_rhs
+// CHECK: [[mod:%.+]] = arith.remui
+// CHECK: return [[mod]]
+func.func @non_const_rhs() -> i8 {
+  %c64 = arith.constant 64 : i8
+  %val = test.with_bounds { umin = 0 : ui8, umax = 2 : ui8, smin = 0 : si8, smax = 2 : si8 } : i8
+  %rhs = test.with_bounds { umin = 63 : ui8, umax = 64 : ui8, smin = 63 : si8, smax = 64 : si8 } : i8
+  %mod = arith.remui %val, %rhs : i8
+  return %mod : i8
+}
+
+// -----
+
+// CHECK-LABEL: func @wraps
+// CHECK: [[mod:%.+]] = arith.remsi
+// CHECK: return [[mod]]
+func.func @wraps() -> i8 {
+  %c64 = arith.constant 64 : i8
+  %val = test.with_bounds { umin = 63 : ui8, umax = 65 : ui8, smin = 63 : si8, smax = 65 : si8 } : i8
+  %mod = arith.remsi %val, %c64 : i8
+  return %mod : i8
+}
diff --git a/mlir/test/Dialect/GPU/int-range-interface.mlir b/mlir/test/Dialect/GPU/int-range-interface.mlir
index 980f7e5873e0c..a0917a2fdf110 100644
--- a/mlir/test/Dialect/GPU/int-range-interface.mlir
+++ b/mlir/test/Dialect/GPU/int-range-interface.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -test-int-range-inference -split-input-file %s | FileCheck %s
+// RUN: mlir-opt -int-range-optimizations -split-input-file %s | FileCheck %s
 
 // CHECK-LABEL: func @launch_func
 func.func @launch_func(%arg0 : index) {
diff --git a/mlir/test/Dialect/Index/int-range-inference.mlir b/mlir/test/Dialect/Index/int-range-inference.mlir
index 2784d5fd5cf70..951624d573a64 100644
--- a/mlir/test/Dialect/Index/int-range-inference.mlir
+++ b/mlir/test/Dialect/Index/int-range-inference.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -test-int-range-inference -canonicalize %s | FileCheck %s
+// RUN: mlir-opt -int-range-optimizations -canonicalize %s | FileCheck %s
 
 // Most operations are covered by the `arith` tests, which use the same code
 // Here, we add a few tests to ensure the "index can be 32- or 64-bit" handling
diff --git a/mlir/test/Interfaces/InferIntRangeInterface/infer-int-range-test-ops.mlir b/mlir/test/Interfaces/InferIntRangeInterface/infer-int-range-test-ops.mlir
index 2106eeefdca4d..1ec3441b1fde8 100644
--- a/mlir/test/Interfaces/InferIntRangeInterface/infer-int-range-test-ops.mlir
+++ b/mlir/test/Interfaces/InferIntRangeInterface/infer-int-range-test-ops.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -test-int-range-inference %s | FileCheck %s
+// RUN: mlir-opt -int-range-optimizations %s | FileCheck %s
 
 // CHECK-LABEL: func @constant
 // CHECK: %[[cst:.*]] = "test.constant"() <{value = 3 : index}
@@ -103,13 +103,11 @@ func.func @func_args_unbound(%arg0 : index) -> index {
 
 // CHECK-LABEL: func @propagate_across_while_loop_false()
 func.func @propagate_across_while_loop_false() -> index {
-  // CHECK-DAG: %[[C0:.*]] = "test.constant"() <{value = 0
-  // CHECK-DAG: %[[C1:.*]] = "test.constant"() <{value = 1
+  // CHECK: %[[C1:.*]] = "test.constant"() <{value = 1
   %0 = test.with_bounds { umin = 0 : index, umax = 0 : index,
                           smin = 0 : index, smax = 0 : index } : index
   %1 = scf.while : () -> index {
     %false = arith.constant false
-    // CHECK: scf.condition(%{{.*}}) %[[C0]]
     scf.condition(%false) %0 : index
   } do {
   ^bb0(%i1: index):
@@ -122,12 +120,10 @@ func.func @propagate_across_while_loop_false() -> index {
 
 // CHECK-LABEL: func @propagate_across_while_loop
 func.func @propagate_across_while_loop(%arg0 : i1) -> index {
-  // CHECK-DAG: %[[C0:.*]] = "test.constant"() <{value = 0
-  // CHECK-DAG: %[[C1:.*]] = "test.constant"() <{value = 1
+  // CHECK: %[[C1:.*]] = "test.constant"() <{value = 1
   %0 = test.with_bounds { umin = 0 : index, umax = 0 : index,
                           smin = 0 : index, smax = 0 : index } : index
   %1 = scf.while : () -> index {
-    // CHECK: scf.condition(%{{.*}}) %[[C0]]
     scf.condition(%arg0) %0 : index
   } do {
   ^bb0(%i1: index):
diff --git a/mlir/test/lib/Transforms/CMakeLists.txt b/mlir/test/lib/Transforms/CMakeLists.txt
index 975a41ac3d5fe..66b1faf78e2d8 100644
--- a/mlir/test/lib/Transforms/CMakeLists.txt
+++ b/mlir/test/lib/Transforms/CMakeLists.txt
@@ -24,7 +24,6 @@ add_mlir_library(MLIRTestTransforms
   TestConstantFold.cpp
   TestControlFlowSink.cpp
   TestInlining.cpp
-  TestIntRangeInference.cpp
   TestMakeIsolatedFromAbove.cpp
   ${MLIRTestTransformsPDLSrc}
 
diff --git a/mlir/test/lib/Transforms/TestIntRangeInference.cpp b/mlir/test/lib/Transforms/TestIntRangeInference.cpp
deleted file mode 100644
index 5758f6acf2f0f..0000000000000
--- a/mlir/test/lib/Transforms/TestIntRangeInference.cpp
+++ /dev/null
@@ -1,125 +0,0 @@
-//===- TestIntRangeInference.cpp - Create consts from range inference ---===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-// TODO: This pass is needed to test integer range inference until that
-// functionality has been integrated into SCCP.
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
-#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
-#includ...
[truncated]

Copy link
Contributor

@Hardcode84 Hardcode84 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks

@krzysz00 krzysz00 merged commit 4722911 into llvm:main Jun 10, 2024
7 checks passed
Lukacma pushed a commit to Lukacma/llvm-project that referenced this pull request Jun 12, 2024
…4712)

When the integer range analysis was first develop, a pass that did
integer range-based constant folding was developed and used as a test
pass. There was an intent to add such a folding to SCCP, but that hasn't
happened.

Meanwhile, -int-range-optimizations was added to the arith dialect's
transformations. The cmpi simplification in that pass is a strict subset
of the constant folding that lived in
-test-int-range-inference.

This commit moves the former test pass into -int-range-optimizaitons,
subsuming its previous contents. It also adds an optimization from
rocMLIR where `rem{s,u}i` operations that are noops are replaced by
their left operands.
@HerrCai0907 HerrCai0907 mentioned this pull request Jun 13, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants