@@ -861,8 +861,10 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
861
861
// / conversion process succeeds.
862
862
void applyRewrites ();
863
863
864
- // / Reset the state of the rewriter to a previously saved point.
865
- void resetState (RewriterState state);
864
+ // / Reset the state of the rewriter to a previously saved point. Optionally,
865
+ // / the name of the pattern that triggered the rollback can specified for
866
+ // / debugging purposes.
867
+ void resetState (RewriterState state, StringRef patternName = " " );
866
868
867
869
// / Append a rewrite. Rewrites are committed upon success and rolled back upon
868
870
// / failure.
@@ -873,8 +875,9 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
873
875
}
874
876
875
877
// / Undo the rewrites (motions, splits) one by one in reverse order until
876
- // / "numRewritesToKeep" rewrites remains.
877
- void undoRewrites (unsigned numRewritesToKeep = 0 );
878
+ // / "numRewritesToKeep" rewrites remains. Optionally, the name of the pattern
879
+ // / that triggered the rollback can specified for debugging purposes.
880
+ void undoRewrites (unsigned numRewritesToKeep = 0 , StringRef patternName = " " );
878
881
879
882
// / Remap the given values to those with potentially different types. Returns
880
883
// / success if the values could be remapped, failure otherwise. `valueDiagTag`
@@ -1204,9 +1207,10 @@ RewriterState ConversionPatternRewriterImpl::getCurrentState() {
1204
1207
return RewriterState (rewrites.size (), ignoredOps.size (), replacedOps.size ());
1205
1208
}
1206
1209
1207
- void ConversionPatternRewriterImpl::resetState (RewriterState state) {
1210
+ void ConversionPatternRewriterImpl::resetState (RewriterState state,
1211
+ StringRef patternName) {
1208
1212
// Undo any rewrites.
1209
- undoRewrites (state.numRewrites );
1213
+ undoRewrites (state.numRewrites , patternName );
1210
1214
1211
1215
// Pop all of the recorded ignored operations that are no longer valid.
1212
1216
while (ignoredOps.size () != state.numIgnoredOperations )
@@ -1216,10 +1220,19 @@ void ConversionPatternRewriterImpl::resetState(RewriterState state) {
1216
1220
replacedOps.pop_back ();
1217
1221
}
1218
1222
1219
- void ConversionPatternRewriterImpl::undoRewrites (unsigned numRewritesToKeep) {
1223
+ void ConversionPatternRewriterImpl::undoRewrites (unsigned numRewritesToKeep,
1224
+ StringRef patternName) {
1220
1225
for (auto &rewrite :
1221
- llvm::reverse (llvm::drop_begin (rewrites, numRewritesToKeep)))
1226
+ llvm::reverse (llvm::drop_begin (rewrites, numRewritesToKeep))) {
1227
+ if (!config.allowPatternRollback &&
1228
+ !isa<UnresolvedMaterializationRewrite>(rewrite)) {
1229
+ // Unresolved materializations can always be rolled back (erased).
1230
+ std::string errorMessage = " pattern '" + std::string (patternName) +
1231
+ " ' rollback of IR modifications requested" ;
1232
+ llvm_unreachable (errorMessage.c_str ());
1233
+ }
1222
1234
rewrite->rollback ();
1235
+ }
1223
1236
rewrites.resize (numRewritesToKeep);
1224
1237
}
1225
1238
@@ -2158,7 +2171,7 @@ OperationLegalizer::legalizeWithPattern(Operation *op,
2158
2171
});
2159
2172
if (config.listener )
2160
2173
config.listener ->notifyPatternEnd (pattern, failure ());
2161
- rewriterImpl.resetState (curState);
2174
+ rewriterImpl.resetState (curState, pattern. getDebugName () );
2162
2175
appliedPatterns.erase (&pattern);
2163
2176
};
2164
2177
@@ -2168,8 +2181,13 @@ OperationLegalizer::legalizeWithPattern(Operation *op,
2168
2181
assert (rewriterImpl.pendingRootUpdates .empty () && " dangling root updates" );
2169
2182
auto result = legalizePatternResult (op, pattern, rewriter, curState);
2170
2183
appliedPatterns.erase (&pattern);
2171
- if (failed (result))
2172
- rewriterImpl.resetState (curState);
2184
+ if (failed (result)) {
2185
+ if (!rewriterImpl.config .allowPatternRollback )
2186
+ op->emitError (" pattern '" )
2187
+ << pattern.getDebugName ()
2188
+ << " ' produced IR that could not be legalized" ;
2189
+ rewriterImpl.resetState (curState, pattern.getDebugName ());
2190
+ }
2173
2191
if (config.listener )
2174
2192
config.listener ->notifyPatternEnd (pattern, result);
2175
2193
return result;
@@ -2674,9 +2692,20 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
2674
2692
ConversionPatternRewriter rewriter (ops.front ()->getContext (), config);
2675
2693
ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl ();
2676
2694
2677
- for (auto *op : toConvert)
2678
- if (failed (convert (rewriter, op)))
2679
- return rewriterImpl.undoRewrites (), failure ();
2695
+ for (auto *op : toConvert) {
2696
+ if (failed (convert (rewriter, op))) {
2697
+ // Dialect conversion failed.
2698
+ if (rewriterImpl.config .allowPatternRollback ) {
2699
+ // Rollback is allowed: restore the original IR.
2700
+ rewriterImpl.undoRewrites ();
2701
+ } else {
2702
+ // Rollback is not allowed: apply all modifications that have been
2703
+ // performed so far.
2704
+ rewriterImpl.applyRewrites ();
2705
+ }
2706
+ return failure ();
2707
+ }
2708
+ }
2680
2709
2681
2710
// After a successful conversion, apply rewrites.
2682
2711
rewriterImpl.applyRewrites ();
0 commit comments