Skip to content

Commit beb6d3f

Browse files
matthias-springerIanWood1
authored andcommitted
[mlir][Transforms] Dialect conversion: Add flag to disable rollback (llvm#136490)
This commit adds a new flag to `ConversionConfig` to disallow the rollback of IR modification. This commit is in preparation of the One-Shot Dialect Conversion refactoring, which will remove the ability to roll back IR modifications from the conversion driver. RFC: https://discourse.llvm.org/t/rfc-a-new-one-shot-dialect-conversion-driver/79083/46 By default, this flag is set to "true". I.e., the rollback of IR modifications is allowed. When set to "false", the conversion driver will report a fatal LLVM error when an IR rollback is requested. The name of the rolled back pattern is included in the error message. Moreover, the original IR is no longer restored after a failed conversion. Example: ``` within split at llvm-project/mlir/test/Conversion/ArithToSPIRV/fast-math.mlir:1 offset :11:8: error: pattern '(anonymous namespace)::CmpFOpNanKernelPattern' produced IR that could not be legalized %0 = arith.cmpf ord, %arg0, %arg1 fastmath<fast> : f32 ^ within split at llvm-project/mlir/test/Conversion/ArithToSPIRV/fast-math.mlir:1 offset :11:8: note: see current operation: %1 = "arith.cmpf"(%arg0, %arg1) <{fastmath = #arith.fastmath<fast>, predicate = 7 : i64}> : (f32, f32) -> i1 pattern '(anonymous namespace)::CmpFOpNanKernelPattern' rollback of IR modifications requested UNREACHABLE executed at llvm-project/mlir/lib/Transforms/Utils/DialectConversion.cpp:1231! ``` The majority of patterns in MLIR have already been updated such that they do not trigger any rollbacks, but a few SPIRV patterns remain. More information in the RFC.
1 parent 200a0dc commit beb6d3f

File tree

2 files changed

+62
-14
lines changed

2 files changed

+62
-14
lines changed

mlir/include/mlir/Transforms/DialectConversion.h

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1239,6 +1239,26 @@ struct ConversionConfig {
12391239
/// materializations and instead inserts "builtin.unrealized_conversion_cast"
12401240
/// ops to ensure that the resulting IR is valid.
12411241
bool buildMaterializations = true;
1242+
1243+
/// If set to "true", pattern rollback is allowed. The conversion driver
1244+
/// rolls back IR modifications in the following situations.
1245+
///
1246+
/// 1. Pattern implementation returns "failure" after modifying IR.
1247+
/// 2. Pattern produces IR (in-place modification or new IR) that is illegal
1248+
/// and cannot be legalized by subsequent foldings / pattern applications.
1249+
///
1250+
/// If set to "false", the conversion driver will produce an LLVM fatal error
1251+
/// instead of rolling back IR modifications. Moreover, in case of a failed
1252+
/// conversion, the original IR is not restored. The resulting IR may be a
1253+
/// mix of original and rewritten IR. (Same as a failed greedy pattern
1254+
/// rewrite.)
1255+
///
1256+
/// Note: This flag was added in preparation of the One-Shot Dialect
1257+
/// Conversion refactoring, which will remove the ability to roll back IR
1258+
/// modifications from the conversion driver. Use this flag to ensure that
1259+
/// your patterns do not trigger any IR rollbacks. For details, see
1260+
/// https://discourse.llvm.org/t/rfc-a-new-one-shot-dialect-conversion-driver/79083.
1261+
bool allowPatternRollback = true;
12421262
};
12431263

12441264
//===----------------------------------------------------------------------===//

mlir/lib/Transforms/Utils/DialectConversion.cpp

Lines changed: 42 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -861,8 +861,10 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
861861
/// conversion process succeeds.
862862
void applyRewrites();
863863

864-
/// Reset the state of the rewriter to a previously saved point.
865-
void resetState(RewriterState state);
864+
/// Reset the state of the rewriter to a previously saved point. Optionally,
865+
/// the name of the pattern that triggered the rollback can specified for
866+
/// debugging purposes.
867+
void resetState(RewriterState state, StringRef patternName = "");
866868

867869
/// Append a rewrite. Rewrites are committed upon success and rolled back upon
868870
/// failure.
@@ -873,8 +875,9 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
873875
}
874876

