Skip to content

Allow SymbolUserOpInterface operators to be used in RemoveDeadValues Pass #117405

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Nov 23, 2024

Conversation

codemzs
Copy link
Contributor

@codemzs codemzs commented Nov 23, 2024

This change removes the restriction on SymbolUserOpInterface operators so they can be used with operators that implement SymbolOpInterface, example:

memref.global implements SymbolOpInterface so it can be used with memref.get_global which implements SymbolUserOpInterface

// Define a global constant array
memref.global "private" constant @global_array : memref<10xi32> = dense<[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]> : tensor<10xi32>

// Access this global constant within a function
func @use_global() {
  %0 = memref.get_global @global_array : memref<10xi32>
}

Reference: #116519 and https://discourse.llvm.org/t/question-on-criteria-for-acceptable-ir-in-removedeadvaluespass/83131

@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir labels Nov 23, 2024
@llvmbot
Copy link
Member

llvmbot commented Nov 23, 2024

@llvm/pr-subscribers-mlir-core

@llvm/pr-subscribers-mlir

Author: M. Zeeshan Siddiqui (codemzs)

Changes

Pursuant to the conversation at #116519 this change removes the restriction on SymbolUserOpInterface operators so they can be used with operators that implement SymbolOpInterface, example:

memref.global implements SymbolOpInterface so it can be used with memref.get_global which implements SymbolUserOpInterface

// Define a global constant array
memref.global "private" constant @<!-- -->global_array : memref&lt;10xi32&gt; = dense&lt;[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]&gt; : tensor&lt;10xi32&gt;

// Access this global constant within a function
func @<!-- -->use_global() {
  %0 = memref.get_global @<!-- -->global_array : memref&lt;10xi32&gt;
  // Use %0 as needed
}

Full diff: https://github.com/llvm/llvm-project/pull/117405.diff

2 Files Affected:

  • (modified) mlir/lib/Transforms/RemoveDeadValues.cpp (+2-4)
  • (modified) mlir/test/Transforms/remove-dead-values.mlir (+2-1)
