Skip to content

Commit be81f42

Browse files
authored
[MLIR][Transforms] Fix Mem2Reg removal order to respect dominance (#68687)
This commit fixes a bug in the Mem2Reg operation erasure order. Replacing the topological order with a dominance based order ensures that no operation is removed before all its uses have been replaced. Additionally, the reliance on the `DenseMap` key order was eliminated by switching to a `MapVector`, that gives a deterministic iteration order. Example: ``` %ptr = alloca ... ... %val0 = %load %ptr ... // LOAD0 store %val0 %ptr ... %val1 = load %ptr ... // LOAD1 ```` When promoting the slot backing %ptr, it can happen that the LOAD0 was cleaned before LOAD1. This results in all uses of LOAD0 being replaced by its reaching definition, before LOAD1's result is replaced by LOAD0's result. The subsequent erasure of LOAD0 can thus not succeed, as it has remaining usages.
1 parent 0c21dfd commit be81f42

File tree

2 files changed

+29
-12
lines changed

2 files changed

+29
-12
lines changed

mlir/lib/Transforms/Mem2Reg.cpp

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,9 @@ using namespace mlir;
9696

9797
namespace {
9898

99+
using BlockingUsesMap =
100+
llvm::MapVector<Operation *, SmallPtrSet<OpOperand *, 4>>;
101+
99102
/// Information computed during promotion analysis used to perform actual
100103
/// promotion.
101104
struct MemorySlotPromotionInfo {
@@ -106,7 +109,7 @@ struct MemorySlotPromotionInfo {
106109
/// its uses, it is because the defining ops of the blocking uses requested
107110
/// it. The defining ops therefore must also have blocking uses or be the
108111
/// starting point of the bloccking uses.
109-
DenseMap<Operation *, SmallPtrSet<OpOperand *, 4>> userToBlockingUses;
112+
BlockingUsesMap userToBlockingUses;
110113
};
111114

112115
/// Computes information for basic slot promotion. This will check that direct
@@ -129,8 +132,7 @@ class MemorySlotPromotionAnalyzer {
129132
/// uses (typically, removing its users because it will delete itself to
130133
/// resolve its own blocking uses). This will fail if one of the transitive
131134
/// users cannot remove a requested use, and should prevent promotion.
132-
LogicalResult computeBlockingUses(
133-
DenseMap<Operation *, SmallPtrSet<OpOperand *, 4>> &userToBlockingUses);
135+
LogicalResult computeBlockingUses(BlockingUsesMap &userToBlockingUses);
134136

135137
/// Computes in which blocks the value stored in the slot is actually used,
136138
/// meaning blocks leading to a load. This method uses `definingBlocks`, the
@@ -233,7 +235,7 @@ Value MemorySlotPromoter::getLazyDefaultValue() {
233235
}
234236

235237
LogicalResult MemorySlotPromotionAnalyzer::computeBlockingUses(
236-
DenseMap<Operation *, SmallPtrSet<OpOperand *, 4>> &userToBlockingUses) {
238+
BlockingUsesMap &userToBlockingUses) {
237239
// The promotion of an operation may require the promotion of further
238240
// operations (typically, removing operations that use an operation that must
239241
// delete itself). We thus need to start from the use of the slot pointer and
@@ -243,7 +245,7 @@ LogicalResult MemorySlotPromotionAnalyzer::computeBlockingUses(
243245
// use it.
244246
for (OpOperand &use : slot.ptr.getUses()) {
245247
SmallPtrSet<OpOperand *, 4> &blockingUses =
246-
userToBlockingUses.getOrInsertDefault(use.getOwner());
248+
userToBlockingUses[use.getOwner()];
247249
blockingUses.insert(&use);
248250
}
249251

@@ -281,7 +283,7 @@ LogicalResult MemorySlotPromotionAnalyzer::computeBlockingUses(
281283
assert(llvm::is_contained(user->getResults(), blockingUse->get()));
282284

283285
SmallPtrSetImpl<OpOperand *> &newUserBlockingUseSet =
284-
userToBlockingUses.getOrInsertDefault(blockingUse->getOwner());
286+
userToBlockingUses[blockingUse->getOwner()];
285287
newUserBlockingUseSet.insert(blockingUse);
286288
}
287289
}
@@ -516,14 +518,16 @@ void MemorySlotPromoter::computeReachingDefInRegion(Region *region,
516518
}
517519

518520
void MemorySlotPromoter::removeBlockingUses() {
519-
llvm::SetVector<Operation *> usersToRemoveUses;
520-
for (auto &user : llvm::make_first_range(info.userToBlockingUses))
521-
usersToRemoveUses.insert(user);
522-
SetVector<Operation *> sortedUsersToRemoveUses =
523-
mlir::topologicalSort(usersToRemoveUses);
521+
llvm::SmallVector<Operation *> usersToRemoveUses(
522+
llvm::make_first_range(info.userToBlockingUses));
523+
// The uses need to be traversed in *reverse dominance* order to ensure that
524+
// transitive replacements are performed correctly.
525+
llvm::sort(usersToRemoveUses, [&](Operation *lhs, Operation *rhs) {
526+
return dominance.properlyDominates(rhs, lhs);
527+
});
524528

525529
llvm::SmallVector<Operation *> toErase;
526-
for (Operation *toPromote : llvm::reverse(sortedUsersToRemoveUses)) {
530+
for (Operation *toPromote : usersToRemoveUses) {
527531
if (auto toPromoteMemOp = dyn_cast<PromotableMemOpInterface>(toPromote)) {
528532
Value reachingDef = reachingDefs.lookup(toPromoteMemOp);
529533
// If no reaching definition is known, this use is outside the reach of

mlir/test/Dialect/LLVMIR/mem2reg.mlir

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -683,3 +683,16 @@ llvm.func @no_inner_alloca_promotion(%arg: i64) -> i64 {
683683
// CHECK: llvm.return %[[RES]] : i64
684684
llvm.return %2 : i64
685685
}
686+
687+
// -----
688+
689+
// CHECK-LABEL: @transitive_reaching_def
690+
llvm.func @transitive_reaching_def() -> !llvm.ptr {
691+
%0 = llvm.mlir.constant(1 : i32) : i32
692+
// CHECK-NOT: alloca
693+
%1 = llvm.alloca %0 x !llvm.ptr {alignment = 8 : i64} : (i32) -> !llvm.ptr
694+
%2 = llvm.load %1 {alignment = 8 : i64} : !llvm.ptr -> !llvm.ptr
695+
llvm.store %2, %1 {alignment = 8 : i64} : !llvm.ptr, !llvm.ptr
696+
%3 = llvm.load %1 {alignment = 8 : i64} : !llvm.ptr -> !llvm.ptr
697+
llvm.return %3 : !llvm.ptr
698+
}

0 commit comments

Comments
 (0)