Skip to content

Commit 3873a3e

Browse files
[mlir][Transforms][NFC] Turn op creation into IRRewrite
This commit is a refactoring of the dialect conversion. The dialect conversion maintains a list of "IR rewrites" that can be commited (upon success) or rolled back (upon failure). Until now, the dialect conversion kept track of "op creation" in separate internal data structures. This commit turns "op creation" into an `IRRewrite` that can be committed and rolled back just like any other rewrite. This commit simplifies the internal state of the dialect conversion.
1 parent 613a616 commit 3873a3e

File tree

1 file changed

+66
-38
lines changed

1 file changed

+66
-38
lines changed

mlir/lib/Transforms/Utils/DialectConversion.cpp

Lines changed: 66 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -152,17 +152,12 @@ namespace {
152152
/// This class contains a snapshot of the current conversion rewriter state.
153153
/// This is useful when saving and undoing a set of rewrites.
154154
struct RewriterState {
155-
RewriterState(unsigned numCreatedOps, unsigned numUnresolvedMaterializations,
156-
unsigned numRewrites, unsigned numIgnoredOperations,
157-
unsigned numErased)
158-
: numCreatedOps(numCreatedOps),
159-
numUnresolvedMaterializations(numUnresolvedMaterializations),
155+
RewriterState(unsigned numUnresolvedMaterializations, unsigned numRewrites,
156+
unsigned numIgnoredOperations, unsigned numErased)
157+
: numUnresolvedMaterializations(numUnresolvedMaterializations),
160158
numRewrites(numRewrites), numIgnoredOperations(numIgnoredOperations),
161159
numErased(numErased) {}
162160

163-
/// The current number of created operations.
164-
unsigned numCreatedOps;
165-
166161
/// The current number of unresolved materializations.
167162
unsigned numUnresolvedMaterializations;
168163

@@ -299,7 +294,8 @@ class IRRewrite {
299294
ReplaceBlockArg,
300295
MoveOperation,
301296
ModifyOperation,
302-
ReplaceOperation
297+
ReplaceOperation,
298+
CreateOperation
303299
};
304300

305301
virtual ~IRRewrite() = default;
@@ -372,7 +368,11 @@ class CreateBlockRewrite : public BlockRewrite {
372368
auto &blockOps = block->getOperations();
373369
while (!blockOps.empty())
374370
blockOps.remove(blockOps.begin());
375-
eraseBlock(block);
371+
if (block->getParent()) {
372+
eraseBlock(block);
373+
} else {
374+
delete block;
375+
}
376376
}
377377
};
378378

@@ -602,7 +602,7 @@ class OperationRewrite : public IRRewrite {
602602

603603
static bool classof(const IRRewrite *rewrite) {
604604
return rewrite->getKind() >= Kind::MoveOperation &&
605-
rewrite->getKind() <= Kind::ReplaceOperation;
605+
rewrite->getKind() <= Kind::CreateOperation;
606606
}
607607

608608
protected:
@@ -708,6 +708,19 @@ class ReplaceOperationRewrite : public OperationRewrite {
708708
/// 1->N conversion of some kind.
709709
bool changedResults;
710710
};
711+
712+
class CreateOperationRewrite : public OperationRewrite {
713+
public:
714+
CreateOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
715+
Operation *op)
716+
: OperationRewrite(Kind::CreateOperation, rewriterImpl, op) {}
717+
718+
static bool classof(const IRRewrite *rewrite) {
719+
return rewrite->getKind() == Kind::CreateOperation;
720+
}
721+
722+
void rollback() override;
723+
};
711724
} // namespace
712725

713726
/// Return "true" if there is an operation rewrite that matches the specified
@@ -925,9 +938,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
925938
// replacing a value with one of a different type.
926939
ConversionValueMapping mapping;
927940