diff --git a/mlir/lib/Transforms/RemoveDeadValues.cpp b/mlir/lib/Transforms/RemoveDeadValues.cpp
index b82280dda8ba73..0aa9dcb36681b3 100644
--- a/mlir/lib/Transforms/RemoveDeadValues.cpp
+++ b/mlir/lib/Transforms/RemoveDeadValues.cpp
@@ -577,10 +577,8 @@ void RemoveDeadValues::runOnOperation() {
   WalkResult acceptableIR = module->walk([&](Operation *op) {
     if (op == module)
       return WalkResult::advance();
-    if (isa<BranchOpInterface>(op) ||
-        (isa<SymbolUserOpInterface>(op) && !isa<CallOpInterface>(op))) {
-      op->emitError() << "cannot optimize an IR with "
-                         "non-call symbol user ops or branch ops\n";
+    if (isa<BranchOpInterface>(op)) {
+      op->emitError() << "cannot optimize an IR with branch ops\n";
       return WalkResult::interrupt();
     }
     return WalkResult::advance();
diff --git a/mlir/test/Transforms/remove-dead-values.mlir b/mlir/test/Transforms/remove-dead-values.mlir
index 47137fc6430fea..7a8d49681a4b18 100644
--- a/mlir/test/Transforms/remove-dead-values.mlir
+++ b/mlir/test/Transforms/remove-dead-values.mlir
@@ -3,9 +3,10 @@
 // The IR is updated regardless of memref.global private constant
 //
 module {
-  memref.global "private" constant @__something_global : memref<i32> = dense<0>
+  memref.global "private" constant @global_buffer : memref<5xi32> = dense<[1, 2, 3, 4, 5]> : tensor<5xi32>
   func.func @main(%arg0: i32) -> i32 {
     %0 = tensor.empty() : tensor<10xbf16>
+    %1 = memref.get_global @global_buffer : memref<5xi32>
     // CHECK-NOT: tensor.empty
     return %arg0 : i32
   }

(isa<SymbolUserOpInterface>(op) && !isa<CallOpInterface>(op))) {
op->emitError() << "cannot optimize an IR with "
"non-call symbol user ops or branch ops\n";
if (isa<BranchOpInterface>(op)) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we know why the pass is incompatible with the branch op interface?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@joker-eph I am not sufficiently familiar with the pass to be sure but I am confused myself because the pass seems to work on BranchOps and even has tests that check for this condition and error message, however I believe it does so if the IR does not have "any non-function symbol ops, non-call symbol user ops and branch ops.".

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just checked the details, it uses LivenessAnalysis:
https://github.com/llvm/llvm-project/blob/main/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp#L51
and it is not designed to miss live values. The definition explicitly accounts for Control Flow in Branches: It considers non-forwarded branch operands and whether they lead to blocks with memory effects (1.b).
Transitive Dependencies: It ensures that all values contributing to the computation of live values (those with memory effects or returned by public functions) are also marked as live (3.a, 3.b).

Therefore, I recommend to drop this condition as well and add or modify the test.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you want to do this as part of this PR or later?

Copy link
Contributor Author

@codemzs codemzs Nov 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@parsifal-47 It appears that the pass does not consider values used as block arguments in branch operations as "live" uses. I realized this after I dropped the condition and then got the below error:

error: unexpected error: null operand found
cf.cond_br %arg0, ^bb1(%non_live : i32), ^bb2(%non_live : i32)

Here, %non_live is passed as a block argument to both ^bb1 and ^bb2, because the RemoveDeadValues pass doesn't recognize values used as block arguments in branch operations as "live", it incorrectly removes %non_live. This leads to a null operand error when cf.cond_br still references %non_live, which no longer exists.

What is confusing is there is a function in this pass called cleanRegionBranchOp which is supposed to handle this exact scenario but I think we may need to restructure this pass to drop this condition, therefore I am inclined to not do as part of this change but was curious if this is something you could handle in a separate change? What are your thoughts?

Copy link
Collaborator

@joker-eph joker-eph Nov 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I probably misunderstood your earlier comment, can you clarify what you were trying to do here:

If I change the test to have memory effects (i.e memref.store) in the blocks then the test passes, i.e:

Which test is this and what is the pass/fail criteria?

Copy link
Contributor Author

@codemzs codemzs Nov 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the RemoveDeadValues pass, there's an issue when a value is used only as a block argument in branch operations (e.g., cf.cond_br) but has no uses within the successor blocks. The pass incorrectly (?) treats such values as dead and removes their definitions. However, the branch operation still references these values as block arguments, leading to null operands and compiler errors like:

error: unexpected error: null operand found
cf.cond_br %arg0, ^bb1(%non_live : i32), ^bb2(%non_live : i32)

Example IR:

func.func @dont_touch_unacceptable_ir_has_cleanable_simple_op_with_branch_op(%arg0: i1) {
  %non_live = arith.constant 0 : i32
  cf.cond_br %arg0, ^bb1(%non_live : i32), ^bb2(%non_live : i32) {tag = "br"}
^bb1(%non_live_0 : i32):
  cf.br ^bb3
^bb2(%non_live_1 : i32):
  cf.br ^bb3
^bb3:
  return
}

Detailed Explanation:

  • What Happens:

    • The value %non_live is defined but only used as a block argument in the branch operation cf.cond_br.

    • It is not used inside the successor blocks ^bb1 and ^bb2.

    • The RemoveDeadValues pass incorrectly considers %non_live as dead and removes its definition.

    • The cf.cond_br operation still references %non_live, resulting in a null operand error.

  • Why It Doesn't Occur When Value Is Used Inside Blocks:

    • If %non_live is used within the successor blocks (e.g., in a memref.store operation with memory effects), the pass correctly identifies it as live.
    • Its definition is preserved, and no error occurs.
  • Cause of the Problem:

    • The liveness analysis in the RemoveDeadValues pass does not consider values used only as block arguments in branch operations to be live if they have no uses within the successor blocks.
    • This leads to the removal of their definitions while they are still referenced by the branch operations.
  • Proposed Solution:

    • Update the RemoveDeadValues pass to treat values used as block arguments in branch operations as live, even if they have no uses within the successor blocks OR remove the branches altogether assuming it has no uses.
    • This ensures that the definitions of such values are not removed, preventing null operand errors.

The criteria for test PASS/FAIL is compiler generating the correct IR with no errors, example the below IR remains the same and compiler does not produce an error:

 func.func @test_2_RegionBranchOpInterface_type_1.b(%arg0: memref<i32>, %arg1: memref<i32>, %arg2: i1) {
  %c0_i32 = arith.constant 0 : i32
  cf.cond_br %arg2, ^bb1(%c0_i32 : i32), ^bb2(%c0_i32 : i32) {tag = "br"}
^bb1(%0 : i32):
  memref.store %0, %arg0[] : memref<i32>
  cf.br ^bb3
^bb2(%1 : i32):
  memref.store %1, %arg1[] : memref<i32>
  cf.br ^bb3
^bb3:
  return
}

Copy link
Collaborator

@joker-eph joker-eph Nov 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the explanation.

What you're describing is a possible solution. We could also try to update the branch and successors to remove the unused value right?
It's also a case where having the ability to replace the value with a poison/undef operation could be useful.

Copy link
Contributor Author

@codemzs codemzs Nov 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it’s possible to create an empty branch or successor, updating the branch and successor would be a great solution. For instance, in this particular case, it would have resulted in a successor with zero arguments.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made a follow-up PR: #117501 please take a look if that is sufficient

@codemzs codemzs requested a review from joker-eph November 23, 2024 03:15
@joker-eph joker-eph merged commit 5f9db08 into llvm:main Nov 23, 2024
8 checks passed
CoTinker pushed a commit that referenced this pull request Dec 5, 2024
Fixing RemoveDeadValues to properly remove arguments from
BranchOpInterface operations.
This is a follow-up for:
#117405
cc: @joker-eph @codemzs

---------

Co-authored-by: Renat Idrisov <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:core MLIR Core Infrastructure mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants