Skip to content

[mlir][transform] Fix failure in flattening already flattened linalg ops #86037

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 21, 2024

Conversation

srcarroll
Copy link
Contributor

@srcarroll srcarroll commented Mar 20, 2024

The previous implementation was doing an early successful return on rank <= 1 without adding the original op to transform results. This resulted in errors about number of returns. This patch fixes this by adding the original op to results. Additionally, we first check if op is elementwise and return a slienceable failure early if not.

@llvmbot
Copy link
Member

llvmbot commented Mar 20, 2024

@llvm/pr-subscribers-mlir

Author: None (srcarroll)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/86037.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp (+10-5)
  • (modified) mlir/test/Dialect/Linalg/flatten-elementwise.mlir (+21)
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index ae28049f02e391..c93b656f42353c 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -3269,15 +3269,20 @@ DiagnosedSilenceableFailure transform::FlattenElementwiseLinalgOp::applyToOne(
     transform::ApplyToEachResultList &results,
     transform::TransformState &state) {
   rewriter.setInsertionPoint(target);
-  if (target.getNumLoops() <= 1)
+  if (!isElementwise(target)) {
+    failed(rewriter.notifyMatchFailure(
+        target, "only elementwise flattening is supported"));
+    return emitDefaultSilenceableFailure(target);
+  }
+  // If rank <= 1, do nothing
+  if (target.getNumLoops() <= 1) {
+    results.push_back(target);
     return DiagnosedSilenceableFailure::success();
+  }
   ReassociationIndices reassociation(target.getNumLoops());
   std::iota(reassociation.begin(), reassociation.end(), 0);
   auto maybeFlattened =
-      (isElementwise(target))
-          ? collapseOpIterationDims(target, reassociation, rewriter)
-          : FailureOr<CollapseResult>(rewriter.notifyMatchFailure(
-                target, "only elementwise flattening is supported"));
+      collapseOpIterationDims(target, reassociation, rewriter);
   if (failed(maybeFlattened))
     return emitDefaultSilenceableFailure(target);
   results.push_back(maybeFlattened->collapsedOp);
diff --git a/mlir/test/Dialect/Linalg/flatten-elementwise.mlir b/mlir/test/Dialect/Linalg/flatten-elementwise.mlir
index 858c133dd536ca..5a27fe76b13411 100644
--- a/mlir/test/Dialect/Linalg/flatten-elementwise.mlir
+++ b/mlir/test/Dialect/Linalg/flatten-elementwise.mlir
@@ -67,6 +67,27 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+// CHECK-LABEL: func.func @map_already_flat(
+// CHECK-SAME:                 %[[ARG0:[a-zA-Z0-9_]*]]: memref<32xf32>
+// CHECK-SAME:                 %[[ARG1:[a-zA-Z0-9_]*]]: memref<32xf32>
+// CHECK-SAME:                 %[[ARG2:[a-zA-Z0-9_]*]]: memref<32xf32>
+// CHECK-NEXT:    linalg.map { arith.addf } ins(%[[ARG0]], %[[ARG1]] : memref<32xf32>, memref<32xf32>) outs(%[[ARG2]] : memref<32xf32>)
+func.func @map_already_flat(%arg0: memref<32xf32>, %arg1: memref<32xf32>, %arg2: memref<32xf32>) {
+    linalg.map {arith.addf} ins(%arg0, %arg1: memref<32xf32>, memref<32xf32>) outs(%arg2: memref<32xf32>)
+    return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match interface{LinalgOp} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %flattened = transform.structured.flatten_elementwise %0
+      : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
 // CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)>
 // CHECK-LABEL: func.func @generic
 // CHECK-SAME:                 %[[ARG0:[a-zA-Z0-9_]*]]: memref<32x7xf32>

@llvmbot
Copy link
Member

llvmbot commented Mar 20, 2024

@llvm/pr-subscribers-mlir-linalg

Author: None (srcarroll)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/86037.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp (+10-5)
  • (modified) mlir/test/Dialect/Linalg/flatten-elementwise.mlir (+21)
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index ae28049f02e391..c93b656f42353c 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -3269,15 +3269,20 @@ DiagnosedSilenceableFailure transform::FlattenElementwiseLinalgOp::applyToOne(
     transform::ApplyToEachResultList &results,
     transform::TransformState &state) {
   rewriter.setInsertionPoint(target);
-  if (target.getNumLoops() <= 1)
+  if (!isElementwise(target)) {
+    failed(rewriter.notifyMatchFailure(
+        target, "only elementwise flattening is supported"));
+    return emitDefaultSilenceableFailure(target);
+  }
+  // If rank <= 1, do nothing
+  if (target.getNumLoops() <= 1) {
+    results.push_back(target);
     return DiagnosedSilenceableFailure::success();
+  }
   ReassociationIndices reassociation(target.getNumLoops());
   std::iota(reassociation.begin(), reassociation.end(), 0);
   auto maybeFlattened =
-      (isElementwise(target))
-          ? collapseOpIterationDims(target, reassociation, rewriter)
-          : FailureOr<CollapseResult>(rewriter.notifyMatchFailure(
-                target, "only elementwise flattening is supported"));
+      collapseOpIterationDims(target, reassociation, rewriter);
   if (failed(maybeFlattened))
     return emitDefaultSilenceableFailure(target);
   results.push_back(maybeFlattened->collapsedOp);
diff --git a/mlir/test/Dialect/Linalg/flatten-elementwise.mlir b/mlir/test/Dialect/Linalg/flatten-elementwise.mlir
index 858c133dd536ca..5a27fe76b13411 100644
--- a/mlir/test/Dialect/Linalg/flatten-elementwise.mlir
+++ b/mlir/test/Dialect/Linalg/flatten-elementwise.mlir
@@ -67,6 +67,27 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+// CHECK-LABEL: func.func @map_already_flat(
+// CHECK-SAME:                 %[[ARG0:[a-zA-Z0-9_]*]]: memref<32xf32>
+// CHECK-SAME:                 %[[ARG1:[a-zA-Z0-9_]*]]: memref<32xf32>
+// CHECK-SAME:                 %[[ARG2:[a-zA-Z0-9_]*]]: memref<32xf32>
+// CHECK-NEXT:    linalg.map { arith.addf } ins(%[[ARG0]], %[[ARG1]] : memref<32xf32>, memref<32xf32>) outs(%[[ARG2]] : memref<32xf32>)
+func.func @map_already_flat(%arg0: memref<32xf32>, %arg1: memref<32xf32>, %arg2: memref<32xf32>) {
+    linalg.map {arith.addf} ins(%arg0, %arg1: memref<32xf32>, memref<32xf32>) outs(%arg2: memref<32xf32>)
+    return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match interface{LinalgOp} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %flattened = transform.structured.flatten_elementwise %0
+      : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
 // CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)>
 // CHECK-LABEL: func.func @generic
 // CHECK-SAME:                 %[[ARG0:[a-zA-Z0-9_]*]]: memref<32x7xf32>

Copy link
Contributor

@MaheshRavishankar MaheshRavishankar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok.. I am assuming the flattening transformation does the right thing (i.e. does not nothing) if things are already flattened. If not maybe better to return failure there too.

@srcarroll
Copy link
Contributor Author

srcarroll commented Mar 20, 2024

yes it does nothing in the case the op is already a rank <= 1 op. I don't think returning failure is appropriate in this case though as i consider doing nothing on already flattened is a success. Maybe i misunderstood

@srcarroll srcarroll merged commit df9ed9c into llvm:main Mar 21, 2024
Copy link
Member

@ftynse ftynse left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add a test for the user-visible error message after making it user-visible.

I would have requested changes on this if I had the opportunity. Please consider tagging me on transform dialect-related reviews.

Comment on lines +3273 to +3275
failed(rewriter.notifyMatchFailure(
target, "only elementwise flattening is supported"));
return emitDefaultSilenceableFailure(target);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of sending the error message to the non-existent rewrite driver, this should rather emit it as a silenceable failure! The default failure messages produced by emitDefaultSilenceableFailure are extremely unhelpful and must only be used as last resort.

I also haven't seen the usage of failed as a way to suppress the compiler warning about LogicalResult being unused. The common idiom is to use the C-style cast to void.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I'm new to the transform framework and still learning how things work. I was just following another implementation in this same file for the message. I'll submit another PR to fix.

About tagging you. I made an RFC for this per your suggestion with a mention to the original PR and Mahesh was the only one that took any action so I just went with him again for this. I'll be sure to include you in the future.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No big deal, the point of review is also for contributors to learn things.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

appreciate it. thanks for teaching me. :)

chencha3 pushed a commit to chencha3/llvm-project that referenced this pull request Mar 23, 2024
…ops (llvm#86037)

The previous implementation was doing an early successful return on
`rank <= 1` without adding the original op to transform results. This
resulted in errors about number of returns. This patch fixes this by
adding the original op to results. Additionally, we first check if op is
elementwise and return a slienceable failure early if not.
@srcarroll srcarroll deleted the fix-elementwise-flattening branch June 5, 2024 02:58
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants