Skip to content

[MLIR] Removing dead values for branches #117501

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Dec 5, 2024

Conversation

parsifal-47
Copy link
Contributor

Fixing RemoveDeadValues to properly remove arguments from BranchOpInterface operations.
This is a follow-up for: #117405
cc: @joker-eph @codemzs

@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir labels Nov 24, 2024
@llvmbot
Copy link
Member

llvmbot commented Nov 24, 2024

@llvm/pr-subscribers-mlir-core

Author: Renat Idrisov (parsifal-47)

Changes

Fixing RemoveDeadValues to properly remove arguments from BranchOpInterface operations.
This is a follow-up for: #117405
cc: @joker-eph @codemzs


Full diff: https://github.com/llvm/llvm-project/pull/117501.diff

2 Files Affected:

  • (modified) mlir/lib/Transforms/RemoveDeadValues.cpp (+40-15)
  • (modified) mlir/test/Transforms/remove-dead-values.mlir (+20-3)
diff --git a/mlir/lib/Transforms/RemoveDeadValues.cpp b/mlir/lib/Transforms/RemoveDeadValues.cpp
index 0aa9dcb36681b3..638726e1212772 100644
--- a/mlir/lib/Transforms/RemoveDeadValues.cpp
+++ b/mlir/lib/Transforms/RemoveDeadValues.cpp
@@ -563,6 +563,44 @@ static void cleanRegionBranchOp(RegionBranchOpInterface regionBranchOp,
   dropUsesAndEraseResults(regionBranchOp.getOperation(), resultsToKeep.flip());
 }
 
+// 1. Iterate over each successor block of the given BranchOpInterface
+// operation.
+// 2. For each successor block:
+//    a. Retrieve the operands passed to the successor.
+//    b. Use the provided liveness analysis (`RunLivenessAnalysis`) to determine
+//    which
+//       operands are live in the successor block.
+//    c. Mark each operand as live or dead based on the analysis.
+// 3. Remove dead operands from the branch operation and arguments accordingly
+
+static void cleanBranchOp(BranchOpInterface branchOp, RunLivenessAnalysis &la) {
+  unsigned numSuccessors = branchOp->getNumSuccessors();
+
+  // Do (1)
+  for (unsigned succIdx = 0; succIdx < numSuccessors; ++succIdx) {
+    Block *successorBlock = branchOp->getSuccessor(succIdx);
+
+    // Do (2)
+    SuccessorOperands successorOperands =
+        branchOp.getSuccessorOperands(succIdx);
+    SmallVector<Value> operandValues;
+    for (unsigned operandIdx = 0; operandIdx < successorOperands.size();
+         ++operandIdx) {
+      operandValues.push_back(successorOperands[operandIdx]);
+    }
+
+    BitVector successorLiveOperands = markLives(operandValues, la);
+
+    // Do (3)
+    for (int argIdx = successorLiveOperands.size() - 1; argIdx >= 0; --argIdx) {
+      if (!successorLiveOperands[argIdx]) {
+        successorOperands.erase(argIdx);
+        successorBlock->eraseArgument(argIdx);
+      }
+    }
+  }
+}
+
 struct RemoveDeadValues : public impl::RemoveDeadValuesBase<RemoveDeadValues> {
   void runOnOperation() override;
 };
