Skip to content

Commit 9ca70d7

Browse files
[mlir][Transforms][NFC] Turn op creation into IRRewrite (#81759)
This commit is a refactoring of the dialect conversion. The dialect conversion maintains a list of "IR rewrites" that can be committed (upon success) or rolled back (upon failure). Until now, the dialect conversion kept track of "op creation" in separate internal data structures. This commit turns "op creation" into an `IRRewrite` that can be committed and rolled back just like any other rewrite. This commit simplifies the internal state of the dialect conversion.
1 parent ace83da commit 9ca70d7

File tree

1 file changed

+64
-38
lines changed

1 file changed

+64
-38
lines changed

mlir/lib/Transforms/Utils/DialectConversion.cpp

Lines changed: 64 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -152,17 +152,12 @@ namespace {
152152
/// This class contains a snapshot of the current conversion rewriter state.
153153
/// This is useful when saving and undoing a set of rewrites.
154154
struct RewriterState {
155-
RewriterState(unsigned numCreatedOps, unsigned numUnresolvedMaterializations,
156-
unsigned numRewrites, unsigned numIgnoredOperations,
157-
unsigned numErased)
158-
: numCreatedOps(numCreatedOps),
159-
numUnresolvedMaterializations(numUnresolvedMaterializations),
155+
RewriterState(unsigned numUnresolvedMaterializations, unsigned numRewrites,
156+
unsigned numIgnoredOperations, unsigned numErased)
157+
: numUnresolvedMaterializations(numUnresolvedMaterializations),
160158
numRewrites(numRewrites), numIgnoredOperations(numIgnoredOperations),
161159
numErased(numErased) {}
162160

163-
/// The current number of created operations.
164-
unsigned numCreatedOps;
165-
166161
/// The current number of unresolved materializations.
167162
unsigned numUnresolvedMaterializations;
168163

@@ -303,7 +298,8 @@ class IRRewrite {
303298
// Operation rewrites
304299
MoveOperation,
305300
ModifyOperation,
306-
ReplaceOperation
301+
ReplaceOperation,
302+
CreateOperation
307303
};
308304

309305
virtual ~IRRewrite() = default;
@@ -376,7 +372,10 @@ class CreateBlockRewrite : public BlockRewrite {
376372
auto &blockOps = block->getOperations();
377373
while (!blockOps.empty())
378374
blockOps.remove(blockOps.begin());
379-
eraseBlock(block);
375+
if (block->getParent())
376+
eraseBlock(block);
377+
else
378+
delete block;
380379
}
381380
};
382381

@@ -606,7 +605,7 @@ class OperationRewrite : public IRRewrite {
606605

607606
static bool classof(const IRRewrite *rewrite) {
608607
return rewrite->getKind() >= Kind::MoveOperation &&
609-
rewrite->getKind() <= Kind::ReplaceOperation;
608+
rewrite->getKind() <= Kind::CreateOperation;
610609
}
611610

612611
protected:
@@ -740,6 +739,19 @@ class ReplaceOperationRewrite : public OperationRewrite {
740739
/// A boolean flag that indicates whether result types have changed or not.
741740
bool changedResults;
742741
};
742+
743+
class CreateOperationRewrite : public OperationRewrite {
744+
public:
745+
CreateOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
746+
Operation *op)
747+
: OperationRewrite(Kind::CreateOperation, rewriterImpl, op) {}
748+
749+
static bool classof(const IRRewrite *rewrite) {
750+
return rewrite->getKind() == Kind::CreateOperation;
751+
}
752+
753+
void rollback() override;
754+
};
743755
} // namespace
744756

745757
/// Return "true" if there is an operation rewrite that matches the specified
@@ -957,9 +969,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
957969
// replacing a value with one of a different type.
958970
ConversionValueMapping mapping;
959971

960-
/// Ordered vector of all of the newly created operations during conversion.
961-
SmallVector<Operation *> createdOps;
962-
963972
/// Ordered vector of all unresolved type conversion materializations during
964973
/// conversion.
965974
SmallVector<UnresolvedMaterialization> unresolvedMaterializations;
@@ -1144,6 +1153,15 @@ void ReplaceOperationRewrite::rollback() {
11441153

11451154
void ReplaceOperationRewrite::cleanup() { eraseOp(op); }
11461155

1156+
void CreateOperationRewrite::rollback() {
1157+
for (Region &region : op->getRegions()) {
1158+
while (!region.getBlocks().empty())
1159+
region.getBlocks().remove(region.getBlocks().begin());
1160+
}
1161+
op->dropAllUses();
1162+
eraseOp(op);
1163+
}
1164+
11471165
void ConversionPatternRewriterImpl::detachNestedAndErase(Operation *op) {
11481166
for (Region &region : op->getRegions()) {
11491167
for (Block &block : region.getBlocks()) {
@@ -1161,8 +1179,6 @@ void ConversionPatternRewriterImpl::discardRewrites() {
11611179
// Remove any newly created ops.
11621180
for (UnresolvedMaterialization &materialization : unresolvedMaterializations)
11631181
detachNestedAndErase(materialization.getOp());
1164-
for (auto *op : llvm::reverse(createdOps))
1165-
detachNestedAndErase(op);
11661182
}
11671183

11681184
void ConversionPatternRewriterImpl::applyRewrites() {
@@ -1182,9 +1198,8 @@ void ConversionPatternRewriterImpl::applyRewrites() {
11821198
// State Management
11831199

11841200
RewriterState ConversionPatternRewriterImpl::getCurrentState() {
1185-
return RewriterState(createdOps.size(), unresolvedMaterializations.size(),
1186-
rewrites.size(), ignoredOps.size(),
1187-
eraseRewriter.erased.size());
1201+
return RewriterState(unresolvedMaterializations.size(), rewrites.size(),
1202+
ignoredOps.size(), eraseRewriter.erased.size());
11881203
}
11891204

11901205
void ConversionPatternRewriterImpl::resetState(RewriterState state) {
@@ -1205,12 +1220,6 @@ void ConversionPatternRewriterImpl::resetState(RewriterState state) {
12051220
detachNestedAndErase(op);
12061221
}
12071222

1208-
// Pop all of the newly created operations.
1209-
while (createdOps.size() != state.numCreatedOps) {
1210-
detachNestedAndErase(createdOps.back());
1211-
createdOps.pop_back();
1212-
}
1213-
12141223
// Pop all of the recorded ignored operations that are no longer valid.
12151224
while (ignoredOps.size() != state.numIgnoredOperations)
12161225
ignoredOps.pop_back();
@@ -1478,7 +1487,7 @@ void ConversionPatternRewriterImpl::notifyOperationInserted(
14781487
});
14791488
if (!previous.isSet()) {
14801489
// This is a newly created op.
1481-
createdOps.push_back(op);
1490+
appendRewrite<CreateOperationRewrite>(op);
14821491
return;
14831492
}
14841493
Operation *prevOp = previous.getPoint() == previous.getBlock()->end()
@@ -1979,13 +1988,16 @@ OperationLegalizer::legalizeWithFold(Operation *op,
19791988
rewriter.replaceOp(op, replacementValues);
19801989

19811990
// Recursively legalize any new constant operations.
1982-
for (unsigned i = curState.numCreatedOps, e = rewriterImpl.createdOps.size();
1991+
for (unsigned i = curState.numRewrites, e = rewriterImpl.rewrites.size();
19831992
i != e; ++i) {
1984-
Operation *cstOp = rewriterImpl.createdOps[i];
1985-
if (failed(legalize(cstOp, rewriter))) {
1993+
auto *createOp =
1994+
dyn_cast<CreateOperationRewrite>(rewriterImpl.rewrites[i].get());
1995+
if (!createOp)
1996+
continue;
1997+
if (failed(legalize(createOp->getOperation(), rewriter))) {
19861998
LLVM_DEBUG(logFailure(rewriterImpl.logger,
19871999
"failed to legalize generated constant '{0}'",
1988-
cstOp->getName()));
2000+
createOp->getOperation()->getName()));
19892001
rewriterImpl.resetState(curState);
19902002
return failure();
19912003
}
@@ -2132,9 +2144,14 @@ LogicalResult OperationLegalizer::legalizePatternBlockRewrites(
21322144
// blocks in regions created by this pattern will already be legalized later
21332145
// on. If we haven't built the set yet, build it now.
21342146
if (operationsToIgnore.empty()) {
2135-
auto createdOps = ArrayRef<Operation *>(impl.createdOps)
2136-
.drop_front(state.numCreatedOps);
2137-
operationsToIgnore.insert(createdOps.begin(), createdOps.end());
2147+
for (unsigned i = state.numRewrites, e = impl.rewrites.size(); i != e;
2148+
++i) {
2149+
auto *createOp =
2150+
dyn_cast<CreateOperationRewrite>(impl.rewrites[i].get());
2151+
if (!createOp)
2152+
continue;
2153+
operationsToIgnore.insert(createOp->getOperation());
2154+
}
21382155
}
21392156

21402157
// If this operation should be considered for re-legalization, try it.
@@ -2152,8 +2169,11 @@ LogicalResult OperationLegalizer::legalizePatternBlockRewrites(
21522169
LogicalResult OperationLegalizer::legalizePatternCreatedOperations(
21532170
ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &impl,
21542171
RewriterState &state, RewriterState &newState) {
2155-
for (int i = state.numCreatedOps, e = newState.numCreatedOps; i != e; ++i) {
2156-
Operation *op = impl.createdOps[i];
2172+
for (int i = state.numRewrites, e = newState.numRewrites; i != e; ++i) {
2173+
auto *createOp = dyn_cast<CreateOperationRewrite>(impl.rewrites[i].get());
2174+
if (!createOp)
2175+
continue;
2176+
Operation *op = createOp->getOperation();
21572177
if (failed(legalize(op, rewriter))) {
21582178
LLVM_DEBUG(logFailure(impl.logger,
21592179
"failed to legalize generated operation '{0}'({1})",
@@ -2583,10 +2603,16 @@ LogicalResult OperationConverter::legalizeConvertedArgumentTypes(
25832603
});
25842604
return liveUserIt == val.user_end() ? nullptr : *liveUserIt;
25852605
};
2586-
for (auto &r : rewriterImpl.rewrites)
2587-
if (auto *rewrite = dyn_cast<BlockTypeConversionRewrite>(r.get()))
2588-
if (failed(rewrite->materializeLiveConversions(findLiveUser)))
2606+
// Note: `rewrites` may be reallocated as the loop is running.
2607+
for (int64_t i = 0; i < static_cast<int64_t>(rewriterImpl.rewrites.size());
2608+
++i) {
2609+
auto &rewrite = rewriterImpl.rewrites[i];
2610+
if (auto *blockTypeConversionRewrite =
2611+
dyn_cast<BlockTypeConversionRewrite>(rewrite.get()))
2612+
if (failed(blockTypeConversionRewrite->materializeLiveConversions(
2613+
findLiveUser)))
25892614
return failure();
2615+
}
25902616
return success();
25912617
}
25922618

0 commit comments

Comments
 (0)