Skip to content

[MLIR] Removing dead values for branches #117501

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Dec 5, 2024
Merged
64 changes: 48 additions & 16 deletions mlir/lib/Transforms/RemoveDeadValues.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ static SmallVector<OpOperand *> operandsToOpOperands(OperandRange operands) {
/// iff it has no memory effects and none of its results are live.
///
/// It is assumed that `op` is simple. Here, a simple op is one which isn't a
/// symbol op, a symbol-user op, a region branch op, a branch op, a region
/// function-like op, a call-like op, a region branch op, a branch op, a region
/// branch terminator op, or return-like.
static void cleanSimpleOp(Operation *op, RunLivenessAnalysis &la) {
if (!isMemoryEffectFree(op) || hasLive(op->getResults(), la))
Expand Down Expand Up @@ -563,6 +563,51 @@ static void cleanRegionBranchOp(RegionBranchOpInterface regionBranchOp,
dropUsesAndEraseResults(regionBranchOp.getOperation(), resultsToKeep.flip());
}

// 1. Iterate over each successor block of the given BranchOpInterface
// operation.
// 2. For each successor block:
// a. Retrieve the operands passed to the successor.
// b. Use the provided liveness analysis (`RunLivenessAnalysis`) to determine
// which operands are live in the successor block.
// c. Mark each operand as live or dead based on the analysis.
// 3. Remove dead operands from the branch operation and arguments accordingly

static void cleanBranchOp(BranchOpInterface branchOp, RunLivenessAnalysis &la) {
Copy link
Contributor

@codemzs codemzs Nov 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this needs to be a recursive solution, you can have the same situation in a conditional branch, for example with in a branch you could declare variables and then pass them to the nested conditional branch in this case it won't be part of your initial successor block arguments.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My first idea was to make a recursive function, but I checked other pieces of this pass and I saw no recursion. It looks like it is applied till stable point is reached and no more values getting deleted. In iterative fashion. If doing recursively, we need to check for maximal depth and cyclic dependencies. And do that for any kinds of operations. Do you have an advice @joker-eph ?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function is called in a walk() which itself will recursively visit the IR.

That said: what about adding a test to cover the case that @codemzs is describing?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @joker-eph completely forgot this was being invoked from walk(). We are good on that front.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added such test, but after looking closely I realized that walk and walk<WalkOrder::PostOrder> does not help with BranchOp, probably because they do not have parent-child relation. Sorry, I was wrong, it does not reapply iteratively by itself. I am going to add a recursion in some form and let you know.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am quite surprised walk() does not traverse nested branch ops, I would also try creating this IR programmatically to ensure you are not missing anything.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it does, the issue is the ordering, I need "inner" branches to be traversed first, that is what PostOrder does, but in case of Branches there is not so much hierarchy. I am going to play with it a bit and get back to you

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It turns out to be simpler, I added a test with branching loop passing multiple dead values around to both conditional and unconditional branches. walk works perfectly when all of them are cleaned consistently. Thank you!

unsigned numSuccessors = branchOp->getNumSuccessors();

// Do (1)
for (unsigned succIdx = 0; succIdx < numSuccessors; ++succIdx) {
Block *successorBlock = branchOp->getSuccessor(succIdx);

// Do (2)
SuccessorOperands successorOperands =
branchOp.getSuccessorOperands(succIdx);
SmallVector<Value> operandValues;
for (unsigned operandIdx = 0; operandIdx < successorOperands.size();
++operandIdx) {
operandValues.push_back(successorOperands[operandIdx]);
}

BitVector successorLiveOperands = markLives(operandValues, la);

// Do (3)
for (int argIdx = successorLiveOperands.size() - 1; argIdx >= 0; --argIdx) {
if (!successorLiveOperands[argIdx]) {
if (successorBlock->getNumArguments() < successorOperands.size()) {
// if block was cleaned through a different code path
// we only need to remove operands from the invokation
successorOperands.erase(argIdx);
continue;
}

successorBlock->getArgument(argIdx).dropAllUses();
successorOperands.erase(argIdx);
successorBlock->eraseArgument(argIdx);
}
}
}
}

struct RemoveDeadValues : public impl::RemoveDeadValuesBase<RemoveDeadValues> {
void runOnOperation() override;
};
Expand All @@ -572,26 +617,13 @@ void RemoveDeadValues::runOnOperation() {
auto &la = getAnalysis<RunLivenessAnalysis>();
Operation *module = getOperation();

// The removal of non-live values is performed iff there are no branch ops,
// and all symbol user ops present in the IR are call-like.
WalkResult acceptableIR = module->walk([&](Operation *op) {
if (op == module)
return WalkResult::advance();
if (isa<BranchOpInterface>(op)) {
op->emitError() << "cannot optimize an IR with branch ops\n";
return WalkResult::interrupt();
}
return WalkResult::advance();
});

if (acceptableIR.wasInterrupted())
return signalPassFailure();

module->walk([&](Operation *op) {
if (auto funcOp = dyn_cast<FunctionOpInterface>(op)) {
cleanFuncOp(funcOp, module, la);
} else if (auto regionBranchOp = dyn_cast<RegionBranchOpInterface>(op)) {
cleanRegionBranchOp(regionBranchOp, la);
} else if (auto branchOp = dyn_cast<BranchOpInterface>(op)) {
cleanBranchOp(branchOp, la);
} else if (op->hasTrait<::mlir::OpTrait::IsTerminator>()) {
// Nothing to do here because this is a terminator op and it should be
// honored with respect to its parent
Expand Down
47 changes: 38 additions & 9 deletions mlir/test/Transforms/remove-dead-values.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -28,22 +28,51 @@ module @named_module_acceptable {

// -----

// The IR remains untouched because of the presence of a branch op `cf.cond_br`.
// The IR contains both conditional and unconditional branches with a loop
// in which the last cf.cond_br is referncing the first cf.br
//
func.func @dont_touch_unacceptable_ir_has_cleanable_simple_op_with_branch_op(%arg0: i1) {
func.func @acceptable_ir_has_cleanable_loop_of_conditional_and_branch_op(%arg0: i1) {
%non_live = arith.constant 0 : i32
// expected-error @+1 {{cannot optimize an IR with branch ops}}
cf.cond_br %arg0, ^bb1(%non_live : i32), ^bb2(%non_live : i32)
^bb1(%non_live_0 : i32):
cf.br ^bb3
^bb2(%non_live_1 : i32):
cf.br ^bb3
^bb3:
// CHECK-NOT: arith.constant
cf.br ^bb1(%non_live : i32)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lets add a check for cf.br i.e what is expected after this line.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated, thank you!

// CHECK: cf.br ^[[BB1:bb[0-9]+]]
^bb1(%non_live_1 : i32):
// CHECK: ^[[BB1]]:
%non_live_5 = arith.constant 1 : i32
cf.br ^bb3(%non_live_1, %non_live_5 : i32, i32)
// CHECK: cf.br ^[[BB3:bb[0-9]+]]
// CHECK-NOT: i32
^bb3(%non_live_2 : i32, %non_live_6 : i32):
// CHECK: ^[[BB3]]:
cf.cond_br %arg0, ^bb1(%non_live_2 : i32), ^bb4(%non_live_2 : i32)
// CHECK: cf.cond_br %arg0, ^[[BB1]], ^[[BB4:bb[0-9]+]]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My understanding is all of the arith.constant in your tests are not alive, hence all of these branch blocks should not have any arguments but why do we see ^[[BB4:bb[0-9]+]] versus ^[[BB1]]?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct, the difference between ^[[BB4:bb[0-9]+]] and ^[[BB1]] that the first one is matching the branch name which was generated by optimization, and the second one is using the name. If I understand the question correctly.

^bb4(%non_live_4 : i32):
// CHECK: ^[[BB4]]:
return
}

// -----

// Checking that iter_args are properly handled
//
func.func @cleanable_loop_iter_args_value(%arg0: index) -> index {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c10 = arith.constant 10 : index
%non_live = arith.constant 0 : index
// CHECK: [[RESULT:%.+]] = scf.for [[ARG_1:%.*]] = %c0 to %c10 step %c1 iter_args([[ARG_2:%.*]] = %arg0) -> (index) {
%result, %result_non_live = scf.for %i = %c0 to %c10 step %c1 iter_args(%live_arg = %arg0, %non_live_arg = %non_live) -> (index, index) {
// CHECK: [[SUM:%.+]] = arith.addi [[ARG_2]], [[ARG_1]] : index
%new_live = arith.addi %live_arg, %i : index
// CHECK: scf.yield [[SUM:%.+]]
scf.yield %new_live, %non_live_arg : index, index
}
// CHECK: return [[RESULT]] : index
return %result : index
}

// -----

// Note that this cleanup cannot be done by the `canonicalize` pass.
//
// CHECK-LABEL: func.func private @clean_func_op_remove_argument_and_return_value() {
Expand Down
Loading