Skip to content

Commit 0ea96e4

Browse files
Track erased ops separately (llvm#83051)
llvm#83023 fixed a performance regression related to "ignored" ops. This broke some downstream projects that access ops after they were replaced (an API violation). This change restores the original behavior before llvm#83023 (but without the performance regression), to give downstream users more time to fix their code.
1 parent c2042c3 commit 0ea96e4

File tree

1 file changed

+17
-6
lines changed

1 file changed

+17
-6
lines changed

mlir/lib/Transforms/Utils/DialectConversion.cpp

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -153,9 +153,9 @@ namespace {
153153
/// This is useful when saving and undoing a set of rewrites.
154154
struct RewriterState {
155155
RewriterState(unsigned numRewrites, unsigned numIgnoredOperations,
156-
unsigned numErased)
156+
unsigned numErased, unsigned numReplacedOps)
157157
: numRewrites(numRewrites), numIgnoredOperations(numIgnoredOperations),
158-
numErased(numErased) {}
158+
numErased(numErased), numReplacedOps(numReplacedOps) {}
159159

160160
/// The current number of rewrites performed.
161161
unsigned numRewrites;
@@ -165,6 +165,9 @@ struct RewriterState {
165165

166166
/// The current number of erased operations/blocks.
167167
unsigned numErased;
168+
169+
/// The current number of replaced ops that are scheduled for erasure.
170+
unsigned numReplacedOps;
168171
};
169172

170173
//===----------------------------------------------------------------------===//
@@ -954,6 +957,9 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
954957
/// operation was ignored.
955958
SetVector<Operation *> ignoredOps;
956959

960+
// A set of operations that were erased.
961+
SetVector<Operation *> replacedOps;
962+
957963
/// The current type converter, or nullptr if no type converter is currently
958964
/// active.
959965
const TypeConverter *currentTypeConverter = nullptr;
@@ -1152,7 +1158,7 @@ void ConversionPatternRewriterImpl::applyRewrites() {
11521158

11531159
RewriterState ConversionPatternRewriterImpl::getCurrentState() {
11541160
return RewriterState(rewrites.size(), ignoredOps.size(),
1155-
eraseRewriter.erased.size());
1161+
eraseRewriter.erased.size(), replacedOps.size());
11561162
}
11571163

11581164
void ConversionPatternRewriterImpl::resetState(RewriterState state) {
@@ -1165,6 +1171,9 @@ void ConversionPatternRewriterImpl::resetState(RewriterState state) {
11651171

11661172
while (eraseRewriter.erased.size() != state.numErased)
11671173
eraseRewriter.erased.pop_back();
1174+
1175+
while (replacedOps.size() != state.numReplacedOps)
1176+
replacedOps.pop_back();
11681177
}
11691178

11701179
void ConversionPatternRewriterImpl::undoRewrites(unsigned numRewritesToKeep) {
@@ -1228,9 +1237,11 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
12281237
return success();
12291238
}
12301239

1240+
// TODO: This function is a misnomer. It does not actually check if `op` is in
1241+
// `ignoredOps`.
12311242
bool ConversionPatternRewriterImpl::isOpIgnored(Operation *op) const {
12321243
// Check to see if this operation or the parent operation is ignored.
1233-
return ignoredOps.count(op->getParentOp()) || ignoredOps.count(op);
1244+
return ignoredOps.count(op->getParentOp()) || replacedOps.count(op);
12341245
}
12351246

12361247
void ConversionPatternRewriterImpl::markNestedOpsIgnored(Operation *op) {
@@ -1479,7 +1490,7 @@ void ConversionPatternRewriterImpl::notifyOperationInserted(
14791490
void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op,
14801491
ValueRange newValues) {
14811492
assert(newValues.size() == op->getNumResults());
1482-
assert(!ignoredOps.contains(op) && "operation was already replaced");
1493+
assert(!replacedOps.contains(op) && "operation was already replaced");
14831494

14841495
// Track if any of the results changed, e.g. erased and replaced with null.
14851496
bool resultChanged = false;
@@ -1500,7 +1511,7 @@ void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op,
15001511

15011512
// Mark this operation as recursively ignored so that we don't need to
15021513
// convert any nested operations.
1503-
ignoredOps.insert(op);
1514+
replacedOps.insert(op);
15041515
markNestedOpsIgnored(op);
15051516
}
15061517

0 commit comments

Comments
 (0)