Skip to content

[mlir][tensor] Simplify pad-like tensor pack and unpack #92388

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
May 24, 2024

Conversation

adam-smnk
Copy link
Contributor

Extend existing tensor patterns to simplify pad-like tensor pack/unpack into expand/collapse shape operations.

Extend existing tensor patterns to simplify pad-like tensor
pack/unpack into expand/collapse shape operations.
@llvmbot
Copy link
Member

llvmbot commented May 16, 2024

@llvm/pr-subscribers-mlir

Author: Adam Siemieniuk (adam-smnk)

Changes

Extend existing tensor patterns to simplify pad-like tensor pack/unpack into expand/collapse shape operations.


Full diff: https://github.com/llvm/llvm-project/pull/92388.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp (+4-2)
  • (modified) mlir/test/Dialect/Tensor/simplify-pack-unpack.mlir (+48)
diff --git a/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp
index ebcb34e9ef024..5d6e3ec9756af 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp
@@ -91,7 +91,8 @@ struct SimplifyPackToExpandShape : public OpRewritePattern<PackOp> {
     RankedTensorType sourceType = packOp.getSourceType();
     if (failed(isPackOnInnerMostDim(rewriter, packOp)) &&
         failed(isPackOn1D(rewriter, packOp, sourceType.getShape(),
-                          packOp.getStaticTiles()))) {
+                          packOp.getStaticTiles())) &&
+        !packOp.isLikePad()) {
       return failure();
     }
 
@@ -152,7 +153,8 @@ struct SimplifyUnPackToCollapseShape : public OpRewritePattern<UnPackOp> {
     RankedTensorType destType = unpackOp.getDestType();
     if (failed(isUnpackOnInnerMostDim(rewriter, unpackOp)) &&
         failed(isPackOn1D(rewriter, unpackOp, destType.getShape(),
-                          unpackOp.getStaticTiles()))) {
+                          unpackOp.getStaticTiles())) &&
+        !unpackOp.isLikeUnPad()) {
       return failure();
     }
 
diff --git a/mlir/test/Dialect/Tensor/simplify-pack-unpack.mlir b/mlir/test/Dialect/Tensor/simplify-pack-unpack.mlir
index 5a2eade0ecccf..0ad93863dc501 100644
--- a/mlir/test/Dialect/Tensor/simplify-pack-unpack.mlir
+++ b/mlir/test/Dialect/Tensor/simplify-pack-unpack.mlir
@@ -266,3 +266,51 @@ func.func @unpack_16x1x1x2_to_32x1(%arg0 : tensor<16x1x1x2xf32>) -> tensor<32x1x
     : tensor<16x1x1x2xf32> -> tensor<32x1xf32>
   return %unpack : tensor<32x1xf32>
 }