928-
/// Ordered vector of all of the newly created operations during conversion.
929-
SmallVector<Operation *> createdOps;
930-
931941
/// Ordered vector of all unresolved type conversion materializations during
932942
/// conversion.
933943
SmallVector<UnresolvedMaterialization> unresolvedMaterializations;
@@ -1110,7 +1120,18 @@ void ReplaceOperationRewrite::rollback() {
11101120

11111121
void ReplaceOperationRewrite::cleanup() { eraseOp(op); }
11121122

1123+
void CreateOperationRewrite::rollback() {
1124+
for (Region &region : op->getRegions()) {
1125+
while (!region.getBlocks().empty())
1126+
region.getBlocks().remove(region.getBlocks().begin());
1127+
}
1128+
op->dropAllUses();
1129+
eraseOp(op);
1130+
}
1131+
11131132
void ConversionPatternRewriterImpl::detachNestedAndErase(Operation *op) {
1133+
// if (erasedIR.erasedOps.contains(op)) return;
1134+
11141135
for (Region &region : op->getRegions()) {
11151136
for (Block &block : region.getBlocks()) {
11161137
while (!block.getOperations().empty())
@@ -1127,8 +1148,6 @@ void ConversionPatternRewriterImpl::discardRewrites() {
11271148
// Remove any newly created ops.
11281149
for (UnresolvedMaterialization &materialization : unresolvedMaterializations)
11291150
detachNestedAndErase(materialization.getOp());
1130-
for (auto *op : llvm::reverse(createdOps))
1131-
detachNestedAndErase(op);
11321151
}
11331152

11341153
void ConversionPatternRewriterImpl::applyRewrites() {
@@ -1148,9 +1167,8 @@ void ConversionPatternRewriterImpl::applyRewrites() {
11481167
// State Management
11491168

11501169
RewriterState ConversionPatternRewriterImpl::getCurrentState() {
1151-
return RewriterState(createdOps.size(), unresolvedMaterializations.size(),
1152-
rewrites.size(), ignoredOps.size(),
1153-
eraseRewriter.erased.size());
1170+
return RewriterState(unresolvedMaterializations.size(), rewrites.size(),
1171+
ignoredOps.size(), eraseRewriter.erased.size());
11541172
}
11551173

11561174
void ConversionPatternRewriterImpl::resetState(RewriterState state) {
@@ -1171,12 +1189,6 @@ void ConversionPatternRewriterImpl::resetState(RewriterState state) {
11711189
detachNestedAndErase(op);
11721190
}
11731191

1174-
// Pop all of the newly created operations.
1175-
while (createdOps.size() != state.numCreatedOps) {
1176-
detachNestedAndErase(createdOps.back());
1177-
createdOps.pop_back();
1178-
}
1179-
11801192
// Pop all of the recorded ignored operations that are no longer valid.
11811193
while (ignoredOps.size() != state.numIgnoredOperations)
11821194
ignoredOps.pop_back();
@@ -1460,7 +1472,7 @@ void ConversionPatternRewriterImpl::notifyOperationInserted(
14601472
});
14611473
if (!previous.isSet()) {
14621474
// This is a newly created op.
1463-
createdOps.push_back(op);
1475+
appendRewrite<CreateOperationRewrite>(op);
14641476
return;
14651477
}
14661478
Operation *prevOp = previous.getPoint() == previous.getBlock()->end()
@@ -1961,13 +1973,16 @@ OperationLegalizer::legalizeWithFold(Operation *op,
19611973
rewriter.replaceOp(op, replacementValues);
19621974

19631975
// Recursively legalize any new constant operations.
1964-
for (unsigned i = curState.numCreatedOps, e = rewriterImpl.createdOps.size();
1976+
for (unsigned i = curState.numRewrites, e = rewriterImpl.rewrites.size();
19651977
i != e; ++i) {
1966-
Operation *cstOp = rewriterImpl.createdOps[i];
1967-
if (failed(legalize(cstOp, rewriter))) {
1978+
auto *createOp =
1979+
dyn_cast<CreateOperationRewrite>(rewriterImpl.rewrites[i].get());
1980+
if (!createOp)
1981+
continue;
1982+
if (failed(legalize(createOp->getOperation(), rewriter))) {
19681983
LLVM_DEBUG(logFailure(rewriterImpl.logger,
19691984
"failed to legalize generated constant '{0}'",
1970-
cstOp->getName()));
1985+
createOp->getOperation()->getName()));
19711986
rewriterImpl.resetState(curState);
19721987
return failure();
19731988
}
@@ -2112,9 +2127,14 @@ LogicalResult OperationLegalizer::legalizePatternBlockRewrites(
21122127
// blocks in regions created by this pattern will already be legalized later
21132128
// on. If we haven't built the set yet, build it now.
21142129
if (operationsToIgnore.empty()) {
2115-
auto createdOps = ArrayRef<Operation *>(impl.createdOps)
2116-
.drop_front(state.numCreatedOps);
2117-
operationsToIgnore.insert(createdOps.begin(), createdOps.end());
2130+
for (unsigned i = state.numRewrites, e = impl.rewrites.size(); i != e;
2131+
++i) {
2132+
auto *createOp =
2133+
dyn_cast<CreateOperationRewrite>(impl.rewrites[i].get());
2134+
if (!createOp)
2135+
continue;
2136+
operationsToIgnore.insert(createOp->getOperation());
2137+
}
21182138
}
21192139

21202140
// If this operation should be considered for re-legalization, try it.
@@ -2132,8 +2152,11 @@ LogicalResult OperationLegalizer::legalizePatternBlockRewrites(
21322152
LogicalResult OperationLegalizer::legalizePatternCreatedOperations(
21332153
ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &impl,
21342154
RewriterState &state, RewriterState &newState) {
2135-
for (int i = state.numCreatedOps, e = newState.numCreatedOps; i != e; ++i) {
2136-
Operation *op = impl.createdOps[i];
2155+
for (int i = state.numRewrites, e = newState.numRewrites; i != e; ++i) {
2156+
auto *createOp = dyn_cast<CreateOperationRewrite>(impl.rewrites[i].get());
2157+
if (!createOp)
2158+
continue;
2159+
Operation *op = createOp->getOperation();
21372160
if (failed(legalize(op, rewriter))) {
21382161
LLVM_DEBUG(logFailure(impl.logger,
21392162
"failed to legalize generated operation '{0}'({1})",
@@ -2563,10 +2586,15 @@ LogicalResult OperationConverter::legalizeConvertedArgumentTypes(
25632586
});
25642587
return liveUserIt == val.user_end() ? nullptr : *liveUserIt;
25652588
};
2566-
for (auto &r : rewriterImpl.rewrites)
2567-
if (auto *rewrite = dyn_cast<BlockTypeConversionRewrite>(r.get()))
2568-
if (failed(rewrite->materializeLiveConversions(findLiveUser)))
2589+
// Note: `rewrites` may be reallocated as the loop is running.
2590+
for (int64_t i = 0; i < rewriterImpl.rewrites.size(); ++i) {
2591+
auto &rewrite = rewriterImpl.rewrites[i];
2592+
if (auto *blockTypeConversionRewrite =
2593+
dyn_cast<BlockTypeConversionRewrite>(rewrite.get()))
2594+
if (failed(blockTypeConversionRewrite->materializeLiveConversions(
2595+
findLiveUser)))
25692596
return failure();
2597+
}
25702598
return success();
25712599
}
25722600

0 commit comments

Comments
 (0)