Skip to content

Commit 6701034

Browse files
[mlir][Transforms][NFC] Turn op creation into IRRewrite
This commit is a refactoring of the dialect conversion. The dialect conversion maintains a list of "IR rewrites" that can be commited (upon success) or rolled back (upon failure). Until now, the dialect conversion kept track of "op creation" in separate internal data structures. This commit turns "op creation" into an `IRRewrite` that can be committed and rolled back just like any other rewrite. This commit simplifies the internal state of the dialect conversion.
1 parent b8d4cbd commit 6701034

File tree

1 file changed

+66
-38
lines changed

1 file changed

+66
-38
lines changed

mlir/lib/Transforms/Utils/DialectConversion.cpp

Lines changed: 66 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -152,17 +152,12 @@ namespace {
152152
/// This class contains a snapshot of the current conversion rewriter state.
153153
/// This is useful when saving and undoing a set of rewrites.
154154
struct RewriterState {
155-
RewriterState(unsigned numCreatedOps, unsigned numUnresolvedMaterializations,
156-
unsigned numRewrites, unsigned numIgnoredOperations,
157-
unsigned numErased)
158-
: numCreatedOps(numCreatedOps),
159-
numUnresolvedMaterializations(numUnresolvedMaterializations),
155+
RewriterState(unsigned numUnresolvedMaterializations, unsigned numRewrites,
156+
unsigned numIgnoredOperations, unsigned numErased)
157+
: numUnresolvedMaterializations(numUnresolvedMaterializations),
160158
numRewrites(numRewrites), numIgnoredOperations(numIgnoredOperations),
161159
numErased(numErased) {}
162160

163-
/// The current number of created operations.
164-
unsigned numCreatedOps;
165-
166161
/// The current number of unresolved materializations.
167162
unsigned numUnresolvedMaterializations;
168163

@@ -299,7 +294,8 @@ class IRRewrite {
299294
ReplaceBlockArg,
300295
MoveOperation,
301296
ModifyOperation,
302-
ReplaceOperation
297+
ReplaceOperation,
298+
CreateOperation
303299
};
304300

305301
virtual ~IRRewrite() = default;
@@ -372,7 +368,11 @@ class CreateBlockRewrite : public BlockRewrite {
372368
auto &blockOps = block->getOperations();
373369
while (!blockOps.empty())
374370
blockOps.remove(blockOps.begin());
375-
eraseBlock(block);
371+
if (block->getParent()) {
372+
eraseBlock(block);
373+
} else {
374+
delete block;
375+
}
376376
}
377377
};
378378

@@ -602,7 +602,7 @@ class OperationRewrite : public IRRewrite {
602602

603603
static bool classof(const IRRewrite *rewrite) {
604604
return rewrite->getKind() >= Kind::MoveOperation &&
605-
rewrite->getKind() <= Kind::ReplaceOperation;
605+
rewrite->getKind() <= Kind::CreateOperation;
606606
}
607607

608608
protected:
@@ -708,6 +708,19 @@ class ReplaceOperationRewrite : public OperationRewrite {
708708
/// 1->N conversion of some kind.
709709
bool changedResults;
710710
};
711+
712+
class CreateOperationRewrite : public OperationRewrite {
713+
public:
714+
CreateOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
715+
Operation *op)
716+
: OperationRewrite(Kind::CreateOperation, rewriterImpl, op) {}
717+
718+
static bool classof(const IRRewrite *rewrite) {
719+
return rewrite->getKind() == Kind::CreateOperation;
720+
}
721+
722+
void rollback() override;
723+
};
711724
} // namespace
712725

713726
/// Return "true" if there is an operation rewrite that matches the specified
@@ -925,9 +938,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
925938
// replacing a value with one of a different type.
926939
ConversionValueMapping mapping;
927940

