Skip to content

Commit 726f54c

Browse files
committed
Reland: [MLIR][Transforms] Fix Mem2Reg removal order to respect dominance
This commit fixes a bug in the Mem2Reg operation erasure order. Replacing the use-def based topological order with a dominance-based weak order ensures that no operation is removed before all its uses have been replaced. The order relation uses the topological order of blocks and block internal ordering to determine a deterministic operation order. Additionally, the reliance on the `DenseMap` key order was eliminated by switching to a `MapVector`, that gives a deterministic iteration order. Example: ``` %ptr = alloca ... ... %val0 = %load %ptr ... // LOAD0 store %val0 %ptr ... %val1 = load %ptr ... // LOAD1 ```` When promoting the slot backing %ptr, it can happen that the LOAD0 was cleaned before LOAD1. This results in all uses of LOAD0 being replaced by its reaching definition, before LOAD1's result is replaced by LOAD0's result. The subsequent erasure of LOAD0 can thus not succeed, as it has remaining usages.
1 parent 660a78f commit 726f54c

File tree

11 files changed

+78
-38
lines changed

11 files changed

+78
-38
lines changed

mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -380,9 +380,6 @@ namespace detail {
380380
/// to the results of preceding blocks.
381381
void connectPHINodes(Region &region, const ModuleTranslation &state);
382382

383-
/// Get a topologically sorted list of blocks of the given region.
384-
SetVector<Block *> getTopologicallySortedBlocks(Region &region);
385-
386383
/// Create an LLVM IR constant of `llvmType` from the MLIR attribute `attr`.
387384
/// This currently supports integer, floating point, splat and dense element
388385
/// attributes and combinations thereof. Also, an array attribute with two

mlir/include/mlir/Transforms/RegionUtils.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,9 @@ LogicalResult eraseUnreachableBlocks(RewriterBase &rewriter,
8787
LogicalResult runRegionDCE(RewriterBase &rewriter,
8888
MutableArrayRef<Region> regions);
8989

90+
/// Get a topologically sorted list of blocks of the given region.
91+
SetVector<Block *> getTopologicallySortedBlocks(Region &region);
92+
9093
} // namespace mlir
9194

9295
#endif // MLIR_TRANSFORMS_REGIONUTILS_H_

mlir/lib/Target/LLVMIR/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ add_mlir_translation_library(MLIRTargetLLVMIRExport
3939
MLIRLLVMDialect
4040
MLIRLLVMIRTransforms
4141
MLIRTranslateLib
42+
MLIRTransformUtils
4243
)
4344

4445
add_mlir_translation_library(MLIRToLLVMIRTranslationRegistration

mlir/lib/Target/LLVMIR/Dialect/OpenACC/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,5 @@ add_mlir_translation_library(MLIROpenACCToLLVMIRTranslation
1010
MLIROpenACCDialect
1111
MLIRSupport
1212
MLIRTargetLLVMIRExport
13+
MLIRTransformUtils
1314
)

mlir/lib/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include "mlir/Support/LLVM.h"
2020
#include "mlir/Target/LLVMIR/Dialect/OpenMPCommon.h"
2121
#include "mlir/Target/LLVMIR/ModuleTranslation.h"
22+
#include "mlir/Transforms/RegionUtils.h"
2223

2324
#include "llvm/ADT/TypeSwitch.h"
2425
#include "llvm/Frontend/OpenMP/OMPConstants.h"
@@ -395,8 +396,7 @@ static LogicalResult convertDataOp(acc::DataOp &op,
395396
llvm::BasicBlock *endDataBlock = llvm::BasicBlock::Create(
396397
ctx, "acc.end_data", builder.GetInsertBlock()->getParent());
397398

398-
SetVector<Block *> blocks =
399-
LLVM::detail::getTopologicallySortedBlocks(op.getRegion());
399+
SetVector<Block *> blocks = getTopologicallySortedBlocks(op.getRegion());
400400
for (Block *bb : blocks) {
401401
llvm::BasicBlock *llvmBB = moduleTranslation.lookupBlock(bb);
402402
if (bb->isEntryBlock()) {

mlir/lib/Target/LLVMIR/Dialect/OpenMP/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,5 @@ add_mlir_translation_library(MLIROpenMPToLLVMIRTranslation
1010
MLIROpenMPDialect
1111
MLIRSupport
1212
MLIRTargetLLVMIRExport
13+
MLIRTransformUtils
1314
)

mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -194,8 +194,7 @@ static llvm::BasicBlock *convertOmpOpRegions(
194194

195195
// Convert blocks one by one in topological order to ensure
196196
// defs are converted before uses.
197-
SetVector<Block *> blocks =
198-
LLVM::detail::getTopologicallySortedBlocks(region);
197+
SetVector<Block *> blocks = getTopologicallySortedBlocks(region);
199198
for (Block *bb : blocks) {
200199
llvm::BasicBlock *llvmBB = moduleTranslation.lookupBlock(bb);
201200
// Retarget the branch of the entry block to the entry block of the

mlir/lib/Target/LLVMIR/ModuleTranslation.cpp

Lines changed: 2 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
#include "mlir/Support/LogicalResult.h"
3232
#include "mlir/Target/LLVMIR/LLVMTranslationInterface.h"
3333
#include "mlir/Target/LLVMIR/TypeToLLVM.h"
34+
#include "mlir/Transforms/RegionUtils.h"
3435

3536
#include "llvm/ADT/PostOrderIterator.h"
3637
#include "llvm/ADT/SetVector.h"
@@ -571,24 +572,6 @@ void mlir::LLVM::detail::connectPHINodes(Region &region,
571572
}
572573
}
573574

574-
/// Sort function blocks topologically.
575-
SetVector<Block *>
576-
mlir::LLVM::detail::getTopologicallySortedBlocks(Region &region) {
577-
// For each block that has not been visited yet (i.e. that has no
578-
// predecessors), add it to the list as well as its successors.
579-
SetVector<Block *> blocks;
580-
for (Block &b : region) {
581-
if (blocks.count(&b) == 0) {
582-
llvm::ReversePostOrderTraversal<Block *> traversal(&b);
583-
blocks.insert(traversal.begin(), traversal.end());
584-
}
585-
}
586-
assert(blocks.size() == region.getBlocks().size() &&
587-
"some blocks are not sorted");
588-
589-
return blocks;
590-
}
591-
592575
llvm::CallInst *mlir::LLVM::detail::createIntrinsicCall(
593576
llvm::IRBuilderBase &builder, llvm::Intrinsic::ID intrinsic,
594577
ArrayRef<llvm::Value *> args, ArrayRef<llvm::Type *> tys) {
@@ -922,7 +905,7 @@ LogicalResult ModuleTranslation::convertOneFunction(LLVMFuncOp func) {
922905

923906
// Then, convert blocks one by one in topological order to ensure defs are
924907
// converted before uses.
925-
auto blocks = detail::getTopologicallySortedBlocks(func.getBody());
908+
auto blocks = getTopologicallySortedBlocks(func.getBody());
926909
for (Block *bb : blocks) {
927910
llvm::IRBuilder<> builder(llvmContext);
928911
if (failed(convertBlock(*bb, bb->isEntryBlock(), builder)))

mlir/lib/Transforms/Mem2Reg.cpp

Lines changed: 38 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
#include "mlir/Interfaces/MemorySlotInterfaces.h"
1717
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
1818
#include "mlir/Transforms/Passes.h"
19+
#include "mlir/Transforms/RegionUtils.h"
20+
#include "llvm/ADT/PostOrderIterator.h"
1921
#include "llvm/ADT/STLExtras.h"
2022
#include "llvm/Support/Casting.h"
2123
#include "llvm/Support/GenericIteratedDominanceFrontier.h"
@@ -96,6 +98,9 @@ using namespace mlir;
9698

9799
namespace {
98100

101+
using BlockingUsesMap =
102+
llvm::MapVector<Operation *, SmallPtrSet<OpOperand *, 4>>;
103+
99104
/// Information computed during promotion analysis used to perform actual
100105
/// promotion.
101106
struct MemorySlotPromotionInfo {
@@ -106,7 +111,7 @@ struct MemorySlotPromotionInfo {
106111
/// its uses, it is because the defining ops of the blocking uses requested
107112
/// it. The defining ops therefore must also have blocking uses or be the
108113
/// starting point of the bloccking uses.
109-
DenseMap<Operation *, SmallPtrSet<OpOperand *, 4>> userToBlockingUses;
114+
BlockingUsesMap userToBlockingUses;
110115
};
111116

112117
/// Computes information for basic slot promotion. This will check that direct
@@ -129,8 +134,7 @@ class MemorySlotPromotionAnalyzer {
129134
/// uses (typically, removing its users because it will delete itself to
130135
/// resolve its own blocking uses). This will fail if one of the transitive
131136
/// users cannot remove a requested use, and should prevent promotion.
132-
LogicalResult computeBlockingUses(
133-
DenseMap<Operation *, SmallPtrSet<OpOperand *, 4>> &userToBlockingUses);
137+
LogicalResult computeBlockingUses(BlockingUsesMap &userToBlockingUses);
134138

135139
/// Computes in which blocks the value stored in the slot is actually used,
136140
/// meaning blocks leading to a load. This method uses `definingBlocks`, the
@@ -233,7 +237,7 @@ Value MemorySlotPromoter::getLazyDefaultValue() {
233237
}
234238

235239
LogicalResult MemorySlotPromotionAnalyzer::computeBlockingUses(
236-
DenseMap<Operation *, SmallPtrSet<OpOperand *, 4>> &userToBlockingUses) {
240+
BlockingUsesMap &userToBlockingUses) {
237241
// The promotion of an operation may require the promotion of further
238242
// operations (typically, removing operations that use an operation that must
239243
// delete itself). We thus need to start from the use of the slot pointer and
@@ -243,7 +247,7 @@ LogicalResult MemorySlotPromotionAnalyzer::computeBlockingUses(
243247
// use it.
244248
for (OpOperand &use : slot.ptr.getUses()) {
245249
SmallPtrSet<OpOperand *, 4> &blockingUses =
246-
userToBlockingUses.getOrInsertDefault(use.getOwner());
250+
userToBlockingUses[use.getOwner()];
247251
blockingUses.insert(&use);
248252
}
249253

@@ -281,7 +285,7 @@ LogicalResult MemorySlotPromotionAnalyzer::computeBlockingUses(
281285
assert(llvm::is_contained(user->getResults(), blockingUse->get()));
282286

283287
SmallPtrSetImpl<OpOperand *> &newUserBlockingUseSet =
284-
userToBlockingUses.getOrInsertDefault(blockingUse->getOwner());
288+
userToBlockingUses[blockingUse->getOwner()];
285289
newUserBlockingUseSet.insert(blockingUse);
286290
}
287291
}
@@ -515,15 +519,37 @@ void MemorySlotPromoter::computeReachingDefInRegion(Region *region,
515519
}
516520
}
517521

522+
/// Sorts `ops` according to dominance. Relies on the topological order of basic
523+
/// blocks to get a deterministic ordering.
524+
static void dominanceSort(SmallVector<Operation *> &ops, Region &region) {
525+
// Produce a topological block order and construct a map to lookup the indices
526+
// of blocks.
527+
DenseMap<Block *, size_t> topoBlockIndices;
528+
SetVector<Block *> topologicalOrder = getTopologicallySortedBlocks(region);
529+
for (auto [index, block] : llvm::enumerate(topologicalOrder))
530+
topoBlockIndices[block] = index;
531+
532+
// Combining the topological order of the basic blocks together with block
533+
// internal operation order guarantees a deterministic, dominance respecting
534+
// order.
535+
llvm::sort(ops, [&](Operation *lhs, Operation *rhs) {
536+
size_t lhsBlockIndex = topoBlockIndices.at(lhs->getBlock());
537+
size_t rhsBlockIndex = topoBlockIndices.at(rhs->getBlock());
538+
if (lhsBlockIndex == rhsBlockIndex)
539+
return lhs->isBeforeInBlock(rhs);
540+
return lhsBlockIndex < rhsBlockIndex;
541+
});
542+
}
543+
518544
void MemorySlotPromoter::removeBlockingUses() {
519-
llvm::SetVector<Operation *> usersToRemoveUses;
520-
for (auto &user : llvm::make_first_range(info.userToBlockingUses))
521-
usersToRemoveUses.insert(user);
522-
SetVector<Operation *> sortedUsersToRemoveUses =
523-
mlir::topologicalSort(usersToRemoveUses);
545+
llvm::SmallVector<Operation *> usersToRemoveUses(
546+
llvm::make_first_range(info.userToBlockingUses));
547+
548+
// Sort according to dominance.
549+
dominanceSort(usersToRemoveUses, *slot.ptr.getParentBlock()->getParent());
524550

525551
llvm::SmallVector<Operation *> toErase;
526-
for (Operation *toPromote : llvm::reverse(sortedUsersToRemoveUses)) {
552+
for (Operation *toPromote : llvm::reverse(usersToRemoveUses)) {
527553
if (auto toPromoteMemOp = dyn_cast<PromotableMemOpInterface>(toPromote)) {
528554
Value reachingDef = reachingDefs.lookup(toPromoteMemOp);
529555
// If no reaching definition is known, this use is outside the reach of

mlir/lib/Transforms/Utils/RegionUtils.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -836,3 +836,19 @@ LogicalResult mlir::simplifyRegions(RewriterBase &rewriter,
836836
return success(eliminatedBlocks || eliminatedOpsOrArgs ||
837837
mergedIdenticalBlocks);
838838
}
839+
840+
SetVector<Block *> mlir::getTopologicallySortedBlocks(Region &region) {
841+
// For each block that has not been visited yet (i.e. that has no
842+
// predecessors), add it to the list as well as its successors.
843+
SetVector<Block *> blocks;
844+
for (Block &b : region) {
845+
if (blocks.count(&b) == 0) {
846+
llvm::ReversePostOrderTraversal<Block *> traversal(&b);
847+
blocks.insert(traversal.begin(), traversal.end());
848+
}
849+
}
850+
assert(blocks.size() == region.getBlocks().size() &&
851+
"some blocks are not sorted");
852+
853+
return blocks;
854+
}

mlir/test/Dialect/LLVMIR/mem2reg.mlir

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -683,3 +683,16 @@ llvm.func @no_inner_alloca_promotion(%arg: i64) -> i64 {
683683
// CHECK: llvm.return %[[RES]] : i64
684684
llvm.return %2 : i64
685685
}
686+
687+
// -----
688+
689+
// CHECK-LABEL: @transitive_reaching_def
690+
llvm.func @transitive_reaching_def() -> !llvm.ptr {
691+
%0 = llvm.mlir.constant(1 : i32) : i32
692+
// CHECK-NOT: alloca
693+
%1 = llvm.alloca %0 x !llvm.ptr {alignment = 8 : i64} : (i32) -> !llvm.ptr
694+
%2 = llvm.load %1 {alignment = 8 : i64} : !llvm.ptr -> !llvm.ptr
695+
llvm.store %2, %1 {alignment = 8 : i64} : !llvm.ptr, !llvm.ptr
696+
%3 = llvm.load %1 {alignment = 8 : i64} : !llvm.ptr -> !llvm.ptr
697+
llvm.return %3 : !llvm.ptr
698+
}

0 commit comments

Comments
 (0)