-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[mlir][transform] Fix failure in flattening already flattened linalg ops #86037
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir Author: None (srcarroll) ChangesFull diff: https://github.com/llvm/llvm-project/pull/86037.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index ae28049f02e391..c93b656f42353c 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -3269,15 +3269,20 @@ DiagnosedSilenceableFailure transform::FlattenElementwiseLinalgOp::applyToOne(
transform::ApplyToEachResultList &results,
transform::TransformState &state) {
rewriter.setInsertionPoint(target);
- if (target.getNumLoops() <= 1)
+ if (!isElementwise(target)) {
+ failed(rewriter.notifyMatchFailure(
+ target, "only elementwise flattening is supported"));
+ return emitDefaultSilenceableFailure(target);
+ }
+ // If rank <= 1, do nothing
+ if (target.getNumLoops() <= 1) {
+ results.push_back(target);
return DiagnosedSilenceableFailure::success();
+ }
ReassociationIndices reassociation(target.getNumLoops());
std::iota(reassociation.begin(), reassociation.end(), 0);
auto maybeFlattened =
- (isElementwise(target))
- ? collapseOpIterationDims(target, reassociation, rewriter)
- : FailureOr<CollapseResult>(rewriter.notifyMatchFailure(
- target, "only elementwise flattening is supported"));
+ collapseOpIterationDims(target, reassociation, rewriter);
if (failed(maybeFlattened))
return emitDefaultSilenceableFailure(target);
results.push_back(maybeFlattened->collapsedOp);
diff --git a/mlir/test/Dialect/Linalg/flatten-elementwise.mlir b/mlir/test/Dialect/Linalg/flatten-elementwise.mlir
index 858c133dd536ca..5a27fe76b13411 100644
--- a/mlir/test/Dialect/Linalg/flatten-elementwise.mlir
+++ b/mlir/test/Dialect/Linalg/flatten-elementwise.mlir
@@ -67,6 +67,27 @@ module attributes {transform.with_named_sequence} {
// -----
+// CHECK-LABEL: func.func @map_already_flat(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: memref<32xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: memref<32xf32>
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]*]]: memref<32xf32>
+// CHECK-NEXT: linalg.map { arith.addf } ins(%[[ARG0]], %[[ARG1]] : memref<32xf32>, memref<32xf32>) outs(%[[ARG2]] : memref<32xf32>)
+func.func @map_already_flat(%arg0: memref<32xf32>, %arg1: memref<32xf32>, %arg2: memref<32xf32>) {
+ linalg.map {arith.addf} ins(%arg0, %arg1: memref<32xf32>, memref<32xf32>) outs(%arg2: memref<32xf32>)
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match interface{LinalgOp} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %flattened = transform.structured.flatten_elementwise %0
+ : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
// CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)>
// CHECK-LABEL: func.func @generic
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: memref<32x7xf32>
|
@llvm/pr-subscribers-mlir-linalg Author: None (srcarroll) ChangesFull diff: https://github.com/llvm/llvm-project/pull/86037.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index ae28049f02e391..c93b656f42353c 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -3269,15 +3269,20 @@ DiagnosedSilenceableFailure transform::FlattenElementwiseLinalgOp::applyToOne(
transform::ApplyToEachResultList &results,
transform::TransformState &state) {
rewriter.setInsertionPoint(target);
- if (target.getNumLoops() <= 1)
+ if (!isElementwise(target)) {
+ failed(rewriter.notifyMatchFailure(
+ target, "only elementwise flattening is supported"));
+ return emitDefaultSilenceableFailure(target);
+ }
+ // If rank <= 1, do nothing
+ if (target.getNumLoops() <= 1) {
+ results.push_back(target);
return DiagnosedSilenceableFailure::success();
+ }
ReassociationIndices reassociation(target.getNumLoops());
std::iota(reassociation.begin(), reassociation.end(), 0);
auto maybeFlattened =
- (isElementwise(target))
- ? collapseOpIterationDims(target, reassociation, rewriter)
- : FailureOr<CollapseResult>(rewriter.notifyMatchFailure(
- target, "only elementwise flattening is supported"));
+ collapseOpIterationDims(target, reassociation, rewriter);
if (failed(maybeFlattened))
return emitDefaultSilenceableFailure(target);
results.push_back(maybeFlattened->collapsedOp);
diff --git a/mlir/test/Dialect/Linalg/flatten-elementwise.mlir b/mlir/test/Dialect/Linalg/flatten-elementwise.mlir
index 858c133dd536ca..5a27fe76b13411 100644
--- a/mlir/test/Dialect/Linalg/flatten-elementwise.mlir
+++ b/mlir/test/Dialect/Linalg/flatten-elementwise.mlir
@@ -67,6 +67,27 @@ module attributes {transform.with_named_sequence} {
// -----
+// CHECK-LABEL: func.func @map_already_flat(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: memref<32xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: memref<32xf32>
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]*]]: memref<32xf32>
+// CHECK-NEXT: linalg.map { arith.addf } ins(%[[ARG0]], %[[ARG1]] : memref<32xf32>, memref<32xf32>) outs(%[[ARG2]] : memref<32xf32>)
+func.func @map_already_flat(%arg0: memref<32xf32>, %arg1: memref<32xf32>, %arg2: memref<32xf32>) {
+ linalg.map {arith.addf} ins(%arg0, %arg1: memref<32xf32>, memref<32xf32>) outs(%arg2: memref<32xf32>)
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match interface{LinalgOp} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %flattened = transform.structured.flatten_elementwise %0
+ : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
// CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)>
// CHECK-LABEL: func.func @generic
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: memref<32x7xf32>
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok.. I am assuming the flattening transformation does the right thing (i.e. does not nothing) if things are already flattened. If not maybe better to return failure there too.
yes it does nothing in the case the op is already a rank <= 1 op. I don't think returning failure is appropriate in this case though as i consider doing nothing on already flattened is a success. Maybe i misunderstood |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add a test for the user-visible error message after making it user-visible.
I would have requested changes on this if I had the opportunity. Please consider tagging me on transform dialect-related reviews.
failed(rewriter.notifyMatchFailure( | ||
target, "only elementwise flattening is supported")); | ||
return emitDefaultSilenceableFailure(target); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of sending the error message to the non-existent rewrite driver, this should rather emit it as a silenceable failure! The default failure messages produced by emitDefaultSilenceableFailure
are extremely unhelpful and must only be used as last resort.
I also haven't seen the usage of failed
as a way to suppress the compiler warning about LogicalResult
being unused. The common idiom is to use the C-style cast to void.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, I'm new to the transform framework and still learning how things work. I was just following another implementation in this same file for the message. I'll submit another PR to fix.
About tagging you. I made an RFC for this per your suggestion with a mention to the original PR and Mahesh was the only one that took any action so I just went with him again for this. I'll be sure to include you in the future.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No big deal, the point of review is also for contributors to learn things.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
appreciate it. thanks for teaching me. :)
…ops (llvm#86037) The previous implementation was doing an early successful return on `rank <= 1` without adding the original op to transform results. This resulted in errors about number of returns. This patch fixes this by adding the original op to results. Additionally, we first check if op is elementwise and return a slienceable failure early if not.
The previous implementation was doing an early successful return on
rank <= 1
without adding the original op to transform results. This resulted in errors about number of returns. This patch fixes this by adding the original op to results. Additionally, we first check if op is elementwise and return a slienceable failure early if not.