Skip to content

[mlir][Linalg] Add folders for linalg.transpose #81709

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Feb 21, 2024

Conversation

dcaballe
Copy link
Contributor

This PR adds folders for linalg transpose ops with only one dimension or an identity permutation. The folding removes the linalg.transpose and just propagates the input tensor. Given that this is a DPS op, I'm now wondering if this folding is incorrect and we should instead replace the op with a linalg.copy so that the init tensor is still used. Feedback would be appreciated. I think that propagating the input tensor if the DPS op is folded away should be ok, given that all the uses of the init tensor are replaced with the input tensor, but I might be missing something.

This PR adds folders for linalg transpose ops with only one dimension or
an identity permutation. The folding removes the `linalg.transpose` and
just propagates the input tensor. Given that this is a DPS op, I'm now
wondering if this folding is incorrect and we should instead replace the
op with a `linalg.copy` so that the init tensor is still used. Feedback
would be appreciated. I think that propagating the input tensor if the
DPS op is folded away should be ok, given that all the uses of the init
tensor are replaced with the input tensor, but I might be missing something.
@llvmbot
Copy link
Member

llvmbot commented Feb 14, 2024

@llvm/pr-subscribers-mlir

Author: Diego Caballero (dcaballe)

Changes

This PR adds folders for linalg transpose ops with only one dimension or an identity permutation. The folding removes the linalg.transpose and just propagates the input tensor. Given that this is a DPS op, I'm now wondering if this folding is incorrect and we should instead replace the op with a linalg.copy so that the init tensor is still used. Feedback would be appreciated. I think that propagating the input tensor if the DPS op is folded away should be ok, given that all the uses of the init tensor are replaced with the input tensor, but I might be missing something.


Full diff: https://github.com/llvm/llvm-project/pull/81709.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td (+2-1)
  • (modified) mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp (+16)
  • (modified) mlir/test/Dialect/Linalg/canonicalize.mlir (+34)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 751edd02288301..de9414598b0282 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -245,7 +245,7 @@ def MapOp : LinalgStructuredBase_Op<"map", [
           }
     ```
 
-    Shortened print form is available. Applies to simple maps with one 
+    Shortened print form is available. Applies to simple maps with one
     non-yield operation inside the body.
 
     The example above will be printed as:
@@ -458,6 +458,7 @@ def TransposeOp : LinalgStructuredBase_Op<"transpose", [
                              ::mlir::OperationState & odsState);
   }];
 
+  let hasFolder = 1;
   let hasCustomAssemblyFormat = 1;
   let hasVerifier = 1;
 }
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index e86b9762d8581f..2f6ab7e32e5872 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1785,6 +1785,22 @@ void TransposeOp::getEffects(
                         getDpsInits());
 }
 
+LogicalResult TransposeOp::fold(FoldAdaptor adaptor,
+                                SmallVectorImpl<OpFoldResult> &result) {
+  // Single dimension transpose.
+  if (getPermutation().size() == 0) {
+    result.push_back(getInput());
+    return success();
+  }
+  // Identity permutation.
+  if (isIdentityPermutation(getPermutation())) {
+    result.push_back(getInput());
+    return success();
+  }
+
+  return failure();
+}
+
 //===----------------------------------------------------------------------===//
 // BroadcastOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 052dc367ca6779..5bc4cb82f8cdbc 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -1017,3 +1017,37 @@ func.func @canonicalize_fill_to_copy_dest(%arg0 : tensor<?x?xf32>, %arg1 : tenso
   %copy = linalg.copy ins(%arg1 : tensor<?x?xf32>) outs(%fill : tensor<?x?xf32>) -> tensor<?x?xf32>
   return %copy : tensor<?x?xf32>
 }
+
+// -----
+
+func.func @transpose_1d(%input: tensor<16xf32>,
+                        %init: tensor<16xf32>) -> tensor<16xf32> {
+  %transpose = linalg.transpose
+      ins(%input:tensor<16xf32>)
+      outs(%init:tensor<16xf32>)
+      permutation = [0]
+  func.return %transpose : tensor<16xf32>
+}
+
+// CHECK-LABEL: func @transpose_1d(
+//  CHECK-SAME:     %[[INPUT:[a-zA-Z0-9]+]]: tensor<16xf32>,
+//  CHECK-SAME:     %[[INIT:[a-zA-Z0-9]+]]: tensor<16xf32>)
+//   CHECK-NOT:   linalg.transpose
+//       CHECK:   return %[[INPUT]] : tensor<16xf32>
+
+// -----
+
+func.func @transpose_identity_perm(%input: tensor<16x32x64xf32>,
+                                   %init: tensor<16x32x64xf32>) -> tensor<16x32x64xf32> {
+  %transpose = linalg.transpose
+      ins(%input:tensor<16x32x64xf32>)
+      outs(%init:tensor<16x32x64xf32>)
+      permutation = [0, 1, 2]
+  func.return %transpose : tensor<16x32x64xf32>
+}
+
+// CHECK-LABEL: func @transpose_identity_perm(
+//  CHECK-SAME:     %[[INPUT:[a-zA-Z0-9]+]]: tensor<16x32x64xf32>,
+//  CHECK-SAME:     %[[INIT:[a-zA-Z0-9]+]]: tensor<16x32x64xf32>)
+//   CHECK-NOT:   linalg.transpose
+//       CHECK:   return %[[INPUT]] : tensor<16x32x64xf32>

@llvmbot
Copy link
Member

llvmbot commented Feb 14, 2024

@llvm/pr-subscribers-mlir-linalg

Author: Diego Caballero (dcaballe)

Changes

This PR adds folders for linalg transpose ops with only one dimension or an identity permutation. The folding removes the linalg.transpose and just propagates the input tensor. Given that this is a DPS op, I'm now wondering if this folding is incorrect and we should instead replace the op with a linalg.copy so that the init tensor is still used. Feedback would be appreciated. I think that propagating the input tensor if the DPS op is folded away should be ok, given that all the uses of the init tensor are replaced with the input tensor, but I might be missing something.


Full diff: https://github.com/llvm/llvm-project/pull/81709.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td (+2-1)
  • (modified) mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp (+16)
  • (modified) mlir/test/Dialect/Linalg/canonicalize.mlir (+34)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 751edd02288301..de9414598b0282 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -245,7 +245,7 @@ def MapOp : LinalgStructuredBase_Op<"map", [
           }
     ```
 
