Skip to content

Commit da7b888

Browse files
no rollback flag
1 parent 63e2888 commit da7b888

File tree

2 files changed

+63
-14
lines changed

2 files changed

+63
-14
lines changed

mlir/include/mlir/Transforms/DialectConversion.h

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1219,6 +1219,26 @@ struct ConversionConfig {
12191219
/// materializations and instead inserts "builtin.unrealized_conversion_cast"
12201220
/// ops to ensure that the resulting IR is valid.
12211221
bool buildMaterializations = true;
1222+
1223+
/// If set to "true", pattern rollback is allowed. The conversion driver
1224+
/// rolls back IR modifications in the following situations.
1225+
///
1226+
/// 1. Pattern implementation returns "failure" after modifying IR.
1227+
/// 2. Pattern produces IR (in-place modification or new IR) that is illegal
1228+
/// and cannot be legalized by subsequent foldings / pattern applications.
1229+
///
1230+
/// If set to "false", the conversion driver will produce an LLVM fatal error
1231+
/// instead of rolling back IR modifications. Moreover, in case of a failed
1232+
/// conversion, the original IR is not restored. The resulting IR may be a
1233+
/// mix of original and rewritten IR. (Same as a failed greedy pattern
1234+
/// rewrite.)
1235+
///
1236+
/// Note: This flag was added in preparation of the One-Shot Dialect
1237+
/// Conversion refactoring, which will remove the ability to roll back IR
1238+
/// modifications from the conversion driver. Use this flag to ensure that
1239+
/// your patterns do not trigger any IR rollbacks. For details, see
1240+
/// https://discourse.llvm.org/t/rfc-a-new-one-shot-dialect-conversion-driver/79083.
1241+
bool allowPatternRollback = true;
12221242
};
12231243

12241244
//===----------------------------------------------------------------------===//

mlir/lib/Transforms/Utils/DialectConversion.cpp

Lines changed: 43 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -861,8 +861,10 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
861861
/// conversion process succeeds.
862862
void applyRewrites();
863863

864-
/// Reset the state of the rewriter to a previously saved point.
865-
void resetState(RewriterState state);
864+
/// Reset the state of the rewriter to a previously saved point. Optionally,
865+
/// the name of the pattern that triggered the rollback can specified for
866+
/// debugging purposes.
867+
void resetState(RewriterState state, StringRef patternName = "");
866868

867869
/// Append a rewrite. Rewrites are committed upon success and rolled back upon
868870
/// failure.
@@ -873,8 +875,9 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
873875
}
874876

875877
/// Undo the rewrites (motions, splits) one by one in reverse order until
876-
/// "numRewritesToKeep" rewrites remains.
877-
void undoRewrites(unsigned numRewritesToKeep = 0);
878+
/// "numRewritesToKeep" rewrites remains. Optionally, the name of the pattern
879+
/// that triggered the rollback can specified for debugging purposes.
880+
void undoRewrites(unsigned numRewritesToKeep = 0, StringRef patternName = "");
878881

879882
/// Remap the given values to those with potentially different types. Returns
880883
/// success if the values could be remapped, failure otherwise. `valueDiagTag`
@@ -1204,9 +1207,10 @@ RewriterState ConversionPatternRewriterImpl::getCurrentState() {
12041207
return RewriterState(rewrites.size(), ignoredOps.size(), replacedOps.size());
12051208
}
12061209

1207-
void ConversionPatternRewriterImpl::resetState(RewriterState state) {
1210+
void ConversionPatternRewriterImpl::resetState(RewriterState state,
1211+
StringRef patternName) {
12081212
// Undo any rewrites.
1209-
undoRewrites(state.numRewrites);
1213+
undoRewrites(state.numRewrites, patternName);
12101214

12111215
// Pop all of the recorded ignored operations that are no longer valid.
12121216
while (ignoredOps.size() != state.numIgnoredOperations)
@@ -1216,10 +1220,19 @@ void ConversionPatternRewriterImpl::resetState(RewriterState state) {
12161220
replacedOps.pop_back();
12171221
}
12181222

1219-
void ConversionPatternRewriterImpl::undoRewrites(unsigned numRewritesToKeep) {
1223+
void ConversionPatternRewriterImpl::undoRewrites(unsigned numRewritesToKeep,
1224+
StringRef patternName) {
12201225
for (auto &rewrite :
1221-
llvm::reverse(llvm::drop_begin(rewrites, numRewritesToKeep)))
1226+
llvm::reverse(llvm::drop_begin(rewrites, numRewritesToKeep))) {
1227+
if (!config.allowPatternRollback &&
1228+
!isa<UnresolvedMaterializationRewrite>(rewrite)) {
1229+
// Unresolved materializations can always be rolled back (erased).
1230+
std::string errorMessage = "pattern '" + std::string(patternName) +
1231+
"' rollback of IR modifications requested";
1232+
llvm_unreachable(errorMessage.c_str());
1233+
}
12221234
rewrite->rollback();
1235+
}
12231236
rewrites.resize(numRewritesToKeep);
12241237
}
12251238

@@ -2158,7 +2171,7 @@ OperationLegalizer::legalizeWithPattern(Operation *op,
21582171
});
21592172
if (config.listener)
21602173
config.listener->notifyPatternEnd(pattern, failure());
2161-
rewriterImpl.resetState(curState);
2174+
rewriterImpl.resetState(curState, pattern.getDebugName());
21622175
appliedPatterns.erase(&pattern);
21632176
};
21642177

@@ -2168,8 +2181,13 @@ OperationLegalizer::legalizeWithPattern(Operation *op,
21682181
assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates");
21692182
auto result = legalizePatternResult(op, pattern, rewriter, curState);
21702183
appliedPatterns.erase(&pattern);
2171-
if (failed(result))
2172-
rewriterImpl.resetState(curState);
2184+
if (failed(result)) {
2185+
if (!rewriterImpl.config.allowPatternRollback)
2186+
op->emitError("pattern '")
2187+
<< pattern.getDebugName()
2188+
<< "' produced IR that could not be legalized";
2189+
rewriterImpl.resetState(curState, pattern.getDebugName());
2190+
}
21732191
if (config.listener)
21742192
config.listener->notifyPatternEnd(pattern, result);
21752193
return result;
@@ -2674,9 +2692,20 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
26742692
ConversionPatternRewriter rewriter(ops.front()->getContext(), config);
26752693
ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl();
26762694

2677-
for (auto *op : toConvert)
2678-
if (failed(convert(rewriter, op)))
2679-
return rewriterImpl.undoRewrites(), failure();
2695+
for (auto *op : toConvert) {
2696+
if (failed(convert(rewriter, op))) {
2697+
// Dialect conversion failed.
2698+
if (rewriterImpl.config.allowPatternRollback) {
2699+
// Rollback is allowed: restore the original IR.
2700+
rewriterImpl.undoRewrites();
2701+
} else {
2702+
// Rollback is not allowed: apply all modifications that have been
2703+
// performed so far.
2704+
rewriterImpl.applyRewrites();
2705+
}
2706+
return failure();
2707+
}
2708+
}
26802709

26812710
// After a successful conversion, apply rewrites.
26822711
rewriterImpl.applyRewrites();

0 commit comments

Comments
 (0)