875877
/// Undo the rewrites (motions, splits) one by one in reverse order until
876-
/// "numRewritesToKeep" rewrites remains.
877-
void undoRewrites(unsigned numRewritesToKeep = 0);
878+
/// "numRewritesToKeep" rewrites remains. Optionally, the name of the pattern
879+
/// that triggered the rollback can specified for debugging purposes.
880+
void undoRewrites(unsigned numRewritesToKeep = 0, StringRef patternName = "");
878881

879882
/// Remap the given values to those with potentially different types. Returns
880883
/// success if the values could be remapped, failure otherwise. `valueDiagTag`
@@ -1204,9 +1207,10 @@ RewriterState ConversionPatternRewriterImpl::getCurrentState() {
12041207
return RewriterState(rewrites.size(), ignoredOps.size(), replacedOps.size());
12051208
}
12061209

1207-
void ConversionPatternRewriterImpl::resetState(RewriterState state) {
1210+
void ConversionPatternRewriterImpl::resetState(RewriterState state,
1211+
StringRef patternName) {
12081212
// Undo any rewrites.
1209-
undoRewrites(state.numRewrites);
1213+
undoRewrites(state.numRewrites, patternName);
12101214

12111215
// Pop all of the recorded ignored operations that are no longer valid.
12121216
while (ignoredOps.size() != state.numIgnoredOperations)
@@ -1216,10 +1220,18 @@ void ConversionPatternRewriterImpl::resetState(RewriterState state) {
12161220
replacedOps.pop_back();
12171221
}
12181222

1219-
void ConversionPatternRewriterImpl::undoRewrites(unsigned numRewritesToKeep) {
1223+
void ConversionPatternRewriterImpl::undoRewrites(unsigned numRewritesToKeep,
1224+
StringRef patternName) {
12201225
for (auto &rewrite :
1221-
llvm::reverse(llvm::drop_begin(rewrites, numRewritesToKeep)))
1226+
llvm::reverse(llvm::drop_begin(rewrites, numRewritesToKeep))) {
1227+
if (!config.allowPatternRollback &&
1228+
!isa<UnresolvedMaterializationRewrite>(rewrite)) {
1229+
// Unresolved materializations can always be rolled back (erased).
1230+
llvm::report_fatal_error("pattern '" + patternName +
1231+
"' rollback of IR modifications requested");
1232+
}
12221233
rewrite->rollback();
1234+
}
12231235
rewrites.resize(numRewritesToKeep);
12241236
}
12251237

@@ -2158,7 +2170,7 @@ OperationLegalizer::legalizeWithPattern(Operation *op,
21582170
});
21592171
if (config.listener)
21602172
config.listener->notifyPatternEnd(pattern, failure());
2161-
rewriterImpl.resetState(curState);
2173+
rewriterImpl.resetState(curState, pattern.getDebugName());
21622174
appliedPatterns.erase(&pattern);
21632175
};
21642176

@@ -2168,8 +2180,13 @@ OperationLegalizer::legalizeWithPattern(Operation *op,
21682180
assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates");
21692181
auto result = legalizePatternResult(op, pattern, rewriter, curState);
21702182
appliedPatterns.erase(&pattern);
2171-
if (failed(result))
2172-
rewriterImpl.resetState(curState);
2183+
if (failed(result)) {
2184+
if (!rewriterImpl.config.allowPatternRollback)
2185+
op->emitError("pattern '")
2186+
<< pattern.getDebugName()
2187+
<< "' produced IR that could not be legalized";
2188+
rewriterImpl.resetState(curState, pattern.getDebugName());
2189+
}
21732190
if (config.listener)
21742191
config.listener->notifyPatternEnd(pattern, result);
21752192
return result;
@@ -2674,9 +2691,20 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
26742691
ConversionPatternRewriter rewriter(ops.front()->getContext(), config);
26752692
ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl();
26762693

2677-
for (auto *op : toConvert)
2678-
if (failed(convert(rewriter, op)))
2679-
return rewriterImpl.undoRewrites(), failure();
2694+
for (auto *op : toConvert) {
2695+
if (failed(convert(rewriter, op))) {
2696+
// Dialect conversion failed.
2697+
if (rewriterImpl.config.allowPatternRollback) {
2698+
// Rollback is allowed: restore the original IR.
2699+
rewriterImpl.undoRewrites();
2700+
} else {
2701+
// Rollback is not allowed: apply all modifications that have been
2702+
// performed so far.
2703+
rewriterImpl.applyRewrites();
2704+
}
2705+
return failure();
2706+
}
2707+
}
26802708

26812709
// After a successful conversion, apply rewrites.
26822710
rewriterImpl.applyRewrites();

0 commit comments

Comments
 (0)