@@ -572,26 +610,13 @@ void RemoveDeadValues::runOnOperation() {
   auto &la = getAnalysis<RunLivenessAnalysis>();
   Operation *module = getOperation();
 
-  // The removal of non-live values is performed iff there are no branch ops,
-  // and all symbol user ops present in the IR are call-like.
-  WalkResult acceptableIR = module->walk([&](Operation *op) {
-    if (op == module)
-      return WalkResult::advance();
-    if (isa<BranchOpInterface>(op)) {
-      op->emitError() << "cannot optimize an IR with branch ops\n";
-      return WalkResult::interrupt();
-    }
-    return WalkResult::advance();
-  });
-
-  if (acceptableIR.wasInterrupted())
-    return signalPassFailure();
-
   module->walk([&](Operation *op) {
     if (auto funcOp = dyn_cast<FunctionOpInterface>(op)) {
       cleanFuncOp(funcOp, module, la);
     } else if (auto regionBranchOp = dyn_cast<RegionBranchOpInterface>(op)) {
       cleanRegionBranchOp(regionBranchOp, la);
+    } else if (auto branchOp = dyn_cast<BranchOpInterface>(op)) {
+      cleanBranchOp(branchOp, la);
     } else if (op->hasTrait<::mlir::OpTrait::IsTerminator>()) {
       // Nothing to do here because this is a terminator op and it should be
       // honored with respect to its parent
diff --git a/mlir/test/Transforms/remove-dead-values.mlir b/mlir/test/Transforms/remove-dead-values.mlir
index 826f6159a36b67..fda7ef3fe673e4 100644
--- a/mlir/test/Transforms/remove-dead-values.mlir
+++ b/mlir/test/Transforms/remove-dead-values.mlir
@@ -28,15 +28,32 @@ module @named_module_acceptable {
 
 // -----
 
-// The IR remains untouched because of the presence of a branch op `cf.cond_br`.
+// The IR is optimized regardless of the presence of a branch op `cf.cond_br`.
 //
-func.func @dont_touch_unacceptable_ir_has_cleanable_simple_op_with_branch_op(%arg0: i1) {
+func.func @acceptable_ir_has_cleanable_simple_op_with_branch_op(%arg0: i1) {
   %non_live = arith.constant 0 : i32
-  // expected-error @+1 {{cannot optimize an IR with branch ops}}
+  // CHECK-NOT: non_live
   cf.cond_br %arg0, ^bb1(%non_live : i32), ^bb2(%non_live : i32)
 ^bb1(%non_live_0 : i32):
+  // CHECK-NOT: non_live_0
   cf.br ^bb3
 ^bb2(%non_live_1 : i32):
+  // CHECK-NOT: non_live_1
+  cf.br ^bb3
+^bb3:
+  return
+}
+
+// -----
+
+// Arguments of unconditional branch op `cf.br` are properly removed.
+//
+func.func @acceptable_ir_has_cleanable_simple_op_with_unconditional_branch_op(%arg0: i1) {
+  %non_live = arith.constant 0 : i32
+  // CHECK-NOT: non_live
+  cf.br ^bb1(%non_live : i32)
+^bb1(%non_live_1 : i32):
+  // CHECK-NOT: non_live_1
   cf.br ^bb3
 ^bb3:
   return

@llvmbot
Copy link
Member

llvmbot commented Nov 24, 2024

@llvm/pr-subscribers-mlir

Author: Renat Idrisov (parsifal-47)

Changes

Fixing RemoveDeadValues to properly remove arguments from BranchOpInterface operations.
This is a follow-up for: #117405
cc: @joker-eph @codemzs


Full diff: https://github.com/llvm/llvm-project/pull/117501.diff

2 Files Affected:

  • (modified) mlir/lib/Transforms/RemoveDeadValues.cpp (+40-15)
  • (modified) mlir/test/Transforms/remove-dead-values.mlir (+20-3)
diff --git a/mlir/lib/Transforms/RemoveDeadValues.cpp b/mlir/lib/Transforms/RemoveDeadValues.cpp
index 0aa9dcb36681b3..638726e1212772 100644
--- a/mlir/lib/Transforms/RemoveDeadValues.cpp
+++ b/mlir/lib/Transforms/RemoveDeadValues.cpp
@@ -563,6 +563,44 @@ static void cleanRegionBranchOp(RegionBranchOpInterface regionBranchOp,
   dropUsesAndEraseResults(regionBranchOp.getOperation(), resultsToKeep.flip());
 }
 
+// 1. Iterate over each successor block of the given BranchOpInterface
+// operation.
+// 2. For each successor block:
+//    a. Retrieve the operands passed to the successor.
+//    b. Use the provided liveness analysis (`RunLivenessAnalysis`) to determine
+//    which
+//       operands are live in the successor block.
+//    c. Mark each operand as live or dead based on the analysis.
+// 3. Remove dead operands from the branch operation and arguments accordingly
+
+static void cleanBranchOp(BranchOpInterface branchOp, RunLivenessAnalysis &la) {
+  unsigned numSuccessors = branchOp->getNumSuccessors();
+
+  // Do (1)
+  for (unsigned succIdx = 0; succIdx < numSuccessors; ++succIdx) {
+    Block *successorBlock = branchOp->getSuccessor(succIdx);
+
+    // Do (2)
+    SuccessorOperands successorOperands =
+        branchOp.getSuccessorOperands(succIdx);
+    SmallVector<Value> operandValues;
+    for (unsigned operandIdx = 0; operandIdx < successorOperands.size();
+         ++operandIdx) {
+      operandValues.push_back(successorOperands[operandIdx]);
+    }
+
+    BitVector successorLiveOperands = markLives(operandValues, la);
+
+    // Do (3)
+    for (int argIdx = successorLiveOperands.size() - 1; argIdx >= 0; --argIdx) {
+      if (!successorLiveOperands[argIdx]) {
+        successorOperands.erase(argIdx);
+        successorBlock->eraseArgument(argIdx);
+      }
+    }
+  }
+}
+
 struct RemoveDeadValues : public impl::RemoveDeadValuesBase<RemoveDeadValues> {
   void runOnOperation() override;
 };
@@ -572,26 +610,13 @@ void RemoveDeadValues::runOnOperation() {
   auto &la = getAnalysis<RunLivenessAnalysis>();
   Operation *module = getOperation();
 
-  // The removal of non-live values is performed iff there are no branch ops,
-  // and all symbol user ops present in the IR are call-like.
-  WalkResult acceptableIR = module->walk([&](Operation *op) {
-    if (op == module)
-      return WalkResult::advance();
-    if (isa<BranchOpInterface>(op)) {
-      op->emitError() << "cannot optimize an IR with branch ops\n";
-      return WalkResult::interrupt();
-    }
-    return WalkResult::advance();
-  });
-
-  if (acceptableIR.wasInterrupted())
-    return signalPassFailure();
-
   module->walk([&](Operation *op) {
     if (auto funcOp = dyn_cast<FunctionOpInterface>(op)) {
       cleanFuncOp(funcOp, module, la);
     } else if (auto regionBranchOp = dyn_cast<RegionBranchOpInterface>(op)) {
       cleanRegionBranchOp(regionBranchOp, la);
+    } else if (auto branchOp = dyn_cast<BranchOpInterface>(op)) {
+      cleanBranchOp(branchOp, la);
     } else if (op->hasTrait<::mlir::OpTrait::IsTerminator>()) {
       // Nothing to do here because this is a terminator op and it should be
       // honored with respect to its parent
diff --git a/mlir/test/Transforms/remove-dead-values.mlir b/mlir/test/Transforms/remove-dead-values.mlir
index 826f6159a36b67..fda7ef3fe673e4 100644
--- a/mlir/test/Transforms/remove-dead-values.mlir
+++ b/mlir/test/Transforms/remove-dead-values.mlir
@@ -28,15 +28,32 @@ module @named_module_acceptable {
 
 // -----
 
-// The IR remains untouched because of the presence of a branch op `cf.cond_br`.
+// The IR is optimized regardless of the presence of a branch op `cf.cond_br`.
 //
-func.func @dont_touch_unacceptable_ir_has_cleanable_simple_op_with_branch_op(%arg0: i1) {
+func.func @acceptable_ir_has_cleanable_simple_op_with_branch_op(%arg0: i1) {
   %non_live = arith.constant 0 : i32
-  // expected-error @+1 {{cannot optimize an IR with branch ops}}
+  // CHECK-NOT: non_live
   cf.cond_br %arg0, ^bb1(%non_live : i32), ^bb2(%non_live : i32)
 ^bb1(%non_live_0 : i32):
+  // CHECK-NOT: non_live_0
   cf.br ^bb3
 ^bb2(%non_live_1 : i32):
+  // CHECK-NOT: non_live_1
+  cf.br ^bb3
+^bb3:
+  return
+}
+
+// -----
+
+// Arguments of unconditional branch op `cf.br` are properly removed.
+//
+func.func @acceptable_ir_has_cleanable_simple_op_with_unconditional_branch_op(%arg0: i1) {
+  %non_live = arith.constant 0 : i32
+  // CHECK-NOT: non_live
+  cf.br ^bb1(%non_live : i32)
+^bb1(%non_live_1 : i32):
+  // CHECK-NOT: non_live_1
   cf.br ^bb3
 ^bb3:
   return

Copy link

⚠️ We detected that you are using a GitHub private e-mail address to contribute to the repo.
Please turn off Keep my email addresses private setting in your account.
See LLVM Discourse for more information.

@parsifal-47 parsifal-47 changed the title Removing dead values for branches [MLIR] Removing dead values for branches Nov 24, 2024
Copy link
Collaborator

@joker-eph joker-eph left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you try to write a test with a scf.for loop that takes some iter_args?

@parsifal-47
Copy link
Contributor Author

Can you try to write a test with a scf.for loop that takes some iter_args?

Sure, updated with the test: 1663984
@joker-eph please take a look, thank you!

%non_live = arith.constant 0 : i32
// CHECK-NOT: non_live
cf.br ^bb1(%non_live : i32)
^bb1(%non_live_1 : i32):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we create a cf.br op inside this branch and then pass the %non_live_1 value as one of it's successor block argument? We are recursively recreating this scenario in one or both of it's successor blocks and I just want to see if alias analysis will clean up the values that refer to the same dead value.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would also be nice to have a case where one of these dead values is used in the conditions of cf.br of a similar example.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added a sub-branch with a non_live argument, which made me discover and fix a bug in the implementation, thank you!
for the second ask, I am not sure I understand, if the value is used as a condition for control-flow, it is alive, it can't be dead.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@codemzs updated, let me know what do you think

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In ^bb1(%non_live_1 : i32): can you create something like %non_live_2 = arith.constant 0 : i32 and then pass it as a second argument to cf.br ^bb3(%non_live_1 : i32, %non_live_2 : i32)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, please take a look, thank you!

%c1 = arith.constant 1 : index
%c10 = arith.constant 10 : index
%non_live = arith.constant 0 : index
// CHECK-NOT: non_live
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The names are never propagated, I don't quite see how these check could test anything actually?

I believe you need to actually check that there is a single item_args.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thank you for noticing!
Changed the condition to // CHECK: scf.for %[[ARG_0:.*]] = %c0 to %c10 step %c1 iter_args(%[[ARG_1:.*]] = %arg0)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@joker-eph updated, please let me know if I understood your comment correctly

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still see a // CHECK-NOT: non_live line 53 right now?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry, I was thinking it is limited to scf, removed all CHECK-NOT: <name> conditions, replaced them with positive checks, thank you!

// c. Mark each operand as live or dead based on the analysis.
// 3. Remove dead operands from the branch operation and arguments accordingly

static void cleanBranchOp(BranchOpInterface branchOp, RunLivenessAnalysis &la) {
Copy link
Contributor

@codemzs codemzs Nov 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this needs to be a recursive solution, you can have the same situation in a conditional branch, for example with in a branch you could declare variables and then pass them to the nested conditional branch in this case it won't be part of your initial successor block arguments.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My first idea was to make a recursive function, but I checked other pieces of this pass and I saw no recursion. It looks like it is applied till stable point is reached and no more values getting deleted. In iterative fashion. If doing recursively, we need to check for maximal depth and cyclic dependencies. And do that for any kinds of operations. Do you have an advice @joker-eph ?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function is called in a walk() which itself will recursively visit the IR.

That said: what about adding a test to cover the case that @codemzs is describing?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @joker-eph completely forgot this was being invoked from walk(). We are good on that front.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added such test, but after looking closely I realized that walk and walk<WalkOrder::PostOrder> does not help with BranchOp, probably because they do not have parent-child relation. Sorry, I was wrong, it does not reapply iteratively by itself. I am going to add a recursion in some form and let you know.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am quite surprised walk() does not traverse nested branch ops, I would also try creating this IR programmatically to ensure you are not missing anything.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it does, the issue is the ordering, I need "inner" branches to be traversed first, that is what PostOrder does, but in case of Branches there is not so much hierarchy. I am going to play with it a bit and get back to you

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It turns out to be simpler, I added a test with branching loop passing multiple dead values around to both conditional and unconditional branches. walk works perfectly when all of them are cleaned consistently. Thank you!

@parsifal-47
Copy link
Contributor Author

@codemzs @joker-eph your comments should be resolved now, thank you!

@parsifal-47
Copy link
Contributor Author

@codemzs @joker-eph please take a look once you have a chance, I think I addressed all of your comments, thank you!

@CoTinker
Copy link
Contributor

CoTinker commented Dec 4, 2024

And please update the comments of cleanSimpleOp

/// It is assumed that `op` is simple. Here, a simple op is one which isn't a
/// symbol op, a symbol-user op, a region branch op, a branch op, a region
/// branch terminator op, or return-like.
static void cleanSimpleOp(Operation *op, RunLivenessAnalysis &la) {

@parsifal-47
Copy link
Contributor Author

And please update the comments of cleanSimpleOp

cleanSimpleOp has not changed, it does not invoke anything additional now, or maybe I do not understand your comment

@CoTinker
Copy link
Contributor

CoTinker commented Dec 4, 2024

Sorry, I misunderstand it. But maybe we should change the comment says symbol op, a symbol-user op, beacuse symbol op except function-like and symbol-user op except call-like are received by this function now.

@parsifal-47
Copy link
Contributor Author

a symbol-user op

no problem, sounds good, updated the comment, please take a look, thank you!

@CoTinker
Copy link
Contributor

CoTinker commented Dec 4, 2024

Sorry, I should have been a little more explicit.
Please change the original comments from symbol op, a symbol-user op to function-like op, a call-like op.

^bb3(%non_live_2 : i32, %non_live_6 : i32):
// CHECK: ^[[BB3]]:
cf.cond_br %arg0, ^bb1(%non_live_2 : i32), ^bb4(%non_live_2 : i32)
// CHECK: cf.cond_br %arg0, ^[[BB1]], ^[[BB4:bb[0-9]+]]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My understanding is all of the arith.constant in your tests are not alive, hence all of these branch blocks should not have any arguments but why do we see ^[[BB4:bb[0-9]+]] versus ^[[BB1]]?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct, the difference between ^[[BB4:bb[0-9]+]] and ^[[BB1]] that the first one is matching the branch name which was generated by optimization, and the second one is using the name. If I understand the question correctly.

cf.br ^bb3
^bb3:
// CHECK-NOT: arith.constant
cf.br ^bb1(%non_live : i32)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lets add a check for cf.br i.e what is expected after this line.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated, thank you!

@parsifal-47
Copy link
Contributor Author

Sorry, I should have been a little more explicit. Please change the original comments from symbol op, a symbol-user op to function-like op, a call-like op.

sorry about that, updated, thank you!

Copy link
Contributor

@codemzs codemzs left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you, @parsifal-47 for making this change!

@CoTinker
Copy link
Contributor

CoTinker commented Dec 5, 2024

Thanks for your work, LGTM. if you have time, you can take a look at this issue #118450.
It's crash due to dropAlluses before erase, result a empty operands linalg op.

@parsifal-47
Copy link
Contributor Author

Thanks for your work, LGTM. if you have time, you can take a look at this issue #118450. It's crash due to dropAlluses before erase, result a empty operands linalg op.

Sure, l will be happy take a look!
Meanwhile, do you have write access here to merge this PR?
Thank you!

@CoTinker CoTinker merged commit 0629e9e into llvm:main Dec 5, 2024
8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:core MLIR Core Infrastructure mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants