Skip to content

Commit 4722911

Browse files
authored
[mlir][Arith] Generalize and improve -int-range-optimizations (#94712)
When the integer range analysis was first develop, a pass that did integer range-based constant folding was developed and used as a test pass. There was an intent to add such a folding to SCCP, but that hasn't happened. Meanwhile, -int-range-optimizations was added to the arith dialect's transformations. The cmpi simplification in that pass is a strict subset of the constant folding that lived in -test-int-range-inference. This commit moves the former test pass into -int-range-optimizaitons, subsuming its previous contents. It also adds an optimization from rocMLIR where `rem{s,u}i` operations that are noops are replaced by their left operands.
1 parent 3e39328 commit 4722911

File tree

10 files changed

+183
-256
lines changed

10 files changed

+183
-256
lines changed

mlir/include/mlir/Dialect/Arith/Transforms/Passes.td

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,14 @@ def ArithIntRangeOpts : Pass<"int-range-optimizations"> {
4040
let summary = "Do optimizations based on integer range analysis";
4141
let description = [{
4242
This pass runs integer range analysis and apllies optimizations based on its
43-
results. e.g. replace arith.cmpi with const if it can be inferred from
44-
args ranges.
43+
results. It replaces operations with known-constant results with said constants,
44+
rewrites `(0 <= %x < D) mod D` to `%x`.
4545
}];
46+
// Explicitly depend on "arith" because this pass could create operations in
47+
// `arith` out of thin air in some cases.
48+
let dependentDialects = [
49+
"::mlir::arith::ArithDialect"
50+
];
4651
}
4752

4853
def ArithEmulateUnsupportedFloats : Pass<"arith-emulate-unsupported-floats"> {

mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp

Lines changed: 128 additions & 116 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,17 @@
88

99
#include <utility>
1010

11+
#include "mlir/Analysis/DataFlowFramework.h"
1112
#include "mlir/Dialect/Arith/Transforms/Passes.h"
1213

1314
#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
1415
#include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h"
1516
#include "mlir/Dialect/Arith/IR/Arith.h"
17+
#include "mlir/Dialect/Utils/StaticValueUtils.h"
18+
#include "mlir/IR/Matchers.h"
19+
#include "mlir/IR/PatternMatch.h"
20+
#include "mlir/Interfaces/SideEffectInterfaces.h"
21+
#include "mlir/Transforms/FoldUtils.h"
1622
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
1723

1824
namespace mlir::arith {
@@ -24,88 +30,50 @@ using namespace mlir;
2430
using namespace mlir::arith;
2531
using namespace mlir::dataflow;
2632

27-
/// Returns true if 2 integer ranges have intersection.
28-
static bool intersects(const ConstantIntRanges &lhs,
29-
const ConstantIntRanges &rhs) {
30-
return !((lhs.smax().slt(rhs.smin()) || lhs.smin().sgt(rhs.smax())) &&
31-
(lhs.umax().ult(rhs.umin()) || lhs.umin().ugt(rhs.umax())));
33+
static std::optional<APInt> getMaybeConstantValue(DataFlowSolver &solver,
34+
Value value) {
35+
auto *maybeInferredRange =
36+
solver.lookupState<IntegerValueRangeLattice>(value);
37+
if (!maybeInferredRange || maybeInferredRange->getValue().isUninitialized())
38+
return std::nullopt;
39+
const ConstantIntRanges &inferredRange =
40+
maybeInferredRange->getValue().getValue();
41+
return inferredRange.getConstantValue();
3242
}
3343

34-
static FailureOr<bool> handleEq(ConstantIntRanges lhs, ConstantIntRanges rhs) {
35-
if (!intersects(lhs, rhs))
36-
return false;
37-
38-
return failure();
39-
}
40-
41-
static FailureOr<bool> handleNe(ConstantIntRanges lhs, ConstantIntRanges rhs) {
42-
if (!intersects(lhs, rhs))
43-
return true;
44-
45-
return failure();
46-
}
47-
48-
static FailureOr<bool> handleSlt(ConstantIntRanges lhs, ConstantIntRanges rhs) {
49-
if (lhs.smax().slt(rhs.smin()))
50-
return true;
51-
52-
if (lhs.smin().sge(rhs.smax()))
53-
return false;
54-
55-
return failure();
56-
}
57-
58-
static FailureOr<bool> handleSle(ConstantIntRanges lhs, ConstantIntRanges rhs) {
59-
if (lhs.smax().sle(rhs.smin()))
60-
return true;
61-
62-
if (lhs.smin().sgt(rhs.smax()))
63-
return false;
64-
65-
return failure();
66-
}
67-
68-
static FailureOr<bool> handleSgt(ConstantIntRanges lhs, ConstantIntRanges rhs) {
69-
return handleSlt(std::move(rhs), std::move(lhs));
70-
}
71-
72-
static FailureOr<bool> handleSge(ConstantIntRanges lhs, ConstantIntRanges rhs) {
73-
return handleSle(std::move(rhs), std::move(lhs));
74-
}
75-
76-
static FailureOr<bool> handleUlt(ConstantIntRanges lhs, ConstantIntRanges rhs) {
77-
if (lhs.umax().ult(rhs.umin()))
78-
return true;
79-
80-
if (lhs.umin().uge(rhs.umax()))
81-
return false;
82-
83-
return failure();
84-
}
85-
86-
static FailureOr<bool> handleUle(ConstantIntRanges lhs, ConstantIntRanges rhs) {
87-
if (lhs.umax().ule(rhs.umin()))
88-
return true;
89-
90-
if (lhs.umin().ugt(rhs.umax()))
91-
return false;
92-
93-
return failure();
94-
}
95-
96-
static FailureOr<bool> handleUgt(ConstantIntRanges lhs, ConstantIntRanges rhs) {
97-
return handleUlt(std::move(rhs), std::move(lhs));
98-
}
99-
100-
static FailureOr<bool> handleUge(ConstantIntRanges lhs, ConstantIntRanges rhs) {
101-
return handleUle(std::move(rhs), std::move(lhs));
44+
/// Patterned after SCCP
45+
static LogicalResult maybeReplaceWithConstant(DataFlowSolver &solver,
46+
PatternRewriter &rewriter,
47+
Value value) {
48+
if (value.use_empty())
49+
return failure();
50+
std::optional<APInt> maybeConstValue = getMaybeConstantValue(solver, value);
51+
if (!maybeConstValue.has_value())
52+
return failure();
53+
54+
Operation *maybeDefiningOp = value.getDefiningOp();
55+
Dialect *valueDialect =
56+
maybeDefiningOp ? maybeDefiningOp->getDialect()
57+
: value.getParentRegion()->getParentOp()->getDialect();
58+
Attribute constAttr =
59+
rewriter.getIntegerAttr(value.getType(), *maybeConstValue);
60+
Operation *constOp = valueDialect->materializeConstant(
61+
rewriter, constAttr, value.getType(), value.getLoc());
62+
// Fall back to arith.constant if the dialect materializer doesn't know what
63+
// to do with an integer constant.
64+
if (!constOp)
65+
constOp = rewriter.getContext()
66+
->getLoadedDialect<ArithDialect>()
67+
->materializeConstant(rewriter, constAttr, value.getType(),
68+
value.getLoc());
69+
if (!constOp)
70+
return failure();
71+
72+
rewriter.replaceAllUsesWith(value, constOp->getResult(0));
73+
return success();
10274
}
10375

10476
namespace {
105-
/// This class listens on IR transformations performed during a pass relying on
106-
/// information from a `DataflowSolver`. It erases state associated with the
107-
/// erased operation and its results from the `DataFlowSolver` so that Patterns
108-
/// do not accidentally query old state information for newly created Ops.
10977
class DataFlowListener : public RewriterBase::Listener {
11078
public:
11179
DataFlowListener(DataFlowSolver &s) : s(s) {}
@@ -120,52 +88,95 @@ class DataFlowListener : public RewriterBase::Listener {
12088
DataFlowSolver &s;
12189
};
12290

123-
struct ConvertCmpOp : public OpRewritePattern<arith::CmpIOp> {
91+
/// Rewrite any results of `op` that were inferred to be constant integers to
92+
/// and replace their uses with that constant. Return success() if all results
93+
/// where thus replaced and the operation is erased. Also replace any block
94+
/// arguments with their constant values.
95+
struct MaterializeKnownConstantValues : public RewritePattern {
96+
MaterializeKnownConstantValues(MLIRContext *context, DataFlowSolver &s)
97+
: RewritePattern(Pattern::MatchAnyOpTypeTag(), /*benefit=*/1, context),
98+
solver(s) {}
99+
100+
LogicalResult match(Operation *op) const override {
101+
if (matchPattern(op, m_Constant()))
102+
return failure();
124103

125-
ConvertCmpOp(MLIRContext *context, DataFlowSolver &s)
126-
: OpRewritePattern<arith::CmpIOp>(context), solver(s) {}
104+
auto needsReplacing = [&](Value v) {
105+
return getMaybeConstantValue(solver, v).has_value() && !v.use_empty();
106+
};
107+
bool hasConstantResults = llvm::any_of(op->getResults(), needsReplacing);
108+
if (op->getNumRegions() == 0)
109+
return success(hasConstantResults);
110+
bool hasConstantRegionArgs = false;
111+
for (Region &region : op->getRegions()) {
112+
for (Block &block : region.getBlocks()) {
113+
hasConstantRegionArgs |=
114+
llvm::any_of(block.getArguments(), needsReplacing);
115+
}
116+
}
117+
return success(hasConstantResults || hasConstantRegionArgs);
118+
}
127119

128-
LogicalResult matchAndRewrite(arith::CmpIOp op,
120+
void rewrite(Operation *op, PatternRewriter &rewriter) const override {
121+
bool replacedAll = (op->getNumResults() != 0);
122+
for (Value v : op->getResults())
123+
replacedAll &=
124+
(succeeded(maybeReplaceWithConstant(solver, rewriter, v)) ||
125+
v.use_empty());
126+
if (replacedAll && isOpTriviallyDead(op)) {
127+
rewriter.eraseOp(op);
128+
return;
129+
}
130+
131+
PatternRewriter::InsertionGuard guard(rewriter);
132+
for (Region &region : op->getRegions()) {
133+
for (Block &block : region.getBlocks()) {
134+
rewriter.setInsertionPointToStart(&block);
135+
for (BlockArgument &arg : block.getArguments()) {
136+
(void)maybeReplaceWithConstant(solver, rewriter, arg);
137+
}
138+
}
139+
}
140+
}
141+
142+
private:
143+
DataFlowSolver &solver;
144+
};
145+
146+
template <typename RemOp>
147+
struct DeleteTrivialRem : public OpRewritePattern<RemOp> {
148+
DeleteTrivialRem(MLIRContext *context, DataFlowSolver &s)
149+
: OpRewritePattern<RemOp>(context), solver(s) {}
150+
151+
LogicalResult matchAndRewrite(RemOp op,
129152
PatternRewriter &rewriter) const override {
130-
auto *lhsResult =
131-
solver.lookupState<dataflow::IntegerValueRangeLattice>(op.getLhs());
132-
if (!lhsResult || lhsResult->getValue().isUninitialized())
153+
Value lhs = op.getOperand(0);
154+
Value rhs = op.getOperand(1);
155+
auto maybeModulus = getConstantIntValue(rhs);
156+
if (!maybeModulus.has_value())
133157
return failure();
134-
135-
auto *rhsResult =
136-
solver.lookupState<dataflow::IntegerValueRangeLattice>(op.getRhs());
137-
if (!rhsResult || rhsResult->getValue().isUninitialized())
158+
int64_t modulus = *maybeModulus;
159+
if (modulus <= 0)
138160
return failure();
139-
140-
using HandlerFunc =
141-
FailureOr<bool> (*)(ConstantIntRanges, ConstantIntRanges);
142-
std::array<HandlerFunc, arith::getMaxEnumValForCmpIPredicate() + 1>
143-
handlers{};
144-
using Pred = arith::CmpIPredicate;
145-
handlers[static_cast<size_t>(Pred::eq)] = &handleEq;
146-
handlers[static_cast<size_t>(Pred::ne)] = &handleNe;
147-
handlers[static_cast<size_t>(Pred::slt)] = &handleSlt;
148-
handlers[static_cast<size_t>(Pred::sle)] = &handleSle;
149-
handlers[static_cast<size_t>(Pred::sgt)] = &handleSgt;
150-
handlers[static_cast<size_t>(Pred::sge)] = &handleSge;
151-
handlers[static_cast<size_t>(Pred::ult)] = &handleUlt;
152-
handlers[static_cast<size_t>(Pred::ule)] = &handleUle;
153-
handlers[static_cast<size_t>(Pred::ugt)] = &handleUgt;
154-
handlers[static_cast<size_t>(Pred::uge)] = &handleUge;
155-
156-
HandlerFunc handler = handlers[static_cast<size_t>(op.getPredicate())];
157-
if (!handler)
161+
auto *maybeLhsRange = solver.lookupState<IntegerValueRangeLattice>(lhs);
162+
if (!maybeLhsRange || maybeLhsRange->getValue().isUninitialized())
158163
return failure();
159-
160-
ConstantIntRanges lhsValue = lhsResult->getValue().getValue();
161-
ConstantIntRanges rhsValue = rhsResult->getValue().getValue();
162-
FailureOr<bool> result = handler(lhsValue, rhsValue);
163-
164-
if (failed(result))
164+
const ConstantIntRanges &lhsRange = maybeLhsRange->getValue().getValue();
165+
const APInt &min = isa<RemUIOp>(op) ? lhsRange.umin() : lhsRange.smin();
166+
const APInt &max = isa<RemUIOp>(op) ? lhsRange.umax() : lhsRange.smax();
167+
// The minima and maxima here are given as closed ranges, we must be
168+
// strictly less than the modulus.
169+
if (min.isNegative() || min.uge(modulus))
170+
return failure();
171+
if (max.isNegative() || max.uge(modulus))
172+
return failure();
173+
if (!min.ule(max))
165174
return failure();
166175

167-
rewriter.replaceOpWithNewOp<arith::ConstantIntOp>(
168-
op, static_cast<int64_t>(*result), /*width*/ 1);
176+
// With all those conditions out of the way, we know thas this invocation of
177+
// a remainder is a noop because the input is strictly within the range
178+
// [0, modulus), so get rid of it.
179+
rewriter.replaceOp(op, ValueRange{lhs});
169180
return success();
170181
}
171182

@@ -201,7 +212,8 @@ struct IntRangeOptimizationsPass
201212

202213
void mlir::arith::populateIntRangeOptimizationsPatterns(
203214
RewritePatternSet &patterns, DataFlowSolver &solver) {
204-
patterns.add<ConvertCmpOp>(patterns.getContext(), solver);
215+
patterns.add<MaterializeKnownConstantValues, DeleteTrivialRem<RemSIOp>,
216+
DeleteTrivialRem<RemUIOp>>(patterns.getContext(), solver);
205217
}
206218

207219
std::unique_ptr<Pass> mlir::arith::createIntRangeOptimizationsPass() {

mlir/test/Dialect/Arith/int-range-interface.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: mlir-opt -test-int-range-inference -canonicalize %s | FileCheck %s
1+
// RUN: mlir-opt -int-range-optimizations -canonicalize %s | FileCheck %s
22

33
// CHECK-LABEL: func @add_min_max
44
// CHECK: %[[c3:.*]] = arith.constant 3 : index

mlir/test/Dialect/Arith/int-range-opts.mlir

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,3 +96,39 @@ func.func @test() -> i8 {
9696
return %1: i8
9797
}
9898

99+
// -----
100+
101+
// CHECK-LABEL: func @trivial_rem
102+
// CHECK: [[val:%.+]] = test.with_bounds
103+
// CHECK: return [[val]]
104+
func.func @trivial_rem() -> i8 {
105+
%c64 = arith.constant 64 : i8
106+
%val = test.with_bounds { umin = 0 : ui8, umax = 63 : ui8, smin = 0 : si8, smax = 63 : si8 } : i8
107+
%mod = arith.remsi %val, %c64 : i8
108+
return %mod : i8
109+
}
110+
111+
// -----
112+
113+
// CHECK-LABEL: func @non_const_rhs
114+
// CHECK: [[mod:%.+]] = arith.remui
115+
// CHECK: return [[mod]]
116+
func.func @non_const_rhs() -> i8 {
117+
%c64 = arith.constant 64 : i8
118+
%val = test.with_bounds { umin = 0 : ui8, umax = 2 : ui8, smin = 0 : si8, smax = 2 : si8 } : i8
119+
%rhs = test.with_bounds { umin = 63 : ui8, umax = 64 : ui8, smin = 63 : si8, smax = 64 : si8 } : i8
120+
%mod = arith.remui %val, %rhs : i8
121+
return %mod : i8
122+
}
123+
124+
// -----
125+
126+
// CHECK-LABEL: func @wraps
127+
// CHECK: [[mod:%.+]] = arith.remsi
128+
// CHECK: return [[mod]]
129+
func.func @wraps() -> i8 {
130+
%c64 = arith.constant 64 : i8
131+
%val = test.with_bounds { umin = 63 : ui8, umax = 65 : ui8, smin = 63 : si8, smax = 65 : si8 } : i8
132+
%mod = arith.remsi %val, %c64 : i8
133+
return %mod : i8
134+
}

mlir/test/Dialect/GPU/int-range-interface.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: mlir-opt -test-int-range-inference -split-input-file %s | FileCheck %s
1+
// RUN: mlir-opt -int-range-optimizations -split-input-file %s | FileCheck %s
22

33
// CHECK-LABEL: func @launch_func
44
func.func @launch_func(%arg0 : index) {

mlir/test/Dialect/Index/int-range-inference.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: mlir-opt -test-int-range-inference -canonicalize %s | FileCheck %s
1+
// RUN: mlir-opt -int-range-optimizations -canonicalize %s | FileCheck %s
22

33
// Most operations are covered by the `arith` tests, which use the same code
44
// Here, we add a few tests to ensure the "index can be 32- or 64-bit" handling

mlir/test/Interfaces/InferIntRangeInterface/infer-int-range-test-ops.mlir

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: mlir-opt -test-int-range-inference %s | FileCheck %s
1+
// RUN: mlir-opt -int-range-optimizations %s | FileCheck %s
22

33
// CHECK-LABEL: func @constant
44
// CHECK: %[[cst:.*]] = "test.constant"() <{value = 3 : index}
@@ -103,13 +103,11 @@ func.func @func_args_unbound(%arg0 : index) -> index {
103103

104104
// CHECK-LABEL: func @propagate_across_while_loop_false()
105105
func.func @propagate_across_while_loop_false() -> index {
106-
// CHECK-DAG: %[[C0:.*]] = "test.constant"() <{value = 0
107-
// CHECK-DAG: %[[C1:.*]] = "test.constant"() <{value = 1
106+
// CHECK: %[[C1:.*]] = "test.constant"() <{value = 1
108107
%0 = test.with_bounds { umin = 0 : index, umax = 0 : index,
109108
smin = 0 : index, smax = 0 : index } : index
110109
%1 = scf.while : () -> index {
111110
%false = arith.constant false
112-
// CHECK: scf.condition(%{{.*}}) %[[C0]]
113111
scf.condition(%false) %0 : index
114112
} do {
115113
^bb0(%i1: index):
@@ -122,12 +120,10 @@ func.func @propagate_across_while_loop_false() -> index {
122120

123121
// CHECK-LABEL: func @propagate_across_while_loop
124122
func.func @propagate_across_while_loop(%arg0 : i1) -> index {
125-
// CHECK-DAG: %[[C0:.*]] = "test.constant"() <{value = 0
126-
// CHECK-DAG: %[[C1:.*]] = "test.constant"() <{value = 1
123+
// CHECK: %[[C1:.*]] = "test.constant"() <{value = 1
127124
%0 = test.with_bounds { umin = 0 : index, umax = 0 : index,
128125
smin = 0 : index, smax = 0 : index } : index
129126
%1 = scf.while : () -> index {
130-
// CHECK: scf.condition(%{{.*}}) %[[C0]]
131127
scf.condition(%arg0) %0 : index
132128
} do {
133129
^bb0(%i1: index):

0 commit comments

Comments
 (0)