928-
/// Ordered vector of all of the newly created operations during conversion.
929-
SmallVector<Operation *> createdOps;
930-
931941
/// Ordered vector of all unresolved type conversion materializations during
932942
/// conversion.
933943
SmallVector<UnresolvedMaterialization> unresolvedMaterializations;
@@ -1110,7 +1120,18 @@ void ReplaceOperationRewrite::rollback() {
11101120

11111121
void ReplaceOperationRewrite::cleanup() { eraseOp(op); }
11121122

1123+
void CreateOperationRewrite::rollback() {
1124+
for (Region &region : op->getRegions()) {
1125+
while (!region.getBlocks().empty())
1126+
region.getBlocks().remove(region.getBlocks().begin());
1127+
}
1128+
op->dropAllUses();
1129+
eraseOp(op);
1130+
}
1131+
11131132
void ConversionPatternRewriterImpl::detachNestedAndErase(Operation *op) {
1133+
// if (erasedIR.erasedOps.contains(op)) return;
1134+
11141135
for (Region &region : op->getRegions()) {
11151136
for (Block &block : region.getBlocks()) {
11161137
while (!block.getOperations().empty())
@@ -1127,8 +1148,6 @@ void ConversionPatternRewriterImpl::discardRewrites() {
11271148
// Remove any newly created ops.
11281149
for (UnresolvedMaterialization &materialization : unresolvedMaterializations)
11291150
detachNestedAndErase(materialization.getOp());
1130-
for (auto *op : llvm::reverse(createdOps))
1131-
detachNestedAndErase(op);
11321151
}
11331152

11341153
void ConversionPatternRewriterImpl::applyRewrites() {
@@ -1148,9 +1167,8 @@ void ConversionPatternRewriterImpl::applyRewrites() {
11481167
// State Management
11491168

11501169
RewriterState ConversionPatternRewriterImpl::getCurrentState() {
1151-
return RewriterState(createdOps.size(), unresolvedMaterializations.size(),
1152-
rewrites.size(), ignoredOps.size(),
1153-
eraseRewriter.erased.size());
1170+
return RewriterState(unresolvedMaterializations.size(), rewrites.size(),
1171+
ignoredOps.size(), eraseRewriter.erased.size());
11541172
}
11551173

11561174
void ConversionPatternRewriterImpl::resetState(RewriterState state) {
@@ -1171,12 +1189,6 @@ void ConversionPatternRewriterImpl::resetState(RewriterState state) {
11711189
detachNestedAndErase(op);
11721190
}
11731191

1174-
// Pop all of the newly created operations.
1175-
while (createdOps.size() != state.numCreatedOps) {
1176-
detachNestedAndErase(createdOps.back());
1177-
createdOps.pop_back();
1178-
}
1179-
11801192
// Pop all of the recorded ignored operations that are no longer valid.
11811193
while (ignoredOps.size() != state.numIgnoredOperations)
11821194
ignoredOps.pop_back();
@@ -1444,7 +1456,7 @@ void ConversionPatternRewriterImpl::notifyOperationInserted(
14441456
});
14451457
if (!previous.isSet()) {
14461458
// This is a newly created op.
1447-
createdOps.push_back(op);
1459+
appendRewrite<CreateOperationRewrite>(op);
14481460
return;
14491461
}
14501462
Operation *prevOp = previous.getPoint() == previous.getBlock()->end()
@@ -1945,13 +1957,16 @@ OperationLegalizer::legalizeWithFold(Operation *op,
19451957
rewriter.replaceOp(op, replacementValues);
19461958

19471959
// Recursively legalize any new constant operations.
1948-
for (unsigned i = curState.numCreatedOps, e = rewriterImpl.createdOps.size();
1960+
for (unsigned i = curState.numRewrites, e = rewriterImpl.rewrites.size();
19491961
i != e; ++i) {
1950-
Operation *cstOp = rewriterImpl.createdOps[i];
1951-
if (failed(legalize(cstOp, rewriter))) {
1962+
auto *createOp =
1963+
dyn_cast<CreateOperationRewrite>(rewriterImpl.rewrites[i].get());
1964+
if (!createOp)
1965+
continue;
1966+
if (failed(legalize(createOp->getOperation(), rewriter))) {
19521967
LLVM_DEBUG(logFailure(rewriterImpl.logger,
19531968
"failed to legalize generated constant '{0}'",
1954-
cstOp->getName()));
1969+
createOp->getOperation()->getName()));
19551970
rewriterImpl.resetState(curState);
19561971
return failure();
19571972
}
@@ -2098,9 +2113,14 @@ LogicalResult OperationLegalizer::legalizePatternBlockRewrites(
20982113
// blocks in regions created by this pattern will already be legalized later
20992114
// on. If we haven't built the set yet, build it now.
21002115
if (operationsToIgnore.empty()) {
2101-
auto createdOps = ArrayRef<Operation *>(impl.createdOps)
2102-
.drop_front(state.numCreatedOps);
2103-
operationsToIgnore.insert(createdOps.begin(), createdOps.end());
2116+
for (unsigned i = state.numRewrites, e = impl.rewrites.size(); i != e;
2117+
++i) {
2118+
auto *createOp =
2119+
dyn_cast<CreateOperationRewrite>(impl.rewrites[i].get());
2120+
if (!createOp)
2121+
continue;
2122+
operationsToIgnore.insert(createOp->getOperation());
2123+
}
21042124
}
21052125

21062126
// If this operation should be considered for re-legalization, try it.
@@ -2118,8 +2138,11 @@ LogicalResult OperationLegalizer::legalizePatternBlockRewrites(
21182138
LogicalResult OperationLegalizer::legalizePatternCreatedOperations(
21192139
ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &impl,
21202140
RewriterState &state, RewriterState &newState) {
2121-
for (int i = state.numCreatedOps, e = newState.numCreatedOps; i != e; ++i) {
2122-
Operation *op = impl.createdOps[i];
2141+
for (int i = state.numRewrites, e = newState.numRewrites; i != e; ++i) {
2142+
auto *createOp = dyn_cast<CreateOperationRewrite>(impl.rewrites[i].get());
2143+
if (!createOp)
2144+
continue;
2145+
Operation *op = createOp->getOperation();
21232146
if (failed(legalize(op, rewriter))) {
21242147
LLVM_DEBUG(logFailure(impl.logger,
21252148
"failed to legalize generated operation '{0}'({1})",
@@ -2549,10 +2572,15 @@ LogicalResult OperationConverter::legalizeConvertedArgumentTypes(
25492572
});
25502573
return liveUserIt == val.user_end() ? nullptr : *liveUserIt;
25512574
};
2552-
for (auto &r : rewriterImpl.rewrites)
2553-
if (auto *rewrite = dyn_cast<BlockTypeConversionRewrite>(r.get()))
2554-
if (failed(rewrite->materializeLiveConversions(findLiveUser)))
2575+
// Note: `rewrites` may be reallocated as the loop is running.
2576+
for (int64_t i = 0; i < rewriterImpl.rewrites.size(); ++i) {
2577+
auto &rewrite = rewriterImpl.rewrites[i];
2578+
if (auto *blockTypeConversionRewrite =
2579+
dyn_cast<BlockTypeConversionRewrite>(rewrite.get()))
2580+
if (failed(blockTypeConversionRewrite->materializeLiveConversions(
2581+
findLiveUser)))
25552582
return failure();
2583+
}
25562584
return success();
25572585
}
25582586

0 commit comments

Comments
 (0)