+
+// -----
+
+// CHECK-LABEL: func.func @pad_like_pack(
+// CHECK-SAME:    %[[ARG0:.+]]: tensor<32x64xf32>)
+// CHECK:         %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1, 2], [3]] output_shape [1, 1, 32, 64] : tensor<32x64xf32> into tensor<1x1x32x64xf32>
+// CHECK:         return %[[EXPANDED]] : tensor<1x1x32x64xf32>
+func.func @pad_like_pack(%arg0: tensor<32x64xf32>) -> tensor<1x1x32x64xf32> {
+  %empty = tensor.empty() : tensor<1x1x32x64xf32>
+  %0 = tensor.pack %arg0 inner_dims_pos = [0, 1] inner_tiles = [32, 64] into %empty : tensor<32x64xf32> -> tensor<1x1x32x64xf32>
+  return %0 : tensor<1x1x32x64xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @pad_like_pack_with_outer_dims_perm(
+// CHECK-SAME:    %[[ARG0:.+]]: tensor<32x64xf32>)
+// CHECK:         %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1, 2], [3]] output_shape [1, 1, 32, 64] : tensor<32x64xf32> into tensor<1x1x32x64xf32>
+// CHECK:         return %[[EXPANDED]] : tensor<1x1x32x64xf32>
+func.func @pad_like_pack_with_outer_dims_perm(%arg0: tensor<32x64xf32>) -> tensor<1x1x32x64xf32> {
+  %empty = tensor.empty() : tensor<1x1x32x64xf32>
+  %0 = tensor.pack %arg0 outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 64] into %empty : tensor<32x64xf32> -> tensor<1x1x32x64xf32>
+  return %0 : tensor<1x1x32x64xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @unpad_like_unpack(
+// CHECK-SAME:    %[[ARG0:.+]]: tensor<1x1x32x64xf32>)
+// CHECK:         %[[COLLAPSED:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1, 2], [3]] : tensor<1x1x32x64xf32> into tensor<32x64xf32>
+// CHECK:         return %[[COLLAPSED]]
+func.func @unpad_like_unpack(%arg0: tensor<1x1x32x64xf32>) -> tensor<32x64xf32> {
+  %empty = tensor.empty() : tensor<32x64xf32>
+  %0 = tensor.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [32, 64] into %empty : tensor<1x1x32x64xf32> -> tensor<32x64xf32>
+  return %0 : tensor<32x64xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @unpad_like_unpack_with_outer_dims_perm(
+// CHECK-SAME:    %[[ARG0:.+]]: tensor<1x1x32x64xf32>)
+// CHECK:         %[[COLLAPSED:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1, 2], [3]] : tensor<1x1x32x64xf32> into tensor<32x64xf32>
+// CHECK:         return %[[COLLAPSED]]
+func.func @unpad_like_unpack_with_outer_dims_perm(%arg0: tensor<1x1x32x64xf32>) -> tensor<32x64xf32> {
+  %empty = tensor.empty() : tensor<32x64xf32>
+  %0 = tensor.unpack %arg0 outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 64] into %empty : tensor<1x1x32x64xf32> -> tensor<32x64xf32>
+  return %0 : tensor<32x64xf32>
+}

@llvmbot
Copy link
Member

llvmbot commented May 16, 2024

@llvm/pr-subscribers-mlir-tensor

Author: Adam Siemieniuk (adam-smnk)

Changes

Extend existing tensor patterns to simplify pad-like tensor pack/unpack into expand/collapse shape operations.


Full diff: https://github.com/llvm/llvm-project/pull/92388.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp (+4-2)
  • (modified) mlir/test/Dialect/Tensor/simplify-pack-unpack.mlir (+48)
diff --git a/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp
index ebcb34e9ef024..5d6e3ec9756af 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp
@@ -91,7 +91,8 @@ struct SimplifyPackToExpandShape : public OpRewritePattern<PackOp> {
     RankedTensorType sourceType = packOp.getSourceType();
     if (failed(isPackOnInnerMostDim(rewriter, packOp)) &&
         failed(isPackOn1D(rewriter, packOp, sourceType.getShape(),
-                          packOp.getStaticTiles()))) {
+                          packOp.getStaticTiles())) &&
+        !packOp.isLikePad()) {
       return failure();
     }
 
@@ -152,7 +153,8 @@ struct SimplifyUnPackToCollapseShape : public OpRewritePattern<UnPackOp> {
     RankedTensorType destType = unpackOp.getDestType();
     if (failed(isUnpackOnInnerMostDim(rewriter, unpackOp)) &&
         failed(isPackOn1D(rewriter, unpackOp, destType.getShape(),
-                          unpackOp.getStaticTiles()))) {
+                          unpackOp.getStaticTiles())) &&
+        !unpackOp.isLikeUnPad()) {
       return failure();
     }
 
diff --git a/mlir/test/Dialect/Tensor/simplify-pack-unpack.mlir b/mlir/test/Dialect/Tensor/simplify-pack-unpack.mlir
index 5a2eade0ecccf..0ad93863dc501 100644
--- a/mlir/test/Dialect/Tensor/simplify-pack-unpack.mlir
+++ b/mlir/test/Dialect/Tensor/simplify-pack-unpack.mlir
@@ -266,3 +266,51 @@ func.func @unpack_16x1x1x2_to_32x1(%arg0 : tensor<16x1x1x2xf32>) -> tensor<32x1x
     : tensor<16x1x1x2xf32> -> tensor<32x1xf32>
   return %unpack : tensor<32x1xf32>
 }
+
+// -----
+
+// CHECK-LABEL: func.func @pad_like_pack(
+// CHECK-SAME:    %[[ARG0:.+]]: tensor<32x64xf32>)
+// CHECK:         %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1, 2], [3]] output_shape [1, 1, 32, 64] : tensor<32x64xf32> into tensor<1x1x32x64xf32>
+// CHECK:         return %[[EXPANDED]] : tensor<1x1x32x64xf32>
+func.func @pad_like_pack(%arg0: tensor<32x64xf32>) -> tensor<1x1x32x64xf32> {
+  %empty = tensor.empty() : tensor<1x1x32x64xf32>
+  %0 = tensor.pack %arg0 inner_dims_pos = [0, 1] inner_tiles = [32, 64] into %empty : tensor<32x64xf32> -> tensor<1x1x32x64xf32>
+  return %0 : tensor<1x1x32x64xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @pad_like_pack_with_outer_dims_perm(
+// CHECK-SAME:    %[[ARG0:.+]]: tensor<32x64xf32>)
+// CHECK:         %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1, 2], [3]] output_shape [1, 1, 32, 64] : tensor<32x64xf32> into tensor<1x1x32x64xf32>
+// CHECK:         return %[[EXPANDED]] : tensor<1x1x32x64xf32>
+func.func @pad_like_pack_with_outer_dims_perm(%arg0: tensor<32x64xf32>) -> tensor<1x1x32x64xf32> {
+  %empty = tensor.empty() : tensor<1x1x32x64xf32>
+  %0 = tensor.pack %arg0 outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 64] into %empty : tensor<32x64xf32> -> tensor<1x1x32x64xf32>
+  return %0 : tensor<1x1x32x64xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @unpad_like_unpack(
+// CHECK-SAME:    %[[ARG0:.+]]: tensor<1x1x32x64xf32>)
+// CHECK:         %[[COLLAPSED:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1, 2], [3]] : tensor<1x1x32x64xf32> into tensor<32x64xf32>
+// CHECK:         return %[[COLLAPSED]]
+func.func @unpad_like_unpack(%arg0: tensor<1x1x32x64xf32>) -> tensor<32x64xf32> {
+  %empty = tensor.empty() : tensor<32x64xf32>
+  %0 = tensor.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [32, 64] into %empty : tensor<1x1x32x64xf32> -> tensor<32x64xf32>
+  return %0 : tensor<32x64xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @unpad_like_unpack_with_outer_dims_perm(
+// CHECK-SAME:    %[[ARG0:.+]]: tensor<1x1x32x64xf32>)
+// CHECK:         %[[COLLAPSED:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1, 2], [3]] : tensor<1x1x32x64xf32> into tensor<32x64xf32>
+// CHECK:         return %[[COLLAPSED]]
+func.func @unpad_like_unpack_with_outer_dims_perm(%arg0: tensor<1x1x32x64xf32>) -> tensor<32x64xf32> {
+  %empty = tensor.empty() : tensor<32x64xf32>
+  %0 = tensor.unpack %arg0 outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 64] into %empty : tensor<1x1x32x64xf32> -> tensor<32x64xf32>
+  return %0 : tensor<32x64xf32>
+}

@hanhanW hanhanW requested a review from chelini May 16, 2024 16:50
@adam-smnk
Copy link
Contributor Author

@chelini Any comments?

Copy link
Contributor

@chelini chelini left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we please add a test where the introduced check fails?

Copy link
Contributor

@chelini chelini left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks, looks good to me.

@adam-smnk adam-smnk merged commit a79a0c5 into llvm:main May 24, 2024
7 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants