Skip to content

[mlir][linalg] Fix EraseIdentityLinalgOp on fill-like ops #130000

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

IanWood1
Copy link
Contributor

@IanWood1 IanWood1 commented Mar 6, 2025

Adds a check to make sure that the linalg op is safe to erase by ensuring that the linalg.yield is yielding one of the linalg op's block args. This check already exists for linalg ops with pure tensor semantics.

Closes #129414

@llvmbot
Copy link
Member

llvmbot commented Mar 6, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-linalg

Author: Ian Wood (IanWood1)

Changes

Adds a check to make sure that the linalg op is safe to erase by ensuring that the linalg.yield is yielding one of the generic's block args.

Closes #129414


Full diff: https://github.com/llvm/llvm-project/pull/130000.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp (+11-7)
  • (modified) mlir/test/Dialect/Linalg/canonicalize.mlir (+20-1)
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 07b19e5cb1a89..c044c94c5af3d 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1285,13 +1285,17 @@ struct EraseIdentityLinalgOp : public OpRewritePattern<OpTy> {
 
     // In the buffer case, we need to check exact buffer equality.
     if (linalgOp.hasPureBufferSemantics()) {
-      if (linalgOp.getNumDpsInputs() == 1 && linalgOp.getNumDpsInits() == 1 &&
-          linalgOp.getDpsInputOperand(0)->get() ==
-              linalgOp.getDpsInitOperand(0)->get()) {
-        rewriter.eraseOp(linalgOp);
-        return success();
-      }
-      return failure();
+      if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1 ||
+          linalgOp.getDpsInputOperand(0)->get() !=
+              linalgOp.getDpsInitOperand(0)->get())
+        return failure();
+
+      auto yieldArg = dyn_cast<BlockArgument>(yieldOp.getOperand(0));
+      if (!yieldArg || yieldArg.getOwner() != &body)
+        return failure();
+
+      rewriter.eraseOp(linalgOp);
+      return success();
     }
 
     // Mixed semantics is not supported yet.
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index db4f6181f517c..08d99c65a291d 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -415,7 +415,7 @@ func.func @dynamic_fill_pack(%arg0: tensor<?x?xf32>) -> tensor<?x?x16x16xf32> {
 
 // -----
 
-// CHECK: func @fold_self_copy
+// CHECK-LABEL: func @fold_self_copy
 func.func @fold_self_copy(%0 : memref<4x16xf32>) {
 // CHECK-NEXT: return
   linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
@@ -431,6 +431,25 @@ func.func @fold_self_copy(%0 : memref<4x16xf32>) {
 
 // -----
 
+// CHECK-LABEL: func @no_fold_fill_like
+//       CHECK:   %[[VAL0:.+]] = arith.constant 0.000000e+00 : f32
+//       CHECK:   linalg.generic 
+//       CHECK:     linalg.yield %[[VAL0]] : f32
+func.func @no_fold_fill_like(%0 : memref<4x16xf32>) {
+  %1 = arith.constant 0.0 : f32
+  linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+                                   affine_map<(d0, d1) -> (d0, d1)>],
+                  iterator_types = ["parallel", "parallel"]}
+    ins(%0 : memref<4x16xf32>)
+    outs(%0 : memref<4x16xf32>) {
+      ^bb0(%arg4: f32, %arg5: f32):
+        linalg.yield %1 : f32
+    }
+  return
+}
+
+// -----
+
 // CHECK-LABEL: func @fold_static_pad_fill
 //       CHECK:   %[[F0:.+]] = arith.constant 0.000000e+00 : f32
 //       CHECK:   %[[INIT:.+]] = tensor.empty() : tensor<412x276xf32>

@mdehling
Copy link
Contributor

@dcaballe @nicolasvasilache @rengolin Sorry to summon you like this :)

Is there any concern with this PR? The fix looks good to me.

Asking because one of our engineers just spent some time looking into an issue and found this canonicalization bug to be the cause of it.

@banach-space
Copy link
Contributor

This looks pretty straightforward and overall LGTM. That said, it's not my area of expertise, so an extra pair of eyes would be helpful.

I've left a few minor suggestions inline. Also - could you update the outdated comment above EraseIdentityLinalgOp?

/// Remove any linalg operation (on tensors) that are just copying

And regarding:

This check already exists for linalg ops with pure tensor semantics.

Is there a negative test for the tensor case?

@IanWood1 IanWood1 requested a review from banach-space May 29, 2025 17:11
Signed-off-by: Ian Wood <[email protected]>
@IanWood1
Copy link
Contributor Author

Is there a negative test for the tensor case?

I couldn't find one so I added it

return success();
}
return failure();
if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1 ||
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add { } around multi-line statments here and below.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[mlir] Inconsistent results for linalg.generic
5 participants