Skip to content

[mlir][Transforms] Dialect conversion: Add flag to disable rollback #136490

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 22, 2025

Conversation

matthias-springer
Copy link
Member

@matthias-springer matthias-springer commented Apr 20, 2025

This commit adds a new flag to ConversionConfig to disallow the rollback of IR modification. This commit is in preparation of the One-Shot Dialect Conversion refactoring, which will remove the ability to roll back IR modifications from the conversion driver.

RFC: https://discourse.llvm.org/t/rfc-a-new-one-shot-dialect-conversion-driver/79083/46

By default, this flag is set to "true". I.e., the rollback of IR modifications is allowed. When set to "false", the conversion driver will report a fatal LLVM error when an IR rollback is requested. The name of the rolled back pattern is included in the error message. Moreover, the original IR is no longer restored after a failed conversion.

Example:

within split at llvm-project/mlir/test/Conversion/ArithToSPIRV/fast-math.mlir:1 offset :11:8: error: pattern '(anonymous namespace)::CmpFOpNanKernelPattern' produced IR that could not be legalized
  %0 = arith.cmpf ord, %arg0, %arg1 fastmath<fast> : f32
       ^
within split at llvm-project/mlir/test/Conversion/ArithToSPIRV/fast-math.mlir:1 offset :11:8: note: see current operation: %1 = "arith.cmpf"(%arg0, %arg1) <{fastmath = #arith.fastmath<fast>, predicate = 7 : i64}> : (f32, f32) -> i1
pattern '(anonymous namespace)::CmpFOpNanKernelPattern' rollback of IR modifications requested
UNREACHABLE executed at llvm-project/mlir/lib/Transforms/Utils/DialectConversion.cpp:1231!

The majority of patterns in MLIR have already been updated such that they do not trigger any rollbacks, but a few SPIRV patterns remain. More information in the RFC.

Depends on #136489.

@matthias-springer matthias-springer marked this pull request as ready for review April 20, 2025 14:49
@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir labels Apr 20, 2025
@llvmbot
Copy link
Member

llvmbot commented Apr 20, 2025

@llvm/pr-subscribers-mlir-core

@llvm/pr-subscribers-mlir

Author: Matthias Springer (matthias-springer)

Changes

This commit adds a new flag to ConversionConfig to disallow the rollback of IR modification. This commit is in preparation of the One-Shot Dialect Conversion refactoring, which will remove the ability to roll back IR modifications from the conversion driver.

RFC: https://discourse.llvm.org/t/rfc-a-new-one-shot-dialect-conversion-driver/79083

By default, this flag is set to "true". I.e., the rollback of IR modifications is allowed. When set to "false", the conversion driver will report a fatal LLVM error when an IR rollback is requested. The name of the rolled back pattern is included in the error message. Moreover, when the original IR is no longer restored after a failed conversion.

Example:

within split at llvm-project/mlir/test/Conversion/ArithToSPIRV/fast-math.mlir:1 offset :11:8: error: pattern '(anonymous namespace)::CmpFOpNanKernelPattern' produced IR that could not be legalized
  %0 = arith.cmpf ord, %arg0, %arg1 fastmath&lt;fast&gt; : f32
       ^
within split at llvm-project/mlir/test/Conversion/ArithToSPIRV/fast-math.mlir:1 offset :11:8: note: see current operation: %1 = "arith.cmpf"(%arg0, %arg1) &lt;{fastmath = #arith.fastmath&lt;fast&gt;, predicate = 7 : i64}&gt; : (f32, f32) -&gt; i1
pattern '(anonymous namespace)::CmpFOpNanKernelPattern' rollback of IR modifications requested
UNREACHABLE executed at llvm-project/mlir/lib/Transforms/Utils/DialectConversion.cpp:1231!

The majority of patterns in MLIR have already been updated such that they do not trigger any rollbacks, but a few SPIRV patterns remain. More information in the RFC.

Depends on #136489.


Full diff: https://github.com/llvm/llvm-project/pull/136490.diff

2 Files Affected:

  • (modified) mlir/include/mlir/Transforms/DialectConversion.h (+20)
  • (modified) mlir/lib/Transforms/Utils/DialectConversion.cpp (+43-14)
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index b6ab252456e70..b65b3ea971f91 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -1219,6 +1219,26 @@ struct ConversionConfig {
   /// materializations and instead inserts "builtin.unrealized_conversion_cast"
   /// ops to ensure that the resulting IR is valid.
   bool buildMaterializations = true;
+
+  /// If set to "true", pattern rollback is allowed. The conversion driver
+  /// rolls back IR modifications in the following situations.
+  ///
+  /// 1. Pattern implementation returns "failure" after modifying IR.
+  /// 2. Pattern produces IR (in-place modification or new IR) that is illegal
+  ///    and cannot be legalized by subsequent foldings / pattern applications.
+  ///
+  /// If set to "false", the conversion driver will produce an LLVM fatal error
+  /// instead of rolling back IR modifications. Moreover, in case of a failed
+  /// conversion, the original IR is not restored. The resulting IR may be a
+  /// mix of original and rewritten IR. (Same as a failed greedy pattern
+  /// rewrite.)
+  ///
+  /// Note: This flag was added in preparation of the One-Shot Dialect
+  /// Conversion refactoring, which will remove the ability to roll back IR
+  /// modifications from the conversion driver. Use this flag to ensure that
+  /// your patterns do not trigger any IR rollbacks. For details, see
+  /// https://discourse.llvm.org/t/rfc-a-new-one-shot-dialect-conversion-driver/79083.
+  bool allowPatternRollback = true;
 };
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 63225c6bbee7c..7a9da13427f91 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -861,8 +861,10 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   /// conversion process succeeds.
   void applyRewrites();
 
-  /// Reset the state of the rewriter to a previously saved point.
-  void resetState(RewriterState state);
+  /// Reset the state of the rewriter to a previously saved point. Optionally,
+  /// the name of the pattern that triggered the rollback can specified for
+  /// debugging purposes.
+  void resetState(RewriterState state, StringRef patternName = "");
 
   /// Append a rewrite. Rewrites are committed upon success and rolled back upon
   /// failure.
@@ -873,8 +875,9 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   }
 
   /// Undo the rewrites (motions, splits) one by one in reverse order until
-  /// "numRewritesToKeep" rewrites remains.
-  void undoRewrites(unsigned numRewritesToKeep = 0);
+  /// "numRewritesToKeep" rewrites remains. Optionally, the name of the pattern
+  /// that triggered the rollback can specified for debugging purposes.
+  void undoRewrites(unsigned numRewritesToKeep = 0, StringRef patternName = "");
 
   /// Remap the given values to those with potentially different types. Returns
   /// success if the values could be remapped, failure otherwise. `valueDiagTag`
@@ -1204,9 +1207,10 @@ RewriterState ConversionPatternRewriterImpl::getCurrentState() {
   return RewriterState(rewrites.size(), ignoredOps.size(), replacedOps.size());
 }
 
-void ConversionPatternRewriterImpl::resetState(RewriterState state) {
+void ConversionPatternRewriterImpl::resetState(RewriterState state,
+                                               StringRef patternName) {
   // Undo any rewrites.
-  undoRewrites(state.numRewrites);
+  undoRewrites(state.numRewrites, patternName);
 
   // Pop all of the recorded ignored operations that are no longer valid.
   while (ignoredOps.size() != state.numIgnoredOperations)
@@ -1216,10 +1220,19 @@ void ConversionPatternRewriterImpl::resetState(RewriterState state) {
     replacedOps.pop_back();
 }
 
-void ConversionPatternRewriterImpl::undoRewrites(unsigned numRewritesToKeep) {
+void ConversionPatternRewriterImpl::undoRewrites(unsigned numRewritesToKeep,
+                                                 StringRef patternName) {
   for (auto &rewrite :
-       llvm::reverse(llvm::drop_begin(rewrites, numRewritesToKeep)))
+       llvm::reverse(llvm::drop_begin(rewrites, numRewritesToKeep))) {
+    if (!config.allowPatternRollback &&
+        !isa<UnresolvedMaterializationRewrite>(rewrite)) {
+      // Unresolved materializations can always be rolled back (erased).
+      std::string errorMessage = "pattern '" + std::string(patternName) +
+                                 "' rollback of IR modifications requested";
+      llvm_unreachable(errorMessage.c_str());
+    }
     rewrite->rollback();
+  }
   rewrites.resize(numRewritesToKeep);
 }
 
@@ -2158,7 +2171,7 @@ OperationLegalizer::legalizeWithPattern(Operation *op,
     });
     if (config.listener)
       config.listener->notifyPatternEnd(pattern, failure());
-    rewriterImpl.resetState(curState);
+    rewriterImpl.resetState(curState, pattern.getDebugName());
     appliedPatterns.erase(&pattern);
   };
 
@@ -2168,8 +2181,13 @@ OperationLegalizer::legalizeWithPattern(Operation *op,
     assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates");
     auto result = legalizePatternResult(op, pattern, rewriter, curState);
     appliedPatterns.erase(&pattern);
-    if (failed(result))
-      rewriterImpl.resetState(curState);
+    if (failed(result)) {
+      if (!rewriterImpl.config.allowPatternRollback)
+        op->emitError("pattern '")
+            << pattern.getDebugName()
+            << "' produced IR that could not be legalized";
+      rewriterImpl.resetState(curState, pattern.getDebugName());
+    }
     if (config.listener)
       config.listener->notifyPatternEnd(pattern, result);
     return result;
@@ -2674,9 +2692,20 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
   ConversionPatternRewriter rewriter(ops.front()->getContext(), config);
   ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl();
 
-  for (auto *op : toConvert)
-    if (failed(convert(rewriter, op)))
-      return rewriterImpl.undoRewrites(), failure();
+  for (auto *op : toConvert) {
+    if (failed(convert(rewriter, op))) {
+      // Dialect conversion failed.
+      if (rewriterImpl.config.allowPatternRollback) {
+        // Rollback is allowed: restore the original IR.
+        rewriterImpl.undoRewrites();
+      } else {
+        // Rollback is not allowed: apply all modifications that have been
+        // performed so far.
+        rewriterImpl.applyRewrites();
+      }
+      return failure();
+    }
+  }
 
   // After a successful conversion, apply rewrites.
   rewriterImpl.applyRewrites();

Copy link
Member

@jpienaar jpienaar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So this would also allow folks to prefetch the future state? E.g., used to test for failures when rollback no longer supported?

/// Reset the state of the rewriter to a previously saved point. Optionally,
/// the name of the pattern that triggered the rollback can specified for
/// debugging purposes.
void resetState(RewriterState state, StringRef patternName = "");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are patternName params intended to stay post?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This entire function will disappear when we delete the rollback mechanism.

@matthias-springer
Copy link
Member Author

So this would also allow folks to prefetch the future state? E.g., used to test for failures when rollback no longer supported?

Yes, that's correct. The simplest way to do that is to change the default value of the new allowPatternRollback flag to false in DialectConversion.h. Alternatively, the flag can also be set to false on a case-by-case basis in each pass.

Copy link
Member

@zero9178 zero9178 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@matthias-springer matthias-springer force-pushed the users/matthias-springer/dialect_conv_rollback_fold branch from 63e2888 to 8bf5ca0 Compare April 22, 2025 06:55
Base automatically changed from users/matthias-springer/dialect_conv_rollback_fold to main April 22, 2025 07:12
@matthias-springer matthias-springer force-pushed the users/matthias-springer/no_rollback_flag branch from da7b888 to 7b4cb9d Compare April 22, 2025 07:22
@matthias-springer matthias-springer force-pushed the users/matthias-springer/no_rollback_flag branch from 7b4cb9d to 65002df Compare April 22, 2025 07:28
@matthias-springer matthias-springer merged commit 8bc0d4d into main Apr 22, 2025
7 of 10 checks passed
@matthias-springer matthias-springer deleted the users/matthias-springer/no_rollback_flag branch April 22, 2025 07:45
makslevental added a commit to makslevental/triton that referenced this pull request Apr 23, 2025
let's start working to integrate 

llvm/llvm-project#136489

and 

llvm/llvm-project#136490

(maybe there will be no work 🤞
antiagainst pushed a commit to triton-lang/triton that referenced this pull request Apr 23, 2025
IanWood1 pushed a commit to IanWood1/llvm-project that referenced this pull request May 6, 2025
…lvm#136490)

This commit adds a new flag to `ConversionConfig` to disallow the
rollback of IR modification. This commit is in preparation of the
One-Shot Dialect Conversion refactoring, which will remove the ability
to roll back IR modifications from the conversion driver.

RFC:
https://discourse.llvm.org/t/rfc-a-new-one-shot-dialect-conversion-driver/79083/46

By default, this flag is set to "true". I.e., the rollback of IR
modifications is allowed. When set to "false", the conversion driver
will report a fatal LLVM error when an IR rollback is requested. The
name of the rolled back pattern is included in the error message.
Moreover, the original IR is no longer restored after a failed
conversion.

Example:
```
within split at llvm-project/mlir/test/Conversion/ArithToSPIRV/fast-math.mlir:1 offset :11:8: error: pattern '(anonymous namespace)::CmpFOpNanKernelPattern' produced IR that could not be legalized
  %0 = arith.cmpf ord, %arg0, %arg1 fastmath<fast> : f32
       ^
within split at llvm-project/mlir/test/Conversion/ArithToSPIRV/fast-math.mlir:1 offset :11:8: note: see current operation: %1 = "arith.cmpf"(%arg0, %arg1) <{fastmath = #arith.fastmath<fast>, predicate = 7 : i64}> : (f32, f32) -> i1
pattern '(anonymous namespace)::CmpFOpNanKernelPattern' rollback of IR modifications requested
UNREACHABLE executed at llvm-project/mlir/lib/Transforms/Utils/DialectConversion.cpp:1231!
```

The majority of patterns in MLIR have already been updated such that
they do not trigger any rollbacks, but a few SPIRV patterns remain. More
information in the RFC.
IanWood1 pushed a commit to IanWood1/llvm-project that referenced this pull request May 6, 2025
…lvm#136490)

This commit adds a new flag to `ConversionConfig` to disallow the
rollback of IR modification. This commit is in preparation of the
One-Shot Dialect Conversion refactoring, which will remove the ability
to roll back IR modifications from the conversion driver.

RFC:
https://discourse.llvm.org/t/rfc-a-new-one-shot-dialect-conversion-driver/79083/46

By default, this flag is set to "true". I.e., the rollback of IR
modifications is allowed. When set to "false", the conversion driver
will report a fatal LLVM error when an IR rollback is requested. The
name of the rolled back pattern is included in the error message.
Moreover, the original IR is no longer restored after a failed
conversion.

Example:
```
within split at llvm-project/mlir/test/Conversion/ArithToSPIRV/fast-math.mlir:1 offset :11:8: error: pattern '(anonymous namespace)::CmpFOpNanKernelPattern' produced IR that could not be legalized
  %0 = arith.cmpf ord, %arg0, %arg1 fastmath<fast> : f32
       ^
within split at llvm-project/mlir/test/Conversion/ArithToSPIRV/fast-math.mlir:1 offset :11:8: note: see current operation: %1 = "arith.cmpf"(%arg0, %arg1) <{fastmath = #arith.fastmath<fast>, predicate = 7 : i64}> : (f32, f32) -> i1
pattern '(anonymous namespace)::CmpFOpNanKernelPattern' rollback of IR modifications requested
UNREACHABLE executed at llvm-project/mlir/lib/Transforms/Utils/DialectConversion.cpp:1231!
```

The majority of patterns in MLIR have already been updated such that
they do not trigger any rollbacks, but a few SPIRV patterns remain. More
information in the RFC.
IanWood1 pushed a commit to IanWood1/llvm-project that referenced this pull request May 6, 2025
…lvm#136490)

This commit adds a new flag to `ConversionConfig` to disallow the
rollback of IR modification. This commit is in preparation of the
One-Shot Dialect Conversion refactoring, which will remove the ability
to roll back IR modifications from the conversion driver.

RFC:
https://discourse.llvm.org/t/rfc-a-new-one-shot-dialect-conversion-driver/79083/46

By default, this flag is set to "true". I.e., the rollback of IR
modifications is allowed. When set to "false", the conversion driver
will report a fatal LLVM error when an IR rollback is requested. The
name of the rolled back pattern is included in the error message.
Moreover, the original IR is no longer restored after a failed
conversion.

Example:
```
within split at llvm-project/mlir/test/Conversion/ArithToSPIRV/fast-math.mlir:1 offset :11:8: error: pattern '(anonymous namespace)::CmpFOpNanKernelPattern' produced IR that could not be legalized
  %0 = arith.cmpf ord, %arg0, %arg1 fastmath<fast> : f32
       ^
within split at llvm-project/mlir/test/Conversion/ArithToSPIRV/fast-math.mlir:1 offset :11:8: note: see current operation: %1 = "arith.cmpf"(%arg0, %arg1) <{fastmath = #arith.fastmath<fast>, predicate = 7 : i64}> : (f32, f32) -> i1
pattern '(anonymous namespace)::CmpFOpNanKernelPattern' rollback of IR modifications requested
UNREACHABLE executed at llvm-project/mlir/lib/Transforms/Utils/DialectConversion.cpp:1231!
```

The majority of patterns in MLIR have already been updated such that
they do not trigger any rollbacks, but a few SPIRV patterns remain. More
information in the RFC.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:core MLIR Core Infrastructure mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants