Skip to content

Reland: [MLIR][Transforms] Fix Mem2Reg removal order to respect dominance #68877

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 12, 2023

Conversation

Dinistro
Copy link
Contributor

This commit fixes a bug in the Mem2Reg operation erasure order. Replacing the use-def based topological order with a dominance-based weak order ensures that no operation is removed before all its uses have been replaced. The order relation uses the topological order of blocks and block internal ordering to determine a deterministic operation order.

Additionally, the reliance on the DenseMap key order was eliminated by switching to a MapVector, that gives a deterministic iteration order.

Example:

%ptr = alloca ...
...
%val0 = %load %ptr ... // LOAD0
store %val0 %ptr ...
%val1 = load %ptr ... // LOAD1

When promoting the slot backing %ptr, it can happen that the LOAD0 was cleaned before LOAD1. This results in all uses of LOAD0 being replaced by its reaching definition, before LOAD1's result is replaced by LOAD0's result. The subsequent erasure of LOAD0 can thus not succeed, as it has remaining usages.

@llvmbot
Copy link
Member

llvmbot commented Oct 12, 2023

@llvm/pr-subscribers-mlir-core
@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-llvm
@llvm/pr-subscribers-flang-openmp
@llvm/pr-subscribers-openacc
@llvm/pr-subscribers-mlir-openmp

@llvm/pr-subscribers-mlir-openacc

Author: Christian Ulmann (Dinistro)

Changes

This commit fixes a bug in the Mem2Reg operation erasure order. Replacing the use-def based topological order with a dominance-based weak order ensures that no operation is removed before all its uses have been replaced. The order relation uses the topological order of blocks and block internal ordering to determine a deterministic operation order.

Additionally, the reliance on the DenseMap key order was eliminated by switching to a MapVector, that gives a deterministic iteration order.

Example:

%ptr = alloca ...
...
%val0 = %load %ptr ... // LOAD0
store %val0 %ptr ...
%val1 = load %ptr ... // LOAD1

When promoting the slot backing %ptr, it can happen that the LOAD0 was cleaned before LOAD1. This results in all uses of LOAD0 being replaced by its reaching definition, before LOAD1's result is replaced by LOAD0's result. The subsequent erasure of LOAD0 can thus not succeed, as it has remaining usages.


Full diff: https://github.com/llvm/llvm-project/pull/68877.diff

11 Files Affected:

  • (modified) mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h (-3)
  • (modified) mlir/include/mlir/Transforms/RegionUtils.h (+3)
  • (modified) mlir/lib/Target/LLVMIR/CMakeLists.txt (+1)
  • (modified) mlir/lib/Target/LLVMIR/Dialect/OpenACC/CMakeLists.txt (+1)
  • (modified) mlir/lib/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.cpp (+2-2)
  • (modified) mlir/lib/Target/LLVMIR/Dialect/OpenMP/CMakeLists.txt (+1)
  • (modified) mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp (+1-2)
  • (modified) mlir/lib/Target/LLVMIR/ModuleTranslation.cpp (+2-19)
  • (modified) mlir/lib/Transforms/Mem2Reg.cpp (+38-12)
  • (modified) mlir/lib/Transforms/Utils/RegionUtils.cpp (+16)
  • (modified) mlir/test/Dialect/LLVMIR/mem2reg.mlir (+13)
diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
index 68b522405535a69..f9026f84935be52 100644
--- a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
+++ b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
@@ -380,9 +380,6 @@ namespace detail {
 /// to the results of preceding blocks.
 void connectPHINodes(Region &region, const ModuleTranslation &state);
 
-/// Get a topologically sorted list of blocks of the given region.
-SetVector<Block *> getTopologicallySortedBlocks(Region &region);
-
 /// Create an LLVM IR constant of `llvmType` from the MLIR attribute `attr`.
 /// This currently supports integer, floating point, splat and dense element
 /// attributes and combinations thereof. Also, an array attribute with two
diff --git a/mlir/include/mlir/Transforms/RegionUtils.h b/mlir/include/mlir/Transforms/RegionUtils.h
index 06eebff201d1b4d..192ff7138405973 100644
--- a/mlir/include/mlir/Transforms/RegionUtils.h
+++ b/mlir/include/mlir/Transforms/RegionUtils.h
@@ -87,6 +87,9 @@ LogicalResult eraseUnreachableBlocks(RewriterBase &rewriter,
 LogicalResult runRegionDCE(RewriterBase &rewriter,
                            MutableArrayRef<Region> regions);
 
+/// Get a topologically sorted list of blocks of the given region.
+SetVector<Block *> getTopologicallySortedBlocks(Region &region);
+
 } // namespace mlir
 
 #endif // MLIR_TRANSFORMS_REGIONUTILS_H_
diff --git a/mlir/lib/Target/LLVMIR/CMakeLists.txt b/mlir/lib/Target/LLVMIR/CMakeLists.txt
index 868ccbbb10620d9..5db0885d70d6e7a 100644
--- a/mlir/lib/Target/LLVMIR/CMakeLists.txt
+++ b/mlir/lib/Target/LLVMIR/CMakeLists.txt
@@ -39,6 +39,7 @@ add_mlir_translation_library(MLIRTargetLLVMIRExport
   MLIRLLVMDialect
   MLIRLLVMIRTransforms
   MLIRTranslateLib
+  MLIRTransformUtils
   )
 
 add_mlir_translation_library(MLIRToLLVMIRTranslationRegistration
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenACC/CMakeLists.txt b/mlir/lib/Target/LLVMIR/Dialect/OpenACC/CMakeLists.txt
index 5e79e4be58f2965..e43581e28c77afa 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenACC/CMakeLists.txt
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenACC/CMakeLists.txt
@@ -10,4 +10,5 @@ add_mlir_translation_library(MLIROpenACCToLLVMIRTranslation
   MLIROpenACCDialect
   MLIRSupport
   MLIRTargetLLVMIRExport
+  MLIRTransformUtils
   )
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.cpp
index 392d34cd6f91353..37fec190d6f401d 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.cpp
@@ -19,6 +19,7 @@
 #include "mlir/Support/LLVM.h"
 #include "mlir/Target/LLVMIR/Dialect/OpenMPCommon.h"
 #include "mlir/Target/LLVMIR/ModuleTranslation.h"
+#include "mlir/Transforms/RegionUtils.h"
 
 #include "llvm/ADT/TypeSwitch.h"
 #include "llvm/Frontend/OpenMP/OMPConstants.h"
@@ -395,8 +396,7 @@ static LogicalResult convertDataOp(acc::DataOp &op,
   llvm::BasicBlock *endDataBlock = llvm::BasicBlock::Create(
       ctx, "acc.end_data", builder.GetInsertBlock()->getParent());
 
-  SetVector<Block *> blocks =
-      LLVM::detail::getTopologicallySortedBlocks(op.getRegion());
+  SetVector<Block *> blocks = getTopologicallySortedBlocks(op.getRegion());
   for (Block *bb : blocks) {
     llvm::BasicBlock *llvmBB = moduleTranslation.lookupBlock(bb);
     if (bb->isEntryBlock()) {
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/CMakeLists.txt b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/CMakeLists.txt
index 41227f80cdfc8dd..0a5d7c6e22058d4 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/CMakeLists.txt
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/CMakeLists.txt
@@ -10,4 +10,5 @@ add_mlir_translation_library(MLIROpenMPToLLVMIRTranslation
   MLIROpenMPDialect
   MLIRSupport
   MLIRTargetLLVMIRExport
+  MLIRTransformUtils
   )
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 1ec3bb8e7562a9e..208c3d690e5532c 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -194,8 +194,7 @@ static llvm::BasicBlock *convertOmpOpRegions(
 
   // Convert blocks one by one in topological order to ensure
   // defs are converted before uses.
-  SetVector<Block *> blocks =
-      LLVM::detail::getTopologicallySortedBlocks(region);
+  SetVector<Block *> blocks = getTopologicallySortedBlocks(region);
   for (Block *bb : blocks) {
     llvm::BasicBlock *llvmBB = moduleTranslation.lookupBlock(bb);
     // Retarget the branch of the entry block to the entry block of the
diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index ee73b04e020fd26..7312388bc9b4dd2 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -31,6 +31,7 @@
 #include "mlir/Support/LogicalResult.h"
 #include "mlir/Target/LLVMIR/LLVMTranslationInterface.h"
 #include "mlir/Target/LLVMIR/TypeToLLVM.h"
+#include "mlir/Transforms/RegionUtils.h"
 
 #include "llvm/ADT/PostOrderIterator.h"
 #include "llvm/ADT/SetVector.h"
@@ -571,24 +572,6 @@ void mlir::LLVM::detail::connectPHINodes(Region &region,
   }
 }
 
-/// Sort function blocks topologically.
-SetVector<Block *>
-mlir::LLVM::detail::getTopologicallySortedBlocks(Region &region) {
-  // For each block that has not been visited yet (i.e. that has no
-  // predecessors), add it to the list as well as its successors.
-  SetVector<Block *> blocks;
-  for (Block &b : region) {
-    if (blocks.count(&b) == 0) {
-      llvm::ReversePostOrderTraversal<Block *> traversal(&b);
-      blocks.insert(traversal.begin(), traversal.end());
-    }
-  }
-  assert(blocks.size() == region.getBlocks().size() &&
-         "some blocks are not sorted");
-
-  return blocks;
-}
-
 llvm::CallInst *mlir::LLVM::detail::createIntrinsicCall(
     llvm::IRBuilderBase &builder, llvm::Intrinsic::ID intrinsic,
     ArrayRef<llvm::Value *> args, ArrayRef<llvm::Type *> tys) {
@@ -922,7 +905,7 @@ LogicalResult ModuleTranslation::convertOneFunction(LLVMFuncOp func) {
 
   // Then, convert blocks one by one in topological order to ensure defs are
   // converted before uses.
-  auto blocks = detail::getTopologicallySortedBlocks(func.getBody());
+  auto blocks = getTopologicallySortedBlocks(func.getBody());
   for (Block *bb : blocks) {
     llvm::IRBuilder<> builder(llvmContext);
     if (failed(convertBlock(*bb, bb->isEntryBlock(), builder)))
diff --git a/mlir/lib/Transforms/Mem2Reg.cpp b/mlir/lib/Transforms/Mem2Reg.cpp
index 65de25dd2f32663..1794d12ae2768db 100644
--- a/mlir/lib/Transforms/Mem2Reg.cpp
+++ b/mlir/lib/Transforms/Mem2Reg.cpp
@@ -16,6 +16,8 @@
 #include "mlir/Interfaces/MemorySlotInterfaces.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "mlir/Transforms/Passes.h"
+#include "mlir/Transforms/RegionUtils.h"
+#include "llvm/ADT/PostOrderIterator.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/Support/Casting.h"
 #include "llvm/Support/GenericIteratedDominanceFrontier.h"
@@ -96,6 +98,9 @@ using namespace mlir;
 
 namespace {
 
+using BlockingUsesMap =
+    llvm::MapVector<Operation *, SmallPtrSet<OpOperand *, 4>>;
+
 /// Information computed during promotion analysis used to perform actual
 /// promotion.
 struct MemorySlotPromotionInfo {
@@ -106,7 +111,7 @@ struct MemorySlotPromotionInfo {
   /// its uses, it is because the defining ops of the blocking uses requested
   /// it. The defining ops therefore must also have blocking uses or be the
   /// starting point of the bloccking uses.
-  DenseMap<Operation *, SmallPtrSet<OpOperand *, 4>> userToBlockingUses;
+  BlockingUsesMap userToBlockingUses;
 };
 
 /// Computes information for basic slot promotion. This will check that direct
@@ -129,8 +134,7 @@ class MemorySlotPromotionAnalyzer {
   /// uses (typically, removing its users because it will delete itself to
   /// resolve its own blocking uses). This will fail if one of the transitive
   /// users cannot remove a requested use, and should prevent promotion.
-  LogicalResult computeBlockingUses(
-      DenseMap<Operation *, SmallPtrSet<OpOperand *, 4>> &userToBlockingUses);
+  LogicalResult computeBlockingUses(BlockingUsesMap &userToBlockingUses);
 
   /// Computes in which blocks the value stored in the slot is actually used,
   /// meaning blocks leading to a load. This method uses `definingBlocks`, the
@@ -233,7 +237,7 @@ Value MemorySlotPromoter::getLazyDefaultValue() {
 }
 
 LogicalResult MemorySlotPromotionAnalyzer::computeBlockingUses(
-    DenseMap<Operation *, SmallPtrSet<OpOperand *, 4>> &userToBlockingUses) {
+    BlockingUsesMap &userToBlockingUses) {
   // The promotion of an operation may require the promotion of further
   // operations (typically, removing operations that use an operation that must
   // delete itself). We thus need to start from the use of the slot pointer and
@@ -243,7 +247,7 @@ LogicalResult MemorySlotPromotionAnalyzer::computeBlockingUses(
   // use it.
   for (OpOperand &use : slot.ptr.getUses()) {
     SmallPtrSet<OpOperand *, 4> &blockingUses =
-        userToBlockingUses.getOrInsertDefault(use.getOwner());
+        userToBlockingUses[use.getOwner()];
     blockingUses.insert(&use);
   }
 
@@ -281,7 +285,7 @@ LogicalResult MemorySlotPromotionAnalyzer::computeBlockingUses(
       assert(llvm::is_contained(user->getResults(), blockingUse->get()));
 
       SmallPtrSetImpl<OpOperand *> &newUserBlockingUseSet =
-          userToBlockingUses.getOrInsertDefault(blockingUse->getOwner());
+          userToBlockingUses[blockingUse->getOwner()];
       newUserBlockingUseSet.insert(blockingUse);
     }
   }
@@ -515,15 +519,37 @@ void MemorySlotPromoter::computeReachingDefInRegion(Region *region,
   }
 }
 
+/// Sorts `ops` according to dominance. Relies on the topological order of basic
+/// blocks to get a deterministic ordering.
+static void dominanceSort(SmallVector<Operation *> &ops, Region &region) {
+  // Produce a topological block order and construct a map to lookup the indices
+  // of blocks.
+  DenseMap<Block *, size_t> topoBlockIndices;
+  SetVector<Block *> topologicalOrder = getTopologicallySortedBlocks(region);
+  for (auto [index, block] : llvm::enumerate(topologicalOrder))
+    topoBlockIndices[block] = index;
+
+  // Combining the topological order of the basic blocks together with block
+  // internal operation order guarentees a deterministic, dominance respecting
+  // order.
+  llvm::sort(ops, [&](Operation *lhs, Operation *rhs) {
+    size_t lhsBlockIndex = topoBlockIndices.at(lhs->getBlock());
+    size_t rhsBlockIndex = topoBlockIndices.at(rhs->getBlock());
+    if (lhsBlockIndex == rhsBlockIndex)
+      return lhs->isBeforeInBlock(rhs);
+    return lhsBlockIndex < rhsBlockIndex;
+  });
+}
+
 void MemorySlotPromoter::removeBlockingUses() {
-  llvm::SetVector<Operation *> usersToRemoveUses;
-  for (auto &user : llvm::make_first_range(info.userToBlockingUses))
-    usersToRemoveUses.insert(user);
-  SetVector<Operation *> sortedUsersToRemoveUses =
-      mlir::topologicalSort(usersToRemoveUses);
+  llvm::SmallVector<Operation *> usersToRemoveUses(
+      llvm::make_first_range(info.userToBlockingUses));
+
+  // Sort according to dominance.
+  dominanceSort(usersToRemoveUses, *slot.ptr.getParentBlock()->getParent());
 
   llvm::SmallVector<Operation *> toErase;
-  for (Operation *toPromote : llvm::reverse(sortedUsersToRemoveUses)) {
+  for (Operation *toPromote : llvm::reverse(usersToRemoveUses)) {
     if (auto toPromoteMemOp = dyn_cast<PromotableMemOpInterface>(toPromote)) {
       Value reachingDef = reachingDefs.lookup(toPromoteMemOp);
       // If no reaching definition is known, this use is outside the reach of
diff --git a/mlir/lib/Transforms/Utils/RegionUtils.cpp b/mlir/lib/Transforms/Utils/RegionUtils.cpp
index b95af9ca0299649..1f2344677e6515c 100644
--- a/mlir/lib/Transforms/Utils/RegionUtils.cpp
+++ b/mlir/lib/Transforms/Utils/RegionUtils.cpp
@@ -836,3 +836,19 @@ LogicalResult mlir::simplifyRegions(RewriterBase &rewriter,
   return success(eliminatedBlocks || eliminatedOpsOrArgs ||
                  mergedIdenticalBlocks);
 }
+
+SetVector<Block *> mlir::getTopologicallySortedBlocks(Region &region) {
+  // For each block that has not been visited yet (i.e. that has no
+  // predecessors), add it to the list as well as its successors.
+  SetVector<Block *> blocks;
+  for (Block &b : region) {
+    if (blocks.count(&b) == 0) {
+      llvm::ReversePostOrderTraversal<Block *> traversal(&b);
+      blocks.insert(traversal.begin(), traversal.end());
+    }
+  }
+  assert(blocks.size() == region.getBlocks().size() &&
+         "some blocks are not sorted");
+
+  return blocks;
+}
diff --git a/mlir/test/Dialect/LLVMIR/mem2reg.mlir b/mlir/test/Dialect/LLVMIR/mem2reg.mlir
index 30ba459d07a49f3..32e3fed7e5485df 100644
--- a/mlir/test/Dialect/LLVMIR/mem2reg.mlir
+++ b/mlir/test/Dialect/LLVMIR/mem2reg.mlir
@@ -683,3 +683,16 @@ llvm.func @no_inner_alloca_promotion(%arg: i64) -> i64 {
   // CHECK: llvm.return %[[RES]] : i64
   llvm.return %2 : i64
 }
+
+// -----
+
+// CHECK-LABEL: @transitive_reaching_def
+llvm.func @transitive_reaching_def() -> !llvm.ptr {
+  %0 = llvm.mlir.constant(1 : i32) : i32
+  // CHECK-NOT: alloca
+  %1 = llvm.alloca %0 x !llvm.ptr {alignment = 8 : i64} : (i32) -> !llvm.ptr
+  %2 = llvm.load %1 {alignment = 8 : i64} : !llvm.ptr -> !llvm.ptr
+  llvm.store %2, %1 {alignment = 8 : i64} : !llvm.ptr, !llvm.ptr
+  %3 = llvm.load %1 {alignment = 8 : i64} : !llvm.ptr -> !llvm.ptr
+  llvm.return %3 : !llvm.ptr
+}

@Dinistro
Copy link
Contributor Author

This is the third attempt of fixing this problem, previous commits did not have a correct strict weak ordering comparison function.
Link to the previous PR: #68767

Copy link
Contributor

@definelicht definelicht left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

…ance

This commit fixes a bug in the Mem2Reg operation erasure order.
Replacing the use-def based topological order with a dominance-based
weak order ensures that no operation is removed before all its uses
have been replaced. The order relation uses the topological order of
blocks and block internal ordering to determine a deterministic
operation order.

Additionally, the reliance on the `DenseMap` key order was eliminated
by switching to a `MapVector`, that gives a deterministic iteration
order.

Example:

```
%ptr = alloca ...
...
%val0 = %load %ptr ... // LOAD0
store %val0 %ptr ...
%val1 = load %ptr ... // LOAD1
````

When promoting the slot backing %ptr, it can happen that the LOAD0 was
cleaned before LOAD1. This results in all uses of LOAD0 being replaced
by its reaching definition, before LOAD1's result is replaced by
LOAD0's result. The subsequent erasure of LOAD0 can thus not succeed,
as it has remaining usages.
@Dinistro Dinistro force-pushed the fix-mem2reg-removal-order branch from 0b33cf6 to 726f54c Compare October 12, 2023 11:42
Copy link
Contributor

@gysit gysit left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@Dinistro Dinistro merged commit ab6a66d into llvm:main Oct 12, 2023
@Dinistro Dinistro deleted the fix-mem2reg-removal-order branch October 12, 2023 14:47
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants