8
8
9
9
#include < utility>
10
10
11
+ #include " mlir/Analysis/DataFlowFramework.h"
11
12
#include " mlir/Dialect/Arith/Transforms/Passes.h"
12
13
13
14
#include " mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
14
15
#include " mlir/Analysis/DataFlow/IntegerRangeAnalysis.h"
15
16
#include " mlir/Dialect/Arith/IR/Arith.h"
17
+ #include " mlir/Dialect/Utils/StaticValueUtils.h"
18
+ #include " mlir/IR/Matchers.h"
19
+ #include " mlir/IR/PatternMatch.h"
20
+ #include " mlir/Interfaces/SideEffectInterfaces.h"
21
+ #include " mlir/Transforms/FoldUtils.h"
16
22
#include " mlir/Transforms/GreedyPatternRewriteDriver.h"
17
23
18
24
namespace mlir ::arith {
@@ -24,88 +30,50 @@ using namespace mlir;
24
30
using namespace mlir ::arith;
25
31
using namespace mlir ::dataflow;
26
32
27
- // / Returns true if 2 integer ranges have intersection.
28
- static bool intersects (const ConstantIntRanges &lhs,
29
- const ConstantIntRanges &rhs) {
30
- return !((lhs.smax ().slt (rhs.smin ()) || lhs.smin ().sgt (rhs.smax ())) &&
31
- (lhs.umax ().ult (rhs.umin ()) || lhs.umin ().ugt (rhs.umax ())));
33
+ static std::optional<APInt> getMaybeConstantValue (DataFlowSolver &solver,
34
+ Value value) {
35
+ auto *maybeInferredRange =
36
+ solver.lookupState <IntegerValueRangeLattice>(value);
37
+ if (!maybeInferredRange || maybeInferredRange->getValue ().isUninitialized ())
38
+ return std::nullopt;
39
+ const ConstantIntRanges &inferredRange =
40
+ maybeInferredRange->getValue ().getValue ();
41
+ return inferredRange.getConstantValue ();
32
42
}
33
43
34
- static FailureOr<bool > handleEq (ConstantIntRanges lhs, ConstantIntRanges rhs) {
35
- if (!intersects (lhs, rhs))
36
- return false ;
37
-
38
- return failure ();
39
- }
40
-
41
- static FailureOr<bool > handleNe (ConstantIntRanges lhs, ConstantIntRanges rhs) {
42
- if (!intersects (lhs, rhs))
43
- return true ;
44
-
45
- return failure ();
46
- }
47
-
48
- static FailureOr<bool > handleSlt (ConstantIntRanges lhs, ConstantIntRanges rhs) {
49
- if (lhs.smax ().slt (rhs.smin ()))
50
- return true ;
51
-
52
- if (lhs.smin ().sge (rhs.smax ()))
53
- return false ;
54
-
55
- return failure ();
56
- }
57
-
58
- static FailureOr<bool > handleSle (ConstantIntRanges lhs, ConstantIntRanges rhs) {
59
- if (lhs.smax ().sle (rhs.smin ()))
60
- return true ;
61
-
62
- if (lhs.smin ().sgt (rhs.smax ()))
63
- return false ;
64
-
65
- return failure ();
66
- }
67
-
68
- static FailureOr<bool > handleSgt (ConstantIntRanges lhs, ConstantIntRanges rhs) {
69
- return handleSlt (std::move (rhs), std::move (lhs));
70
- }
71
-
72
- static FailureOr<bool > handleSge (ConstantIntRanges lhs, ConstantIntRanges rhs) {
73
- return handleSle (std::move (rhs), std::move (lhs));
74
- }
75
-
76
- static FailureOr<bool > handleUlt (ConstantIntRanges lhs, ConstantIntRanges rhs) {
77
- if (lhs.umax ().ult (rhs.umin ()))
78
- return true ;
79
-
80
- if (lhs.umin ().uge (rhs.umax ()))
81
- return false ;
82
-
83
- return failure ();
84
- }
85
-
86
- static FailureOr<bool > handleUle (ConstantIntRanges lhs, ConstantIntRanges rhs) {
87
- if (lhs.umax ().ule (rhs.umin ()))
88
- return true ;
89
-
90
- if (lhs.umin ().ugt (rhs.umax ()))
91
- return false ;
92
-
93
- return failure ();
94
- }
95
-
96
- static FailureOr<bool > handleUgt (ConstantIntRanges lhs, ConstantIntRanges rhs) {
97
- return handleUlt (std::move (rhs), std::move (lhs));
98
- }
99
-
100
- static FailureOr<bool > handleUge (ConstantIntRanges lhs, ConstantIntRanges rhs) {
101
- return handleUle (std::move (rhs), std::move (lhs));
44
+ // / Patterned after SCCP
45
+ static LogicalResult maybeReplaceWithConstant (DataFlowSolver &solver,
46
+ PatternRewriter &rewriter,
47
+ Value value) {
48
+ if (value.use_empty ())
49
+ return failure ();
50
+ std::optional<APInt> maybeConstValue = getMaybeConstantValue (solver, value);
51
+ if (!maybeConstValue.has_value ())
52
+ return failure ();
53
+
54
+ Operation *maybeDefiningOp = value.getDefiningOp ();
55
+ Dialect *valueDialect =
56
+ maybeDefiningOp ? maybeDefiningOp->getDialect ()
57
+ : value.getParentRegion ()->getParentOp ()->getDialect ();
58
+ Attribute constAttr =
59
+ rewriter.getIntegerAttr (value.getType (), *maybeConstValue);
60
+ Operation *constOp = valueDialect->materializeConstant (
61
+ rewriter, constAttr, value.getType (), value.getLoc ());
62
+ // Fall back to arith.constant if the dialect materializer doesn't know what
63
+ // to do with an integer constant.
64
+ if (!constOp)
65
+ constOp = rewriter.getContext ()
66
+ ->getLoadedDialect <ArithDialect>()
67
+ ->materializeConstant (rewriter, constAttr, value.getType (),
68
+ value.getLoc ());
69
+ if (!constOp)
70
+ return failure ();
71
+
72
+ rewriter.replaceAllUsesWith (value, constOp->getResult (0 ));
73
+ return success ();
102
74
}
103
75
104
76
namespace {
105
- // / This class listens on IR transformations performed during a pass relying on
106
- // / information from a `DataflowSolver`. It erases state associated with the
107
- // / erased operation and its results from the `DataFlowSolver` so that Patterns
108
- // / do not accidentally query old state information for newly created Ops.
109
77
class DataFlowListener : public RewriterBase ::Listener {
110
78
public:
111
79
DataFlowListener (DataFlowSolver &s) : s(s) {}
@@ -120,52 +88,95 @@ class DataFlowListener : public RewriterBase::Listener {
120
88
DataFlowSolver &s;
121
89
};
122
90
123
- struct ConvertCmpOp : public OpRewritePattern <arith::CmpIOp> {
91
+ // / Rewrite any results of `op` that were inferred to be constant integers to
92
+ // / and replace their uses with that constant. Return success() if all results
93
+ // / where thus replaced and the operation is erased. Also replace any block
94
+ // / arguments with their constant values.
95
+ struct MaterializeKnownConstantValues : public RewritePattern {
96
+ MaterializeKnownConstantValues (MLIRContext *context, DataFlowSolver &s)
97
+ : RewritePattern(Pattern::MatchAnyOpTypeTag(), /* benefit=*/ 1 , context),
98
+ solver (s) {}
99
+
100
+ LogicalResult match (Operation *op) const override {
101
+ if (matchPattern (op, m_Constant ()))
102
+ return failure ();
124
103
125
- ConvertCmpOp (MLIRContext *context, DataFlowSolver &s)
126
- : OpRewritePattern<arith::CmpIOp>(context), solver(s) {}
104
+ auto needsReplacing = [&](Value v) {
105
+ return getMaybeConstantValue (solver, v).has_value () && !v.use_empty ();
106
+ };
107
+ bool hasConstantResults = llvm::any_of (op->getResults (), needsReplacing);
108
+ if (op->getNumRegions () == 0 )
109
+ return success (hasConstantResults);
110
+ bool hasConstantRegionArgs = false ;
111
+ for (Region ®ion : op->getRegions ()) {
112
+ for (Block &block : region.getBlocks ()) {
113
+ hasConstantRegionArgs |=
114
+ llvm::any_of (block.getArguments (), needsReplacing);
115
+ }
116
+ }
117
+ return success (hasConstantResults || hasConstantRegionArgs);
118
+ }
127
119
128
- LogicalResult matchAndRewrite (arith::CmpIOp op,
120
+ void rewrite (Operation *op, PatternRewriter &rewriter) const override {
121
+ bool replacedAll = (op->getNumResults () != 0 );
122
+ for (Value v : op->getResults ())
123
+ replacedAll &=
124
+ (succeeded (maybeReplaceWithConstant (solver, rewriter, v)) ||
125
+ v.use_empty ());
126
+ if (replacedAll && isOpTriviallyDead (op)) {
127
+ rewriter.eraseOp (op);
128
+ return ;
129
+ }
130
+
131
+ PatternRewriter::InsertionGuard guard (rewriter);
132
+ for (Region ®ion : op->getRegions ()) {
133
+ for (Block &block : region.getBlocks ()) {
134
+ rewriter.setInsertionPointToStart (&block);
135
+ for (BlockArgument &arg : block.getArguments ()) {
136
+ (void )maybeReplaceWithConstant (solver, rewriter, arg);
137
+ }
138
+ }
139
+ }
140
+ }
141
+
142
+ private:
143
+ DataFlowSolver &solver;
144
+ };
145
+
146
+ template <typename RemOp>
147
+ struct DeleteTrivialRem : public OpRewritePattern <RemOp> {
148
+ DeleteTrivialRem (MLIRContext *context, DataFlowSolver &s)
149
+ : OpRewritePattern<RemOp>(context), solver(s) {}
150
+
151
+ LogicalResult matchAndRewrite (RemOp op,
129
152
PatternRewriter &rewriter) const override {
130
- auto *lhsResult =
131
- solver.lookupState <dataflow::IntegerValueRangeLattice>(op.getLhs ());
132
- if (!lhsResult || lhsResult->getValue ().isUninitialized ())
153
+ Value lhs = op.getOperand (0 );
154
+ Value rhs = op.getOperand (1 );
155
+ auto maybeModulus = getConstantIntValue (rhs);
156
+ if (!maybeModulus.has_value ())
133
157
return failure ();
134
-
135
- auto *rhsResult =
136
- solver.lookupState <dataflow::IntegerValueRangeLattice>(op.getRhs ());
137
- if (!rhsResult || rhsResult->getValue ().isUninitialized ())
158
+ int64_t modulus = *maybeModulus;
159
+ if (modulus <= 0 )
138
160
return failure ();
139
-
140
- using HandlerFunc =
141
- FailureOr<bool > (*)(ConstantIntRanges, ConstantIntRanges);
142
- std::array<HandlerFunc, arith::getMaxEnumValForCmpIPredicate () + 1 >
143
- handlers{};
144
- using Pred = arith::CmpIPredicate;
145
- handlers[static_cast <size_t >(Pred::eq)] = &handleEq;
146
- handlers[static_cast <size_t >(Pred::ne)] = &handleNe;
147
- handlers[static_cast <size_t >(Pred::slt)] = &handleSlt;
148
- handlers[static_cast <size_t >(Pred::sle)] = &handleSle;
149
- handlers[static_cast <size_t >(Pred::sgt)] = &handleSgt;
150
- handlers[static_cast <size_t >(Pred::sge)] = &handleSge;
151
- handlers[static_cast <size_t >(Pred::ult)] = &handleUlt;
152
- handlers[static_cast <size_t >(Pred::ule)] = &handleUle;
153
- handlers[static_cast <size_t >(Pred::ugt)] = &handleUgt;
154
- handlers[static_cast <size_t >(Pred::uge)] = &handleUge;
155
-
156
- HandlerFunc handler = handlers[static_cast <size_t >(op.getPredicate ())];
157
- if (!handler)
161
+ auto *maybeLhsRange = solver.lookupState <IntegerValueRangeLattice>(lhs);
162
+ if (!maybeLhsRange || maybeLhsRange->getValue ().isUninitialized ())
158
163
return failure ();
159
-
160
- ConstantIntRanges lhsValue = lhsResult->getValue ().getValue ();
161
- ConstantIntRanges rhsValue = rhsResult->getValue ().getValue ();
162
- FailureOr<bool > result = handler (lhsValue, rhsValue);
163
-
164
- if (failed (result))
164
+ const ConstantIntRanges &lhsRange = maybeLhsRange->getValue ().getValue ();
165
+ const APInt &min = isa<RemUIOp>(op) ? lhsRange.umin () : lhsRange.smin ();
166
+ const APInt &max = isa<RemUIOp>(op) ? lhsRange.umax () : lhsRange.smax ();
167
+ // The minima and maxima here are given as closed ranges, we must be
168
+ // strictly less than the modulus.
169
+ if (min.isNegative () || min.uge (modulus))
170
+ return failure ();
171
+ if (max.isNegative () || max.uge (modulus))
172
+ return failure ();
173
+ if (!min.ule (max))
165
174
return failure ();
166
175
167
- rewriter.replaceOpWithNewOp <arith::ConstantIntOp>(
168
- op, static_cast <int64_t >(*result), /* width*/ 1 );
176
+ // With all those conditions out of the way, we know thas this invocation of
177
+ // a remainder is a noop because the input is strictly within the range
178
+ // [0, modulus), so get rid of it.
179
+ rewriter.replaceOp (op, ValueRange{lhs});
169
180
return success ();
170
181
}
171
182
@@ -201,7 +212,8 @@ struct IntRangeOptimizationsPass
201
212
202
213
void mlir::arith::populateIntRangeOptimizationsPatterns (
203
214
RewritePatternSet &patterns, DataFlowSolver &solver) {
204
- patterns.add <ConvertCmpOp>(patterns.getContext (), solver);
215
+ patterns.add <MaterializeKnownConstantValues, DeleteTrivialRem<RemSIOp>,
216
+ DeleteTrivialRem<RemUIOp>>(patterns.getContext (), solver);
205
217
}
206
218
207
219
std::unique_ptr<Pass> mlir::arith::createIntRangeOptimizationsPass () {
0 commit comments