-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[mlir][Transforms][NFC] Turn op creation into IRRewrite
#81759
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-core Author: Matthias Springer (matthias-springer) ChangesThis commit is a refactoring of the dialect conversion. The dialect conversion maintains a list of "IR rewrites" that can be committed (upon success) or rolled back (upon failure). Until now, the dialect conversion kept track of "op creation" in separate internal data structures. This commit turns "op creation" into an Full diff: https://github.com/llvm/llvm-project/pull/81759.diff 1 Files Affected:
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index a07c8a56822de5..2bb56d5df1ebd3 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -152,17 +152,12 @@ namespace {
/// This class contains a snapshot of the current conversion rewriter state.
/// This is useful when saving and undoing a set of rewrites.
struct RewriterState {
- RewriterState(unsigned numCreatedOps, unsigned numUnresolvedMaterializations,
- unsigned numRewrites, unsigned numIgnoredOperations,
- unsigned numErased)
- : numCreatedOps(numCreatedOps),
- numUnresolvedMaterializations(numUnresolvedMaterializations),
+ RewriterState(unsigned numUnresolvedMaterializations, unsigned numRewrites,
+ unsigned numIgnoredOperations, unsigned numErased)
+ : numUnresolvedMaterializations(numUnresolvedMaterializations),
numRewrites(numRewrites), numIgnoredOperations(numIgnoredOperations),
numErased(numErased) {}
- /// The current number of created operations.
- unsigned numCreatedOps;
-
/// The current number of unresolved materializations.
unsigned numUnresolvedMaterializations;
@@ -299,7 +294,8 @@ class IRRewrite {
ReplaceBlockArg,
MoveOperation,
ModifyOperation,
- ReplaceOperation
+ ReplaceOperation,
+ CreateOperation
};
virtual ~IRRewrite() = default;
@@ -372,7 +368,11 @@ class CreateBlockRewrite : public BlockRewrite {
auto &blockOps = block->getOperations();
while (!blockOps.empty())
blockOps.remove(blockOps.begin());
- eraseBlock(block);
+ if (block->getParent()) {
+ eraseBlock(block);
+ } else {
+ delete block;
+ }
}
};
@@ -602,7 +602,7 @@ class OperationRewrite : public IRRewrite {
static bool classof(const IRRewrite *rewrite) {
return rewrite->getKind() >= Kind::MoveOperation &&
- rewrite->getKind() <= Kind::ReplaceOperation;
+ rewrite->getKind() <= Kind::CreateOperation;
}
protected:
@@ -708,6 +708,19 @@ class ReplaceOperationRewrite : public OperationRewrite {
/// 1->N conversion of some kind.
bool changedResults;
};
+
+class CreateOperationRewrite : public OperationRewrite {
+public:
+ CreateOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
+ Operation *op)
+ : OperationRewrite(Kind::CreateOperation, rewriterImpl, op) {}
+
+ static bool classof(const IRRewrite *rewrite) {
+ return rewrite->getKind() == Kind::CreateOperation;
+ }
+
+ void rollback() override;
+};
} // namespace
/// Return "true" if there is an operation rewrite that matches the specified
@@ -925,9 +938,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
// replacing a value with one of a different type.
ConversionValueMapping mapping;
- /// Ordered vector of all of the newly created operations during conversion.
- SmallVector<Operation *> createdOps;
-
/// Ordered vector of all unresolved type conversion materializations during
/// conversion.
SmallVector<UnresolvedMaterialization> unresolvedMaterializations;
@@ -1110,7 +1120,18 @@ void ReplaceOperationRewrite::rollback() {
void ReplaceOperationRewrite::cleanup() { eraseOp(op); }
+void CreateOperationRewrite::rollback() {
+ for (Region ®ion : op->getRegions()) {
+ while (!region.getBlocks().empty())
+ region.getBlocks().remove(region.getBlocks().begin());
+ }
+ op->dropAllUses();
+ eraseOp(op);
+}
+
void ConversionPatternRewriterImpl::detachNestedAndErase(Operation *op) {
+ // if (erasedIR.erasedOps.contains(op)) return;
+
for (Region ®ion : op->getRegions()) {
for (Block &block : region.getBlocks()) {
while (!block.getOperations().empty())
@@ -1127,8 +1148,6 @@ void ConversionPatternRewriterImpl::discardRewrites() {
// Remove any newly created ops.
for (UnresolvedMaterialization &materialization : unresolvedMaterializations)
detachNestedAndErase(materialization.getOp());
- for (auto *op : llvm::reverse(createdOps))
- detachNestedAndErase(op);
}
void ConversionPatternRewriterImpl::applyRewrites() {
@@ -1148,9 +1167,8 @@ void ConversionPatternRewriterImpl::applyRewrites() {
// State Management
RewriterState ConversionPatternRewriterImpl::getCurrentState() {
- return RewriterState(createdOps.size(), unresolvedMaterializations.size(),
- rewrites.size(), ignoredOps.size(),
- eraseRewriter.erased.size());
+ return RewriterState(unresolvedMaterializations.size(), rewrites.size(),
+ ignoredOps.size(), eraseRewriter.erased.size());
}
void ConversionPatternRewriterImpl::resetState(RewriterState state) {
@@ -1171,12 +1189,6 @@ void ConversionPatternRewriterImpl::resetState(RewriterState state) {
detachNestedAndErase(op);
}
- // Pop all of the newly created operations.
- while (createdOps.size() != state.numCreatedOps) {
- detachNestedAndErase(createdOps.back());
- createdOps.pop_back();
- }
-
// Pop all of the recorded ignored operations that are no longer valid.
while (ignoredOps.size() != state.numIgnoredOperations)
ignoredOps.pop_back();
@@ -1460,7 +1472,7 @@ void ConversionPatternRewriterImpl::notifyOperationInserted(
});
if (!previous.isSet()) {
// This is a newly created op.
- createdOps.push_back(op);
+ appendRewrite<CreateOperationRewrite>(op);
return;
}
Operation *prevOp = previous.getPoint() == previous.getBlock()->end()
@@ -1961,13 +1973,16 @@ OperationLegalizer::legalizeWithFold(Operation *op,
rewriter.replaceOp(op, replacementValues);
// Recursively legalize any new constant operations.
- for (unsigned i = curState.numCreatedOps, e = rewriterImpl.createdOps.size();
+ for (unsigned i = curState.numRewrites, e = rewriterImpl.rewrites.size();
i != e; ++i) {
- Operation *cstOp = rewriterImpl.createdOps[i];
- if (failed(legalize(cstOp, rewriter))) {
+ auto *createOp =
+ dyn_cast<CreateOperationRewrite>(rewriterImpl.rewrites[i].get());
+ if (!createOp)
+ continue;
+ if (failed(legalize(createOp->getOperation(), rewriter))) {
LLVM_DEBUG(logFailure(rewriterImpl.logger,
"failed to legalize generated constant '{0}'",
- cstOp->getName()));
+ createOp->getOperation()->getName()));
rewriterImpl.resetState(curState);
return failure();
}
@@ -2112,9 +2127,14 @@ LogicalResult OperationLegalizer::legalizePatternBlockRewrites(
// blocks in regions created by this pattern will already be legalized later
// on. If we haven't built the set yet, build it now.
if (operationsToIgnore.empty()) {
- auto createdOps = ArrayRef<Operation *>(impl.createdOps)
- .drop_front(state.numCreatedOps);
- operationsToIgnore.insert(createdOps.begin(), createdOps.end());
+ for (unsigned i = state.numRewrites, e = impl.rewrites.size(); i != e;
+ ++i) {
+ auto *createOp =
+ dyn_cast<CreateOperationRewrite>(impl.rewrites[i].get());
+ if (!createOp)
+ continue;
+ operationsToIgnore.insert(createOp->getOperation());
+ }
}
// If this operation should be considered for re-legalization, try it.
@@ -2132,8 +2152,11 @@ LogicalResult OperationLegalizer::legalizePatternBlockRewrites(
LogicalResult OperationLegalizer::legalizePatternCreatedOperations(
ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &impl,
RewriterState &state, RewriterState &newState) {
- for (int i = state.numCreatedOps, e = newState.numCreatedOps; i != e; ++i) {
- Operation *op = impl.createdOps[i];
+ for (int i = state.numRewrites, e = newState.numRewrites; i != e; ++i) {
+ auto *createOp = dyn_cast<CreateOperationRewrite>(impl.rewrites[i].get());
+ if (!createOp)
+ continue;
+ Operation *op = createOp->getOperation();
if (failed(legalize(op, rewriter))) {
LLVM_DEBUG(logFailure(impl.logger,
"failed to legalize generated operation '{0}'({1})",
@@ -2563,10 +2586,15 @@ LogicalResult OperationConverter::legalizeConvertedArgumentTypes(
});
return liveUserIt == val.user_end() ? nullptr : *liveUserIt;
};
- for (auto &r : rewriterImpl.rewrites)
- if (auto *rewrite = dyn_cast<BlockTypeConversionRewrite>(r.get()))
- if (failed(rewrite->materializeLiveConversions(findLiveUser)))
+ // Note: `rewrites` may be reallocated as the loop is running.
+ for (int64_t i = 0; i < rewriterImpl.rewrites.size(); ++i) {
+ auto &rewrite = rewriterImpl.rewrites[i];
+ if (auto *blockTypeConversionRewrite =
+ dyn_cast<BlockTypeConversionRewrite>(rewrite.get()))
+ if (failed(blockTypeConversionRewrite->materializeLiveConversions(
+ findLiveUser)))
return failure();
+ }
return success();
}
|
@llvm/pr-subscribers-mlir Author: Matthias Springer (matthias-springer) ChangesThis commit is a refactoring of the dialect conversion. The dialect conversion maintains a list of "IR rewrites" that can be committed (upon success) or rolled back (upon failure). Until now, the dialect conversion kept track of "op creation" in separate internal data structures. This commit turns "op creation" into an Full diff: https://github.com/llvm/llvm-project/pull/81759.diff 1 Files Affected:
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index a07c8a56822de5..2bb56d5df1ebd3 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -152,17 +152,12 @@ namespace {
/// This class contains a snapshot of the current conversion rewriter state.
/// This is useful when saving and undoing a set of rewrites.
struct RewriterState {
- RewriterState(unsigned numCreatedOps, unsigned numUnresolvedMaterializations,
- unsigned numRewrites, unsigned numIgnoredOperations,
- unsigned numErased)
- : numCreatedOps(numCreatedOps),
- numUnresolvedMaterializations(numUnresolvedMaterializations),
+ RewriterState(unsigned numUnresolvedMaterializations, unsigned numRewrites,
+ unsigned numIgnoredOperations, unsigned numErased)
+ : numUnresolvedMaterializations(numUnresolvedMaterializations),
numRewrites(numRewrites), numIgnoredOperations(numIgnoredOperations),
numErased(numErased) {}
- /// The current number of created operations.
- unsigned numCreatedOps;
-
/// The current number of unresolved materializations.
unsigned numUnresolvedMaterializations;
@@ -299,7 +294,8 @@ class IRRewrite {
ReplaceBlockArg,
MoveOperation,
ModifyOperation,
- ReplaceOperation
+ ReplaceOperation,
+ CreateOperation
};
virtual ~IRRewrite() = default;
@@ -372,7 +368,11 @@ class CreateBlockRewrite : public BlockRewrite {
auto &blockOps = block->getOperations();
while (!blockOps.empty())
blockOps.remove(blockOps.begin());
- eraseBlock(block);
+ if (block->getParent()) {
+ eraseBlock(block);
+ } else {
+ delete block;
+ }
}
};
@@ -602,7 +602,7 @@ class OperationRewrite : public IRRewrite {
static bool classof(const IRRewrite *rewrite) {
return rewrite->getKind() >= Kind::MoveOperation &&
- rewrite->getKind() <= Kind::ReplaceOperation;
+ rewrite->getKind() <= Kind::CreateOperation;
}
protected:
@@ -708,6 +708,19 @@ class ReplaceOperationRewrite : public OperationRewrite {
/// 1->N conversion of some kind.
bool changedResults;
};
+
+class CreateOperationRewrite : public OperationRewrite {
+public:
+ CreateOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
+ Operation *op)
+ : OperationRewrite(Kind::CreateOperation, rewriterImpl, op) {}
+
+ static bool classof(const IRRewrite *rewrite) {
+ return rewrite->getKind() == Kind::CreateOperation;
+ }
+
+ void rollback() override;
+};
} // namespace
/// Return "true" if there is an operation rewrite that matches the specified
@@ -925,9 +938,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
// replacing a value with one of a different type.
ConversionValueMapping mapping;
- /// Ordered vector of all of the newly created operations during conversion.
- SmallVector<Operation *> createdOps;
-
/// Ordered vector of all unresolved type conversion materializations during
/// conversion.
SmallVector<UnresolvedMaterialization> unresolvedMaterializations;
@@ -1110,7 +1120,18 @@ void ReplaceOperationRewrite::rollback() {
void ReplaceOperationRewrite::cleanup() { eraseOp(op); }
+void CreateOperationRewrite::rollback() {
+ for (Region ®ion : op->getRegions()) {
+ while (!region.getBlocks().empty())
+ region.getBlocks().remove(region.getBlocks().begin());
+ }
+ op->dropAllUses();
+ eraseOp(op);
+}
+
void ConversionPatternRewriterImpl::detachNestedAndErase(Operation *op) {
+ // if (erasedIR.erasedOps.contains(op)) return;
+
for (Region ®ion : op->getRegions()) {
for (Block &block : region.getBlocks()) {
while (!block.getOperations().empty())
@@ -1127,8 +1148,6 @@ void ConversionPatternRewriterImpl::discardRewrites() {
// Remove any newly created ops.
for (UnresolvedMaterialization &materialization : unresolvedMaterializations)
detachNestedAndErase(materialization.getOp());
- for (auto *op : llvm::reverse(createdOps))
- detachNestedAndErase(op);
}
void ConversionPatternRewriterImpl::applyRewrites() {
@@ -1148,9 +1167,8 @@ void ConversionPatternRewriterImpl::applyRewrites() {
// State Management
RewriterState ConversionPatternRewriterImpl::getCurrentState() {
- return RewriterState(createdOps.size(), unresolvedMaterializations.size(),
- rewrites.size(), ignoredOps.size(),
- eraseRewriter.erased.size());
+ return RewriterState(unresolvedMaterializations.size(), rewrites.size(),
+ ignoredOps.size(), eraseRewriter.erased.size());
}
void ConversionPatternRewriterImpl::resetState(RewriterState state) {
@@ -1171,12 +1189,6 @@ void ConversionPatternRewriterImpl::resetState(RewriterState state) {
detachNestedAndErase(op);
}
- // Pop all of the newly created operations.
- while (createdOps.size() != state.numCreatedOps) {
- detachNestedAndErase(createdOps.back());
- createdOps.pop_back();
- }
-
// Pop all of the recorded ignored operations that are no longer valid.
while (ignoredOps.size() != state.numIgnoredOperations)
ignoredOps.pop_back();
@@ -1460,7 +1472,7 @@ void ConversionPatternRewriterImpl::notifyOperationInserted(
});
if (!previous.isSet()) {
// This is a newly created op.
- createdOps.push_back(op);
+ appendRewrite<CreateOperationRewrite>(op);
return;
}
Operation *prevOp = previous.getPoint() == previous.getBlock()->end()
@@ -1961,13 +1973,16 @@ OperationLegalizer::legalizeWithFold(Operation *op,
rewriter.replaceOp(op, replacementValues);
// Recursively legalize any new constant operations.
- for (unsigned i = curState.numCreatedOps, e = rewriterImpl.createdOps.size();
+ for (unsigned i = curState.numRewrites, e = rewriterImpl.rewrites.size();
i != e; ++i) {
- Operation *cstOp = rewriterImpl.createdOps[i];
- if (failed(legalize(cstOp, rewriter))) {
+ auto *createOp =
+ dyn_cast<CreateOperationRewrite>(rewriterImpl.rewrites[i].get());
+ if (!createOp)
+ continue;
+ if (failed(legalize(createOp->getOperation(), rewriter))) {
LLVM_DEBUG(logFailure(rewriterImpl.logger,
"failed to legalize generated constant '{0}'",
- cstOp->getName()));
+ createOp->getOperation()->getName()));
rewriterImpl.resetState(curState);
return failure();
}
@@ -2112,9 +2127,14 @@ LogicalResult OperationLegalizer::legalizePatternBlockRewrites(
// blocks in regions created by this pattern will already be legalized later
// on. If we haven't built the set yet, build it now.
if (operationsToIgnore.empty()) {
- auto createdOps = ArrayRef<Operation *>(impl.createdOps)
- .drop_front(state.numCreatedOps);
- operationsToIgnore.insert(createdOps.begin(), createdOps.end());
+ for (unsigned i = state.numRewrites, e = impl.rewrites.size(); i != e;
+ ++i) {
+ auto *createOp =
+ dyn_cast<CreateOperationRewrite>(impl.rewrites[i].get());
+ if (!createOp)
+ continue;
+ operationsToIgnore.insert(createOp->getOperation());
+ }
}
// If this operation should be considered for re-legalization, try it.
@@ -2132,8 +2152,11 @@ LogicalResult OperationLegalizer::legalizePatternBlockRewrites(
LogicalResult OperationLegalizer::legalizePatternCreatedOperations(
ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &impl,
RewriterState &state, RewriterState &newState) {
- for (int i = state.numCreatedOps, e = newState.numCreatedOps; i != e; ++i) {
- Operation *op = impl.createdOps[i];
+ for (int i = state.numRewrites, e = newState.numRewrites; i != e; ++i) {
+ auto *createOp = dyn_cast<CreateOperationRewrite>(impl.rewrites[i].get());
+ if (!createOp)
+ continue;
+ Operation *op = createOp->getOperation();
if (failed(legalize(op, rewriter))) {
LLVM_DEBUG(logFailure(impl.logger,
"failed to legalize generated operation '{0}'({1})",
@@ -2563,10 +2586,15 @@ LogicalResult OperationConverter::legalizeConvertedArgumentTypes(
});
return liveUserIt == val.user_end() ? nullptr : *liveUserIt;
};
- for (auto &r : rewriterImpl.rewrites)
- if (auto *rewrite = dyn_cast<BlockTypeConversionRewrite>(r.get()))
- if (failed(rewrite->materializeLiveConversions(findLiveUser)))
+ // Note: `rewrites` may be reallocated as the loop is running.
+ for (int64_t i = 0; i < rewriterImpl.rewrites.size(); ++i) {
+ auto &rewrite = rewriterImpl.rewrites[i];
+ if (auto *blockTypeConversionRewrite =
+ dyn_cast<BlockTypeConversionRewrite>(rewrite.get()))
+ if (failed(blockTypeConversionRewrite->materializeLiveConversions(
+ findLiveUser)))
return failure();
+ }
return success();
}
|
8405efc
to
613a616
Compare
572d77c
to
3873a3e
Compare
613a616
to
b8d4cbd
Compare
3873a3e
to
6701034
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
@@ -299,7 +294,8 @@ class IRRewrite { | |||
ReplaceBlockArg, | |||
MoveOperation, | |||
ModifyOperation, | |||
ReplaceOperation |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mmm, come to think of it: would it make sense to have "marker" types here so that you wouldn't need to change below if you add types & also where one adds entries here is self-documented due to markers?
886f558
to
6f7d3e7
Compare
6701034
to
d15c439
Compare
d15c439
to
e4e7d7c
Compare
This commit is a refactoring of the dialect conversion. The dialect conversion maintains a list of "IR rewrites" that can be commited (upon success) or rolled back (upon failure). Until now, the dialect conversion kept track of "op creation" in separate internal data structures. This commit turns "op creation" into an `IRRewrite` that can be committed and rolled back just like any other rewrite. This commit simplifies the internal state of the dialect conversion.
e4e7d7c
to
8859a66
Compare
This commit fixes memory leaks that were introduced by llvm#81759. The way ops and blocks are erased changed slightly. The leaks were caused by an incorrect implementation of op builders: blocks must be created with the supplied builder object. Otherwise, they are not properly tracked by the dialect conversion and can leak during rollback.
This commit fixes memory leaks that were introduced by #81759. The way ops and blocks are erased changed slightly. The leaks were caused by an incorrect implementation of op builders: blocks must be created with the supplied builder object. Otherwise, they are not properly tracked by the dialect conversion and can leak during rollback.
This commit is a refactoring of the dialect conversion. The dialect conversion maintains a list of "IR rewrites" that can be committed (upon success) or rolled back (upon failure).
Until now, the dialect conversion kept track of "op creation" in separate internal data structures. This commit turns "op creation" into an
IRRewrite
that can be committed and rolled back just like any other rewrite. This commit simplifies the internal state of the dialect conversion.