-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[mlir][linalg] Fix EraseIdentityLinalgOp on fill-like ops #130000
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-linalg Author: Ian Wood (IanWood1) ChangesAdds a check to make sure that the linalg op is safe to erase by ensuring that the Closes #129414 Full diff: https://github.com/llvm/llvm-project/pull/130000.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 07b19e5cb1a89..c044c94c5af3d 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1285,13 +1285,17 @@ struct EraseIdentityLinalgOp : public OpRewritePattern<OpTy> {
// In the buffer case, we need to check exact buffer equality.
if (linalgOp.hasPureBufferSemantics()) {
- if (linalgOp.getNumDpsInputs() == 1 && linalgOp.getNumDpsInits() == 1 &&
- linalgOp.getDpsInputOperand(0)->get() ==
- linalgOp.getDpsInitOperand(0)->get()) {
- rewriter.eraseOp(linalgOp);
- return success();
- }
- return failure();
+ if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1 ||
+ linalgOp.getDpsInputOperand(0)->get() !=
+ linalgOp.getDpsInitOperand(0)->get())
+ return failure();
+
+ auto yieldArg = dyn_cast<BlockArgument>(yieldOp.getOperand(0));
+ if (!yieldArg || yieldArg.getOwner() != &body)
+ return failure();
+
+ rewriter.eraseOp(linalgOp);
+ return success();
}
// Mixed semantics is not supported yet.
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index db4f6181f517c..08d99c65a291d 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -415,7 +415,7 @@ func.func @dynamic_fill_pack(%arg0: tensor<?x?xf32>) -> tensor<?x?x16x16xf32> {
// -----
-// CHECK: func @fold_self_copy
+// CHECK-LABEL: func @fold_self_copy
func.func @fold_self_copy(%0 : memref<4x16xf32>) {
// CHECK-NEXT: return
linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
@@ -431,6 +431,25 @@ func.func @fold_self_copy(%0 : memref<4x16xf32>) {
// -----
+// CHECK-LABEL: func @no_fold_fill_like
+// CHECK: %[[VAL0:.+]] = arith.constant 0.000000e+00 : f32
+// CHECK: linalg.generic
+// CHECK: linalg.yield %[[VAL0]] : f32
+func.func @no_fold_fill_like(%0 : memref<4x16xf32>) {
+ %1 = arith.constant 0.0 : f32
+ linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+ affine_map<(d0, d1) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel"]}
+ ins(%0 : memref<4x16xf32>)
+ outs(%0 : memref<4x16xf32>) {
+ ^bb0(%arg4: f32, %arg5: f32):
+ linalg.yield %1 : f32
+ }
+ return
+}
+
+// -----
+
// CHECK-LABEL: func @fold_static_pad_fill
// CHECK: %[[F0:.+]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<412x276xf32>
|
Signed-off-by: Ian Wood <[email protected]>
@dcaballe @nicolasvasilache @rengolin Sorry to summon you like this :) Is there any concern with this PR? The fix looks good to me. Asking because one of our engineers just spent some time looking into an issue and found this canonicalization bug to be the cause of it. |
This looks pretty straightforward and overall LGTM. That said, it's not my area of expertise, so an extra pair of eyes would be helpful. I've left a few minor suggestions inline. Also - could you update the outdated comment above
And regarding:
Is there a negative test for the tensor case? |
Signed-off-by: Ian Wood <[email protected]>
Signed-off-by: Ian Wood <[email protected]>
I couldn't find one so I added it |
Signed-off-by: Ian Wood <[email protected]>
return success(); | ||
} | ||
return failure(); | ||
if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1 || |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add {
}
around multi-line statments here and below.
Adds a check to make sure that the linalg op is safe to erase by ensuring that the
linalg.yield
is yielding one of the linalg op's block args. This check already exists for linalg ops with pure tensor semantics.Closes #129414