Skip to content

Fix block merging #96871

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jul 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -463,10 +463,15 @@ struct BufferDeallocationSimplificationPass
SplitDeallocWhenNotAliasingAnyOther,
RetainedMemrefAliasingAlwaysDeallocatedMemref>(&getContext(),
analysis);
// We don't want that the block structure changes invalidating the
// `BufferOriginAnalysis` so we apply the rewrites witha `Normal` level of
// region simplification
GreedyRewriteConfig config;
config.enableRegionSimplification = GreedySimplifyRegionLevel::Normal;
populateDeallocOpCanonicalizationPatterns(patterns, &getContext());

if (failed(
applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns),
config)))
signalPassFailure();
}
};
Expand Down
144 changes: 132 additions & 12 deletions mlir/lib/Transforms/Utils/RegionUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,23 @@
#include "mlir/Transforms/RegionUtils.h"
#include "mlir/Analysis/TopologicalSortUtils.h"
#include "mlir/IR/Block.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/RegionGraphTraits.h"
#include "mlir/IR/Value.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Support/LogicalResult.h"

#include "llvm/ADT/DepthFirstIterator.h"
#include "llvm/ADT/PostOrderIterator.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallSet.h"

#include <deque>
#include <iterator>

using namespace mlir;

Expand Down Expand Up @@ -699,9 +704,8 @@ LogicalResult BlockMergeCluster::merge(RewriterBase &rewriter) {
blockIterators.push_back(mergeBlock->begin());

// Update each of the predecessor terminators with the new arguments.
SmallVector<SmallVector<Value, 8>, 2> newArguments(
1 + blocksToMerge.size(),
SmallVector<Value, 8>(operandsToMerge.size()));
SmallVector<SmallVector<Value, 8>, 2> newArguments(1 + blocksToMerge.size(),
SmallVector<Value, 8>());
unsigned curOpIndex = 0;
for (const auto &it : llvm::enumerate(operandsToMerge)) {
unsigned nextOpOffset = it.value().first - curOpIndex;
Expand All @@ -712,13 +716,22 @@ LogicalResult BlockMergeCluster::merge(RewriterBase &rewriter) {
Block::iterator &blockIter = blockIterators[i];
std::advance(blockIter, nextOpOffset);
auto &operand = blockIter->getOpOperand(it.value().second);
newArguments[i][it.index()] = operand.get();

// Update the operand and insert an argument if this is the leader.
if (i == 0) {
Value operandVal = operand.get();
operand.set(leaderBlock->addArgument(operandVal.getType(),
operandVal.getLoc()));
Value operandVal = operand.get();
Value *it = std::find(newArguments[i].begin(), newArguments[i].end(),
operandVal);
if (it == newArguments[i].end()) {
newArguments[i].push_back(operandVal);
// Update the operand and insert an argument if this is the leader.
if (i == 0) {
operand.set(leaderBlock->addArgument(operandVal.getType(),
operandVal.getLoc()));
}
} else if (i == 0) {
// If this is the leader, update the operand but do not insert a new
// argument. Instead, the opearand should point to one of the
// arguments we already passed (and that contained `operandVal`)
operand.set(leaderBlock->getArgument(
std::distance(newArguments[i].begin(), it)));
}
}
}
Expand Down Expand Up @@ -818,6 +831,109 @@ static LogicalResult mergeIdenticalBlocks(RewriterBase &rewriter,
return success(anyChanged);
}

static LogicalResult dropRedundantArguments(RewriterBase &rewriter,
Block &block) {
SmallVector<size_t> argsToErase;

// Go through the arguments of the block
for (size_t argIdx = 0; argIdx < block.getNumArguments(); argIdx++) {
bool sameArg = true;
Value commonValue;

// Go through the block predecessor and flag if they pass to the block
// different values for the same argument
for (auto predIt = block.pred_begin(), predE = block.pred_end();
predIt != predE; ++predIt) {
auto branch = dyn_cast<BranchOpInterface>((*predIt)->getTerminator());
if (!branch) {
sameArg = false;
break;
}
unsigned succIndex = predIt.getSuccessorIndex();
SuccessorOperands succOperands = branch.getSuccessorOperands(succIndex);
auto operands = succOperands.getForwardedOperands();
if (!commonValue) {
commonValue = operands[argIdx];
} else {
if (operands[argIdx] != commonValue) {
sameArg = false;
break;
}
}
}

// If they are passing the same value, drop the argument
if (commonValue && sameArg) {
argsToErase.push_back(argIdx);

// Remove the argument from the block
Value argVal = block.getArgument(argIdx);
rewriter.replaceAllUsesWith(argVal, commonValue);
}
}

// Remove the arguments
for (auto argIdx : llvm::reverse(argsToErase)) {
block.eraseArgument(argIdx);

// Remove the argument from the branch ops
for (auto predIt = block.pred_begin(), predE = block.pred_end();
predIt != predE; ++predIt) {
auto branch = cast<BranchOpInterface>((*predIt)->getTerminator());
unsigned succIndex = predIt.getSuccessorIndex();
SuccessorOperands succOperands = branch.getSuccessorOperands(succIndex);
succOperands.erase(argIdx);
}
}
return success(!argsToErase.empty());
}

/// This optimization drops redundant argument to blocks. I.e., if a given
/// argument to a block receives the same value from each of the block
/// predecessors, we can remove the argument from the block and use directly the
/// original value. This is a simple example:
///
/// %cond = llvm.call @rand() : () -> i1
/// %val0 = llvm.mlir.constant(1 : i64) : i64
/// %val1 = llvm.mlir.constant(2 : i64) : i64
/// %val2 = llvm.mlir.constant(3 : i64) : i64
/// llvm.cond_br %cond, ^bb1(%val0 : i64, %val1 : i64), ^bb2(%val0 : i64, %val2
/// : i64)
///
/// ^bb1(%arg0 : i64, %arg1 : i64):
/// llvm.call @foo(%arg0, %arg1)
///
/// The previous IR can be rewritten as:
/// %cond = llvm.call @rand() : () -> i1
/// %val0 = llvm.mlir.constant(1 : i64) : i64
/// %val1 = llvm.mlir.constant(2 : i64) : i64
/// %val2 = llvm.mlir.constant(3 : i64) : i64
/// llvm.cond_br %cond, ^bb1(%val1 : i64), ^bb2(%val2 : i64)
///
/// ^bb1(%arg0 : i64):
/// llvm.call @foo(%val0, %arg0)
///
static LogicalResult dropRedundantArguments(RewriterBase &rewriter,
MutableArrayRef<Region> regions) {
llvm::SmallSetVector<Region *, 1> worklist;
for (auto &region : regions)
worklist.insert(&region);
bool anyChanged = false;
while (!worklist.empty()) {
Region *region = worklist.pop_back_val();

// Add any nested regions to the worklist.
for (Block &block : *region) {
anyChanged = succeeded(dropRedundantArguments(rewriter, block));

for (auto &op : block)
for (auto &nestedRegion : op.getRegions())
worklist.insert(&nestedRegion);
}
}
return success(anyChanged);
}

//===----------------------------------------------------------------------===//
// Region Simplification
//===----------------------------------------------------------------------===//
Expand All @@ -832,8 +948,12 @@ LogicalResult mlir::simplifyRegions(RewriterBase &rewriter,
bool eliminatedBlocks = succeeded(eraseUnreachableBlocks(rewriter, regions));
bool eliminatedOpsOrArgs = succeeded(runRegionDCE(rewriter, regions));
bool mergedIdenticalBlocks = false;
if (mergeBlocks)
bool droppedRedundantArguments = false;
if (mergeBlocks) {
mergedIdenticalBlocks = succeeded(mergeIdenticalBlocks(rewriter, regions));
droppedRedundantArguments =
succeeded(dropRedundantArguments(rewriter, regions));
}
return success(eliminatedBlocks || eliminatedOpsOrArgs ||
mergedIdenticalBlocks);
mergedIdenticalBlocks || droppedRedundantArguments);
}
Original file line number Diff line number Diff line change
Expand Up @@ -178,28 +178,32 @@ func.func @condBranchDynamicTypeNested(
// CHECK-NEXT: ^bb1
// CHECK-NOT: bufferization.dealloc
// CHECK-NOT: bufferization.clone
// CHECK: cf.br ^bb5([[ARG1]], %false{{[0-9_]*}} :
// CHECK: cf.br ^bb6([[ARG1]], %false{{[0-9_]*}} :
// CHECK: ^bb2([[IDX:%.*]]:{{.*}})
// CHECK: [[ALLOC1:%.*]] = memref.alloc([[IDX]])
// CHECK-NEXT: test.buffer_based
// CHECK-NEXT: [[NOT_ARG0:%.+]] = arith.xori [[ARG0]], %true
// CHECK-NEXT: [[OWN:%.+]] = arith.select [[ARG0]], [[ARG0]], [[NOT_ARG0]]
// CHECK-NOT: bufferization.dealloc
// CHECK-NOT: bufferization.clone
// CHECK: cf.cond_br{{.*}}, ^bb3, ^bb3
// CHECK: cf.cond_br{{.*}}, ^bb3, ^bb4
// CHECK-NEXT: ^bb3:
// CHECK-NOT: bufferization.dealloc
// CHECK-NOT: bufferization.clone
// CHECK: cf.br ^bb4([[ALLOC1]], [[OWN]]
// CHECK-NEXT: ^bb4([[ALLOC2:%.*]]:{{.*}}, [[COND1:%.+]]:{{.*}})
// CHECK: cf.br ^bb5([[ALLOC1]], [[OWN]]
// CHECK-NEXT: ^bb4:
// CHECK-NOT: bufferization.dealloc
// CHECK-NOT: bufferization.clone
// CHECK: cf.br ^bb5([[ALLOC2]], [[COND1]]
// CHECK-NEXT: ^bb5([[ALLOC4:%.*]]:{{.*}}, [[COND2:%.+]]:{{.*}})
// CHECK: cf.br ^bb5([[ALLOC1]], [[OWN]]
// CHECK-NEXT: ^bb5([[ALLOC2:%.*]]:{{.*}}, [[COND1:%.+]]:{{.*}})
// CHECK-NOT: bufferization.dealloc
// CHECK-NOT: bufferization.clone
// CHECK: cf.br ^bb6([[ALLOC2]], [[COND1]]
// CHECK-NEXT: ^bb6([[ALLOC4:%.*]]:{{.*}}, [[COND2:%.+]]:{{.*}})
// CHECK-NEXT: [[BASE:%[a-zA-Z0-9_]+]]{{.*}} = memref.extract_strided_metadata [[ALLOC4]]
// CHECK-NEXT: [[OWN:%.+]]:2 = bufferization.dealloc ([[BASE]] :{{.*}}) if ([[COND2]]) retain ([[ALLOC4]], [[ARG2]] :
// CHECK: cf.br ^bb6([[ALLOC4]], [[OWN]]#0
// CHECK-NEXT: ^bb6([[ALLOC5:%.*]]:{{.*}}, [[COND3:%.+]]:{{.*}})
// CHECK: cf.br ^bb7([[ALLOC4]], [[OWN]]#0
// CHECK-NEXT: ^bb7([[ALLOC5:%.*]]:{{.*}}, [[COND3:%.+]]:{{.*}})
// CHECK: test.copy
// CHECK: [[BASE:%[a-zA-Z0-9_]+]]{{.*}} = memref.extract_strided_metadata [[ALLOC5]]
// CHECK-NEXT: bufferization.dealloc ([[BASE]] : {{.*}}) if ([[COND3]])
Expand Down
6 changes: 3 additions & 3 deletions mlir/test/Dialect/Linalg/detensorize_entry_block.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ func.func @main(%arg0: tensor<f32>) -> tensor<f32> {
// CHECK-LABEL: @main
// CHECK-SAME: (%[[ARG0:.+]]: tensor<f32>) -> tensor<f32>
// CHECK: %[[EXTRACTED:.+]] = tensor.extract %[[ARG0]][] : tensor<f32>
// CHECK: cf.br ^{{.*}}(%[[EXTRACTED]] : f32)
// CHECK: ^{{.*}}(%[[ARG1:.+]]: f32):
// CHECK: %[[ELEMENTS:.+]] = tensor.from_elements %[[ARG1]] : tensor<f32>
// CHECK: cf.br ^{{.*}}
// CHECK: ^{{.*}}:
// CHECK: %[[ELEMENTS:.+]] = tensor.from_elements %[[EXTRACTED]] : tensor<f32>
// CHECK: return %[[ELEMENTS]] : tensor<f32>
67 changes: 29 additions & 38 deletions mlir/test/Dialect/Linalg/detensorize_if.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -42,18 +42,15 @@ func.func @main() -> (tensor<i32>) attributes {} {
}

// CHECK-LABEL: func @main()
// CHECK-DAG: arith.constant 0
// CHECK-DAG: arith.constant 10
// CHECK: cf.br ^[[bb1:.*]](%{{.*}}: i32)
// CHECK-NEXT: ^[[bb1]](%{{.*}}: i32):
// CHECK-NEXT: arith.cmpi slt, %{{.*}}, %{{.*}}
// CHECK-NEXT: cf.cond_br %{{.*}}, ^[[bb2:.*]](%{{.*}} : i32), ^bb3(%{{.*}} : i32)
// CHECK-NEXT: ^[[bb2]](%{{.*}}: i32)
// CHECK-NEXT: arith.addi %{{.*}}, %{{.*}}
// CHECK-NEXT: cf.br ^[[bb3:.*]](%{{.*}} : i32)
// CHECK-NEXT: ^[[bb3]](%{{.*}}: i32)
// CHECK-NEXT: tensor.from_elements %{{.*}} : tensor<i32>
// CHECK-NEXT: return %{{.*}}
// CHECK-DAG: %[[cst:.*]] = arith.constant dense<0>
// CHECK-DAG: arith.constant true
// CHECK: cf.br
// CHECK-NEXT: ^[[bb1:.*]]:
// CHECK-NEXT: cf.cond_br %{{.*}}, ^[[bb2:.*]], ^bb3
// CHECK-NEXT: ^[[bb2]]
// CHECK-NEXT: cf.br ^[[bb3:.*]]
// CHECK-NEXT: ^[[bb3]]
// CHECK-NEXT: return %[[cst]]
// CHECK-NEXT: }

// -----
Expand Down Expand Up @@ -106,20 +103,17 @@ func.func @main() -> (tensor<i32>) attributes {} {
}

// CHECK-LABEL: func @main()
// CHECK-DAG: arith.constant 0
// CHECK-DAG: arith.constant 10
// CHECK: cf.br ^[[bb1:.*]](%{{.*}}: i32)
// CHECK-NEXT: ^[[bb1]](%{{.*}}: i32):
// CHECK-NEXT: arith.cmpi slt, %{{.*}}, %{{.*}}
// CHECK-NEXT: cf.cond_br %{{.*}}, ^[[bb2:.*]](%{{.*}} : i32), ^bb3(%{{.*}} : i32)
// CHECK-NEXT: ^[[bb2]](%{{.*}}: i32)
// CHECK-NEXT: arith.addi %{{.*}}, %{{.*}}
// CHECK-NEXT: cf.br ^[[bb3:.*]](%{{.*}} : i32)
// CHECK-NEXT: ^[[bb3]](%{{.*}}: i32)
// CHECK-NEXT: cf.br ^[[bb4:.*]](%{{.*}} : i32)
// CHECK-NEXT: ^[[bb4]](%{{.*}}: i32)
// CHECK-NEXT: tensor.from_elements %{{.*}} : tensor<i32>
// CHECK-NEXT: return %{{.*}}
// CHECK-DAG: %[[cst:.*]] = arith.constant dense<0>
// CHECK-DAG: arith.constant true
// CHECK: cf.br ^[[bb1:.*]]
// CHECK-NEXT: ^[[bb1:.*]]:
// CHECK-NEXT: cf.cond_br %{{.*}}, ^[[bb2:.*]], ^bb3
// CHECK-NEXT: ^[[bb2]]:
// CHECK-NEXT: cf.br ^[[bb3:.*]]
// CHECK-NEXT: ^[[bb3]]:
// CHECK-NEXT: cf.br ^[[bb4:.*]]
// CHECK-NEXT: ^[[bb4]]:
// CHECK-NEXT: return %[[cst]]
// CHECK-NEXT: }

// -----
Expand Down Expand Up @@ -171,16 +165,13 @@ func.func @main() -> (tensor<i32>) attributes {} {
}

// CHECK-LABEL: func @main()
// CHECK-DAG: arith.constant 0
// CHECK-DAG: arith.constant 10
// CHECK: cf.br ^[[bb1:.*]](%{{.*}}: i32)
// CHECK-NEXT: ^[[bb1]](%{{.*}}: i32):
// CHECK-NEXT: arith.cmpi slt, %{{.*}}, %{{.*}}
// CHECK-NEXT: cf.cond_br %{{.*}}, ^[[bb2:.*]](%{{.*}} : i32), ^bb2(%{{.*}} : i32)
// CHECK-NEXT: ^[[bb2]](%{{.*}}: i32)
// CHECK-NEXT: arith.addi %{{.*}}, %{{.*}}
// CHECK-NEXT: cf.br ^[[bb3:.*]](%{{.*}} : i32)
// CHECK-NEXT: ^[[bb3]](%{{.*}}: i32)
// CHECK-NEXT: tensor.from_elements %{{.*}} : tensor<i32>
// CHECK-NEXT: return %{{.*}}
// CHECK-DAG: %[[cst:.*]] = arith.constant dense<10>
// CHECK-DAG: arith.constant true
// CHECK: cf.br ^[[bb1:.*]]
// CHECK-NEXT: ^[[bb1]]:
// CHECK-NEXT: cf.cond_br %{{.*}}, ^[[bb2:.*]], ^bb2
// CHECK-NEXT: ^[[bb2]]
// CHECK-NEXT: cf.br ^[[bb3:.*]]
// CHECK-NEXT: ^[[bb3]]
// CHECK-NEXT: return %[[cst]]
// CHECK-NEXT: }
12 changes: 6 additions & 6 deletions mlir/test/Dialect/Linalg/detensorize_while.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,11 @@ func.func @main(%farg0: tensor<i32>, %farg1: tensor<i32>) -> tensor<i32> attribu
// DET-ALL: cf.br ^[[bb1:.*]](%{{.*}} : i32)
// DET-ALL: ^[[bb1]](%{{.*}}: i32)
// DET-ALL: arith.cmpi slt, {{.*}}
// DET-ALL: cf.cond_br {{.*}}, ^[[bb2:.*]](%{{.*}} : i32), ^[[bb3:.*]](%{{.*}} : i32)
// DET-ALL: ^[[bb2]](%{{.*}}: i32)
// DET-ALL: cf.cond_br {{.*}}, ^[[bb2:.*]], ^[[bb3:.*]]
// DET-ALL: ^[[bb2]]
// DET-ALL: arith.addi {{.*}}
// DET-ALL: cf.br ^[[bb1]](%{{.*}} : i32)
// DET-ALL: ^[[bb3]](%{{.*}}: i32)
// DET-ALL: ^[[bb3]]:
// DET-ALL: tensor.from_elements {{.*}}
// DET-ALL: return %{{.*}} : tensor<i32>

Expand All @@ -62,10 +62,10 @@ func.func @main(%farg0: tensor<i32>, %farg1: tensor<i32>) -> tensor<i32> attribu
// DET-CF: cf.br ^[[bb1:.*]](%{{.*}} : i32)
// DET-CF: ^[[bb1]](%{{.*}}: i32)
// DET-CF: arith.cmpi slt, {{.*}}
// DET-CF: cf.cond_br {{.*}}, ^[[bb2:.*]](%{{.*}} : i32), ^[[bb3:.*]](%{{.*}} : i32)
// DET-CF: ^[[bb2]](%{{.*}}: i32)
// DET-CF: cf.cond_br {{.*}}, ^[[bb2:.*]], ^[[bb3:.*]]
// DET-CF: ^[[bb2]]:
// DET-CF: arith.addi {{.*}}
// DET-CF: cf.br ^[[bb1]](%{{.*}} : i32)
// DET-CF: ^[[bb3]](%{{.*}}: i32)
// DET-CF: ^[[bb3]]:
// DET-CF: tensor.from_elements %{{.*}} : tensor<i32>
// DET-CF: return %{{.*}} : tensor<i32>
Loading
Loading