-    Shortened print form is available. Applies to simple maps with one 
+    Shortened print form is available. Applies to simple maps with one
     non-yield operation inside the body.
 
     The example above will be printed as:
@@ -458,6 +458,7 @@ def TransposeOp : LinalgStructuredBase_Op<"transpose", [
                              ::mlir::OperationState & odsState);
   }];
 
+  let hasFolder = 1;
   let hasCustomAssemblyFormat = 1;
   let hasVerifier = 1;
 }
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index e86b9762d8581f..2f6ab7e32e5872 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1785,6 +1785,22 @@ void TransposeOp::getEffects(
                         getDpsInits());
 }
 
+LogicalResult TransposeOp::fold(FoldAdaptor adaptor,
+                                SmallVectorImpl<OpFoldResult> &result) {
+  // Single dimension transpose.
+  if (getPermutation().size() == 0) {
+    result.push_back(getInput());
+    return success();
+  }
+  // Identity permutation.
+  if (isIdentityPermutation(getPermutation())) {
+    result.push_back(getInput());
+    return success();
+  }
+
+  return failure();
+}
+
 //===----------------------------------------------------------------------===//
 // BroadcastOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 052dc367ca6779..5bc4cb82f8cdbc 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -1017,3 +1017,37 @@ func.func @canonicalize_fill_to_copy_dest(%arg0 : tensor<?x?xf32>, %arg1 : tenso
   %copy = linalg.copy ins(%arg1 : tensor<?x?xf32>) outs(%fill : tensor<?x?xf32>) -> tensor<?x?xf32>
   return %copy : tensor<?x?xf32>
 }
+
+// -----
+
+func.func @transpose_1d(%input: tensor<16xf32>,
+                        %init: tensor<16xf32>) -> tensor<16xf32> {
+  %transpose = linalg.transpose
+      ins(%input:tensor<16xf32>)
+      outs(%init:tensor<16xf32>)
+      permutation = [0]
+  func.return %transpose : tensor<16xf32>
+}
+
+// CHECK-LABEL: func @transpose_1d(
+//  CHECK-SAME:     %[[INPUT:[a-zA-Z0-9]+]]: tensor<16xf32>,
+//  CHECK-SAME:     %[[INIT:[a-zA-Z0-9]+]]: tensor<16xf32>)
+//   CHECK-NOT:   linalg.transpose
+//       CHECK:   return %[[INPUT]] : tensor<16xf32>
+
+// -----
+
+func.func @transpose_identity_perm(%input: tensor<16x32x64xf32>,
+                                   %init: tensor<16x32x64xf32>) -> tensor<16x32x64xf32> {
+  %transpose = linalg.transpose
+      ins(%input:tensor<16x32x64xf32>)
+      outs(%init:tensor<16x32x64xf32>)
+      permutation = [0, 1, 2]
+  func.return %transpose : tensor<16x32x64xf32>
+}
+
+// CHECK-LABEL: func @transpose_identity_perm(
+//  CHECK-SAME:     %[[INPUT:[a-zA-Z0-9]+]]: tensor<16x32x64xf32>,
+//  CHECK-SAME:     %[[INIT:[a-zA-Z0-9]+]]: tensor<16x32x64xf32>)
+//   CHECK-NOT:   linalg.transpose
+//       CHECK:   return %[[INPUT]] : tensor<16x32x64xf32>

@dcaballe
Copy link
Contributor Author

Kind ping :)

@dcaballe dcaballe merged commit b9a071d into llvm:main Feb 21, 2024
@dcaballe dcaballe deleted the fold-linalg-transpose branch February 21, 2024 01:40
@dcaballe dcaballe restored the fold-linalg-transpose branch April 24, 2024 13:37
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants