Skip to content

Commit 78b3a00

Browse files
authored
[mlir] int-range-optmizations: Fix referencing of deleted ops (#91807)
The pass runs a `DataFlowSolver` and collects state information on the input IR. Then, the rewrite driver and folding is applied. During pattern application and folding it can happen that an Op from the input IR is deleted and a new Op is created at the same address. When the newly created Ops is looked up in the `DataFlowSolver` state memory, the state of the original Op is returned. This patch adds a method to `DataFlowSolver` which removes all state related to a `ProgramPoint`. It also adds a listener to the Pass which clears the state information of deleted Ops from the `DataFlowSolver`. Fix #81228
1 parent 502e77d commit 78b3a00

File tree

2 files changed

+35
-1
lines changed

2 files changed

+35
-1
lines changed

mlir/include/mlir/Analysis/DataFlowFramework.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,17 @@ class DataFlowSolver {
242242
return static_cast<const StateT *>(it->second.get());
243243
}
244244

245+
/// Erase any analysis state associated with the given program point.
246+
template <typename PointT>
247+
void eraseState(PointT point) {
248+
ProgramPoint pp(point);
249+
250+
for (auto it = analysisStates.begin(); it != analysisStates.end(); ++it) {
251+
if (it->first.first == pp)
252+
analysisStates.erase(it);
253+
}
254+
}
255+
245256
/// Get a uniqued program point instance. If one is not present, it is
246257
/// created with the provided arguments.
247258
template <typename PointT, typename... Args>

mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,24 @@ static FailureOr<bool> handleUge(ConstantIntRanges lhs, ConstantIntRanges rhs) {
102102
}
103103

104104
namespace {
105+
/// This class listens on IR transformations performed during a pass relying on
106+
/// information from a `DataflowSolver`. It erases state associated with the
107+
/// erased operation and its results from the `DataFlowSolver` so that Patterns
108+
/// do not accidentally query old state information for newly created Ops.
109+
class DataFlowListener : public RewriterBase::Listener {
110+
public:
111+
DataFlowListener(DataFlowSolver &s) : s(s) {}
112+
113+
protected:
114+
void notifyOperationErased(Operation *op) override {
115+
s.eraseState(op);
116+
for (Value res : op->getResults())
117+
s.eraseState(res);
118+
}
119+
120+
DataFlowSolver &s;
121+
};
122+
105123
struct ConvertCmpOp : public OpRewritePattern<arith::CmpIOp> {
106124

107125
ConvertCmpOp(MLIRContext *context, DataFlowSolver &s)
@@ -167,10 +185,15 @@ struct IntRangeOptimizationsPass
167185
if (failed(solver.initializeAndRun(op)))
168186
return signalPassFailure();
169187

188+
DataFlowListener listener(solver);
189+
170190
RewritePatternSet patterns(ctx);
171191
populateIntRangeOptimizationsPatterns(patterns, solver);
172192

173-
if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
193+
GreedyRewriteConfig config;
194+
config.listener = &listener;
195+
196+
if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns), config)))
174197
signalPassFailure();
175198
}
176199
};

0 commit comments

Comments
 (0)