Skip to content

Commit 0629e9e

Browse files
authored
[MLIR] Removing dead values for branches (#117501)
Fixing RemoveDeadValues to properly remove arguments from BranchOpInterface operations. This is a follow-up for: #117405 cc: @joker-eph @codemzs --------- Co-authored-by: Renat Idrisov <[email protected]>
1 parent abc2703 commit 0629e9e

File tree

2 files changed

+86
-25
lines changed

2 files changed

+86
-25
lines changed

mlir/lib/Transforms/RemoveDeadValues.cpp

Lines changed: 48 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ static SmallVector<OpOperand *> operandsToOpOperands(OperandRange operands) {
172172
/// iff it has no memory effects and none of its results are live.
173173
///
174174
/// It is assumed that `op` is simple. Here, a simple op is one which isn't a
175-
/// symbol op, a symbol-user op, a region branch op, a branch op, a region
175+
/// function-like op, a call-like op, a region branch op, a branch op, a region
176176
/// branch terminator op, or return-like.
177177
static void cleanSimpleOp(Operation *op, RunLivenessAnalysis &la) {
178178
if (!isMemoryEffectFree(op) || hasLive(op->getResults(), la))
@@ -563,6 +563,51 @@ static void cleanRegionBranchOp(RegionBranchOpInterface regionBranchOp,
563563
dropUsesAndEraseResults(regionBranchOp.getOperation(), resultsToKeep.flip());
564564
}
565565

566+
// 1. Iterate over each successor block of the given BranchOpInterface
567+
// operation.
568+
// 2. For each successor block:
569+
// a. Retrieve the operands passed to the successor.
570+
// b. Use the provided liveness analysis (`RunLivenessAnalysis`) to determine
571+
// which operands are live in the successor block.
572+
// c. Mark each operand as live or dead based on the analysis.
573+
// 3. Remove dead operands from the branch operation and arguments accordingly
574+
575+
static void cleanBranchOp(BranchOpInterface branchOp, RunLivenessAnalysis &la) {
576+
unsigned numSuccessors = branchOp->getNumSuccessors();
577+
578+
// Do (1)
579+
for (unsigned succIdx = 0; succIdx < numSuccessors; ++succIdx) {
580+
Block *successorBlock = branchOp->getSuccessor(succIdx);
581+
582+
// Do (2)
583+
SuccessorOperands successorOperands =
584+
branchOp.getSuccessorOperands(succIdx);
585+
SmallVector<Value> operandValues;
586+
for (unsigned operandIdx = 0; operandIdx < successorOperands.size();
587+
++operandIdx) {
588+
operandValues.push_back(successorOperands[operandIdx]);
589+
}
590+
591+
BitVector successorLiveOperands = markLives(operandValues, la);
592+
593+
// Do (3)
594+
for (int argIdx = successorLiveOperands.size() - 1; argIdx >= 0; --argIdx) {
595+
if (!successorLiveOperands[argIdx]) {
596+
if (successorBlock->getNumArguments() < successorOperands.size()) {
597+
// if block was cleaned through a different code path
598+
// we only need to remove operands from the invokation
599+
successorOperands.erase(argIdx);
600+
continue;
601+
}
602+
603+
successorBlock->getArgument(argIdx).dropAllUses();
604+
successorOperands.erase(argIdx);
605+
successorBlock->eraseArgument(argIdx);
606+
}
607+
}
608+
}
609+
}
610+
566611
struct RemoveDeadValues : public impl::RemoveDeadValuesBase<RemoveDeadValues> {
567612
void runOnOperation() override;
568613
};
@@ -572,26 +617,13 @@ void RemoveDeadValues::runOnOperation() {
572617
auto &la = getAnalysis<RunLivenessAnalysis>();
573618
Operation *module = getOperation();
574619

575-
// The removal of non-live values is performed iff there are no branch ops,
576-
// and all symbol user ops present in the IR are call-like.
577-
WalkResult acceptableIR = module->walk([&](Operation *op) {
578-
if (op == module)
579-
return WalkResult::advance();
580-
if (isa<BranchOpInterface>(op)) {
581-
op->emitError() << "cannot optimize an IR with branch ops\n";
582-
return WalkResult::interrupt();
583-
}
584-
return WalkResult::advance();
585-
});
586-
587-
if (acceptableIR.wasInterrupted())
588-
return signalPassFailure();
589-
590620
module->walk([&](Operation *op) {
591621
if (auto funcOp = dyn_cast<FunctionOpInterface>(op)) {
592622
cleanFuncOp(funcOp, module, la);
593623
} else if (auto regionBranchOp = dyn_cast<RegionBranchOpInterface>(op)) {
594624
cleanRegionBranchOp(regionBranchOp, la);
625+
} else if (auto branchOp = dyn_cast<BranchOpInterface>(op)) {
626+
cleanBranchOp(branchOp, la);
595627
} else if (op->hasTrait<::mlir::OpTrait::IsTerminator>()) {
596628
// Nothing to do here because this is a terminator op and it should be
597629
// honored with respect to its parent

mlir/test/Transforms/remove-dead-values.mlir

Lines changed: 38 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -28,22 +28,51 @@ module @named_module_acceptable {
2828

2929
// -----
3030

31-
// The IR remains untouched because of the presence of a branch op `cf.cond_br`.
31+
// The IR contains both conditional and unconditional branches with a loop
32+
// in which the last cf.cond_br is referncing the first cf.br
3233
//
33-
func.func @dont_touch_unacceptable_ir_has_cleanable_simple_op_with_branch_op(%arg0: i1) {
34+
func.func @acceptable_ir_has_cleanable_loop_of_conditional_and_branch_op(%arg0: i1) {
3435
%non_live = arith.constant 0 : i32
35-
// expected-error @+1 {{cannot optimize an IR with branch ops}}
36-
cf.cond_br %arg0, ^bb1(%non_live : i32), ^bb2(%non_live : i32)
37-
^bb1(%non_live_0 : i32):
38-
cf.br ^bb3
39-
^bb2(%non_live_1 : i32):
40-
cf.br ^bb3
41-
^bb3:
36+
// CHECK-NOT: arith.constant
37+
cf.br ^bb1(%non_live : i32)
38+
// CHECK: cf.br ^[[BB1:bb[0-9]+]]
39+
^bb1(%non_live_1 : i32):
40+
// CHECK: ^[[BB1]]:
41+
%non_live_5 = arith.constant 1 : i32
42+
cf.br ^bb3(%non_live_1, %non_live_5 : i32, i32)
43+
// CHECK: cf.br ^[[BB3:bb[0-9]+]]
44+
// CHECK-NOT: i32
45+
^bb3(%non_live_2 : i32, %non_live_6 : i32):
46+
// CHECK: ^[[BB3]]:
47+
cf.cond_br %arg0, ^bb1(%non_live_2 : i32), ^bb4(%non_live_2 : i32)
48+
// CHECK: cf.cond_br %arg0, ^[[BB1]], ^[[BB4:bb[0-9]+]]
49+
^bb4(%non_live_4 : i32):
50+
// CHECK: ^[[BB4]]:
4251
return
4352
}
4453

4554
// -----
4655

56+
// Checking that iter_args are properly handled
57+
//
58+
func.func @cleanable_loop_iter_args_value(%arg0: index) -> index {
59+
%c0 = arith.constant 0 : index
60+
%c1 = arith.constant 1 : index
61+
%c10 = arith.constant 10 : index
62+
%non_live = arith.constant 0 : index
63+
// CHECK: [[RESULT:%.+]] = scf.for [[ARG_1:%.*]] = %c0 to %c10 step %c1 iter_args([[ARG_2:%.*]] = %arg0) -> (index) {
64+
%result, %result_non_live = scf.for %i = %c0 to %c10 step %c1 iter_args(%live_arg = %arg0, %non_live_arg = %non_live) -> (index, index) {
65+
// CHECK: [[SUM:%.+]] = arith.addi [[ARG_2]], [[ARG_1]] : index
66+
%new_live = arith.addi %live_arg, %i : index
67+
// CHECK: scf.yield [[SUM:%.+]]
68+
scf.yield %new_live, %non_live_arg : index, index
69+
}
70+
// CHECK: return [[RESULT]] : index
71+
return %result : index
72+
}
73+
74+
// -----
75+
4776
// Note that this cleanup cannot be done by the `canonicalize` pass.
4877
//
4978
// CHECK-LABEL: func.func private @clean_func_op_remove_argument_and_return_value() {

0 commit comments

Comments
 (0)