Skip to content

[mlir][Transforms] Dialect conversion: Add flag to disable rollback #136490

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions mlir/include/mlir/Transforms/DialectConversion.h
Original file line number Diff line number Diff line change
Expand Up @@ -1219,6 +1219,26 @@ struct ConversionConfig {
/// materializations and instead inserts "builtin.unrealized_conversion_cast"
/// ops to ensure that the resulting IR is valid.
bool buildMaterializations = true;

/// If set to "true", pattern rollback is allowed. The conversion driver
/// rolls back IR modifications in the following situations.
///
/// 1. Pattern implementation returns "failure" after modifying IR.
/// 2. Pattern produces IR (in-place modification or new IR) that is illegal
/// and cannot be legalized by subsequent foldings / pattern applications.
///
/// If set to "false", the conversion driver will produce an LLVM fatal error
/// instead of rolling back IR modifications. Moreover, in case of a failed
/// conversion, the original IR is not restored. The resulting IR may be a
/// mix of original and rewritten IR. (Same as a failed greedy pattern
/// rewrite.)
///
/// Note: This flag was added in preparation of the One-Shot Dialect
/// Conversion refactoring, which will remove the ability to roll back IR
/// modifications from the conversion driver. Use this flag to ensure that
/// your patterns do not trigger any IR rollbacks. For details, see
/// https://discourse.llvm.org/t/rfc-a-new-one-shot-dialect-conversion-driver/79083.
bool allowPatternRollback = true;
};

//===----------------------------------------------------------------------===//
Expand Down
56 changes: 42 additions & 14 deletions mlir/lib/Transforms/Utils/DialectConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -861,8 +861,10 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// conversion process succeeds.
void applyRewrites();

/// Reset the state of the rewriter to a previously saved point.
void resetState(RewriterState state);
/// Reset the state of the rewriter to a previously saved point. Optionally,
/// the name of the pattern that triggered the rollback can specified for
/// debugging purposes.
void resetState(RewriterState state, StringRef patternName = "");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are patternName params intended to stay post?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This entire function will disappear when we delete the rollback mechanism.


/// Append a rewrite. Rewrites are committed upon success and rolled back upon
/// failure.
Expand All @@ -873,8 +875,9 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
}

/// Undo the rewrites (motions, splits) one by one in reverse order until
/// "numRewritesToKeep" rewrites remains.
void undoRewrites(unsigned numRewritesToKeep = 0);
/// "numRewritesToKeep" rewrites remains. Optionally, the name of the pattern
/// that triggered the rollback can specified for debugging purposes.
void undoRewrites(unsigned numRewritesToKeep = 0, StringRef patternName = "");

/// Remap the given values to those with potentially different types. Returns
/// success if the values could be remapped, failure otherwise. `valueDiagTag`
Expand Down Expand Up @@ -1204,9 +1207,10 @@ RewriterState ConversionPatternRewriterImpl::getCurrentState() {
return RewriterState(rewrites.size(), ignoredOps.size(), replacedOps.size());
}

void ConversionPatternRewriterImpl::resetState(RewriterState state) {
void ConversionPatternRewriterImpl::resetState(RewriterState state,
StringRef patternName) {
// Undo any rewrites.
undoRewrites(state.numRewrites);
undoRewrites(state.numRewrites, patternName);

// Pop all of the recorded ignored operations that are no longer valid.
while (ignoredOps.size() != state.numIgnoredOperations)
Expand All @@ -1216,10 +1220,18 @@ void ConversionPatternRewriterImpl::resetState(RewriterState state) {
replacedOps.pop_back();
}

void ConversionPatternRewriterImpl::undoRewrites(unsigned numRewritesToKeep) {
void ConversionPatternRewriterImpl::undoRewrites(unsigned numRewritesToKeep,
StringRef patternName) {
for (auto &rewrite :
llvm::reverse(llvm::drop_begin(rewrites, numRewritesToKeep)))
llvm::reverse(llvm::drop_begin(rewrites, numRewritesToKeep))) {
if (!config.allowPatternRollback &&
!isa<UnresolvedMaterializationRewrite>(rewrite)) {
// Unresolved materializations can always be rolled back (erased).
llvm::report_fatal_error("pattern '" + patternName +
"' rollback of IR modifications requested");
}
rewrite->rollback();
}
rewrites.resize(numRewritesToKeep);
}

Expand Down Expand Up @@ -2158,7 +2170,7 @@ OperationLegalizer::legalizeWithPattern(Operation *op,
});
if (config.listener)
config.listener->notifyPatternEnd(pattern, failure());
rewriterImpl.resetState(curState);
rewriterImpl.resetState(curState, pattern.getDebugName());
appliedPatterns.erase(&pattern);
};

Expand All @@ -2168,8 +2180,13 @@ OperationLegalizer::legalizeWithPattern(Operation *op,
assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates");
auto result = legalizePatternResult(op, pattern, rewriter, curState);
appliedPatterns.erase(&pattern);
if (failed(result))
rewriterImpl.resetState(curState);
if (failed(result)) {
if (!rewriterImpl.config.allowPatternRollback)
op->emitError("pattern '")
<< pattern.getDebugName()
<< "' produced IR that could not be legalized";
rewriterImpl.resetState(curState, pattern.getDebugName());
}
if (config.listener)
config.listener->notifyPatternEnd(pattern, result);
return result;
Expand Down Expand Up @@ -2674,9 +2691,20 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
ConversionPatternRewriter rewriter(ops.front()->getContext(), config);
ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl();

for (auto *op : toConvert)
if (failed(convert(rewriter, op)))
return rewriterImpl.undoRewrites(), failure();
for (auto *op : toConvert) {
if (failed(convert(rewriter, op))) {
// Dialect conversion failed.
if (rewriterImpl.config.allowPatternRollback) {
// Rollback is allowed: restore the original IR.
rewriterImpl.undoRewrites();
} else {
// Rollback is not allowed: apply all modifications that have been
// performed so far.
rewriterImpl.applyRewrites();
}
return failure();
}
}

// After a successful conversion, apply rewrites.
rewriterImpl.applyRewrites();
Expand Down
Loading