-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[mlir][tensor] Simplify pad-like tensor pack and unpack #92388
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Extend existing tensor patterns to simplify pad-like tensor pack/unpack into expand/collapse shape operations.
@llvm/pr-subscribers-mlir Author: Adam Siemieniuk (adam-smnk) ChangesExtend existing tensor patterns to simplify pad-like tensor pack/unpack into expand/collapse shape operations. Full diff: https://github.com/llvm/llvm-project/pull/92388.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp
index ebcb34e9ef024..5d6e3ec9756af 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp
@@ -91,7 +91,8 @@ struct SimplifyPackToExpandShape : public OpRewritePattern<PackOp> {
RankedTensorType sourceType = packOp.getSourceType();
if (failed(isPackOnInnerMostDim(rewriter, packOp)) &&
failed(isPackOn1D(rewriter, packOp, sourceType.getShape(),
- packOp.getStaticTiles()))) {
+ packOp.getStaticTiles())) &&
+ !packOp.isLikePad()) {
return failure();
}
@@ -152,7 +153,8 @@ struct SimplifyUnPackToCollapseShape : public OpRewritePattern<UnPackOp> {
RankedTensorType destType = unpackOp.getDestType();
if (failed(isUnpackOnInnerMostDim(rewriter, unpackOp)) &&
failed(isPackOn1D(rewriter, unpackOp, destType.getShape(),
- unpackOp.getStaticTiles()))) {
+ unpackOp.getStaticTiles())) &&
+ !unpackOp.isLikeUnPad()) {
return failure();
}
diff --git a/mlir/test/Dialect/Tensor/simplify-pack-unpack.mlir b/mlir/test/Dialect/Tensor/simplify-pack-unpack.mlir
index 5a2eade0ecccf..0ad93863dc501 100644
--- a/mlir/test/Dialect/Tensor/simplify-pack-unpack.mlir
+++ b/mlir/test/Dialect/Tensor/simplify-pack-unpack.mlir
@@ -266,3 +266,51 @@ func.func @unpack_16x1x1x2_to_32x1(%arg0 : tensor<16x1x1x2xf32>) -> tensor<32x1x
: tensor<16x1x1x2xf32> -> tensor<32x1xf32>
return %unpack : tensor<32x1xf32>
}
+
+// -----
+
+// CHECK-LABEL: func.func @pad_like_pack(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<32x64xf32>)
+// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1, 2], [3]] output_shape [1, 1, 32, 64] : tensor<32x64xf32> into tensor<1x1x32x64xf32>
+// CHECK: return %[[EXPANDED]] : tensor<1x1x32x64xf32>
+func.func @pad_like_pack(%arg0: tensor<32x64xf32>) -> tensor<1x1x32x64xf32> {
+ %empty = tensor.empty() : tensor<1x1x32x64xf32>
+ %0 = tensor.pack %arg0 inner_dims_pos = [0, 1] inner_tiles = [32, 64] into %empty : tensor<32x64xf32> -> tensor<1x1x32x64xf32>
+ return %0 : tensor<1x1x32x64xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @pad_like_pack_with_outer_dims_perm(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<32x64xf32>)
+// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1, 2], [3]] output_shape [1, 1, 32, 64] : tensor<32x64xf32> into tensor<1x1x32x64xf32>
+// CHECK: return %[[EXPANDED]] : tensor<1x1x32x64xf32>
+func.func @pad_like_pack_with_outer_dims_perm(%arg0: tensor<32x64xf32>) -> tensor<1x1x32x64xf32> {
+ %empty = tensor.empty() : tensor<1x1x32x64xf32>
+ %0 = tensor.pack %arg0 outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 64] into %empty : tensor<32x64xf32> -> tensor<1x1x32x64xf32>
+ return %0 : tensor<1x1x32x64xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @unpad_like_unpack(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<1x1x32x64xf32>)
+// CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1, 2], [3]] : tensor<1x1x32x64xf32> into tensor<32x64xf32>
+// CHECK: return %[[COLLAPSED]]
+func.func @unpad_like_unpack(%arg0: tensor<1x1x32x64xf32>) -> tensor<32x64xf32> {
+ %empty = tensor.empty() : tensor<32x64xf32>
+ %0 = tensor.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [32, 64] into %empty : tensor<1x1x32x64xf32> -> tensor<32x64xf32>
+ return %0 : tensor<32x64xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @unpad_like_unpack_with_outer_dims_perm(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<1x1x32x64xf32>)
+// CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1, 2], [3]] : tensor<1x1x32x64xf32> into tensor<32x64xf32>
+// CHECK: return %[[COLLAPSED]]
+func.func @unpad_like_unpack_with_outer_dims_perm(%arg0: tensor<1x1x32x64xf32>) -> tensor<32x64xf32> {
+ %empty = tensor.empty() : tensor<32x64xf32>
+ %0 = tensor.unpack %arg0 outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 64] into %empty : tensor<1x1x32x64xf32> -> tensor<32x64xf32>
+ return %0 : tensor<32x64xf32>
+}
|
@llvm/pr-subscribers-mlir-tensor Author: Adam Siemieniuk (adam-smnk) ChangesExtend existing tensor patterns to simplify pad-like tensor pack/unpack into expand/collapse shape operations. Full diff: https://github.com/llvm/llvm-project/pull/92388.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp
index ebcb34e9ef024..5d6e3ec9756af 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp
@@ -91,7 +91,8 @@ struct SimplifyPackToExpandShape : public OpRewritePattern<PackOp> {
RankedTensorType sourceType = packOp.getSourceType();
if (failed(isPackOnInnerMostDim(rewriter, packOp)) &&
failed(isPackOn1D(rewriter, packOp, sourceType.getShape(),
- packOp.getStaticTiles()))) {
+ packOp.getStaticTiles())) &&
+ !packOp.isLikePad()) {
return failure();
}
@@ -152,7 +153,8 @@ struct SimplifyUnPackToCollapseShape : public OpRewritePattern<UnPackOp> {
RankedTensorType destType = unpackOp.getDestType();
if (failed(isUnpackOnInnerMostDim(rewriter, unpackOp)) &&
failed(isPackOn1D(rewriter, unpackOp, destType.getShape(),
- unpackOp.getStaticTiles()))) {
+ unpackOp.getStaticTiles())) &&
+ !unpackOp.isLikeUnPad()) {
return failure();
}
diff --git a/mlir/test/Dialect/Tensor/simplify-pack-unpack.mlir b/mlir/test/Dialect/Tensor/simplify-pack-unpack.mlir
index 5a2eade0ecccf..0ad93863dc501 100644
--- a/mlir/test/Dialect/Tensor/simplify-pack-unpack.mlir
+++ b/mlir/test/Dialect/Tensor/simplify-pack-unpack.mlir
@@ -266,3 +266,51 @@ func.func @unpack_16x1x1x2_to_32x1(%arg0 : tensor<16x1x1x2xf32>) -> tensor<32x1x
: tensor<16x1x1x2xf32> -> tensor<32x1xf32>
return %unpack : tensor<32x1xf32>
}
+
+// -----
+
+// CHECK-LABEL: func.func @pad_like_pack(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<32x64xf32>)
+// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1, 2], [3]] output_shape [1, 1, 32, 64] : tensor<32x64xf32> into tensor<1x1x32x64xf32>
+// CHECK: return %[[EXPANDED]] : tensor<1x1x32x64xf32>
+func.func @pad_like_pack(%arg0: tensor<32x64xf32>) -> tensor<1x1x32x64xf32> {
+ %empty = tensor.empty() : tensor<1x1x32x64xf32>
+ %0 = tensor.pack %arg0 inner_dims_pos = [0, 1] inner_tiles = [32, 64] into %empty : tensor<32x64xf32> -> tensor<1x1x32x64xf32>
+ return %0 : tensor<1x1x32x64xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @pad_like_pack_with_outer_dims_perm(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<32x64xf32>)
+// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1, 2], [3]] output_shape [1, 1, 32, 64] : tensor<32x64xf32> into tensor<1x1x32x64xf32>
+// CHECK: return %[[EXPANDED]] : tensor<1x1x32x64xf32>
+func.func @pad_like_pack_with_outer_dims_perm(%arg0: tensor<32x64xf32>) -> tensor<1x1x32x64xf32> {
+ %empty = tensor.empty() : tensor<1x1x32x64xf32>
+ %0 = tensor.pack %arg0 outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 64] into %empty : tensor<32x64xf32> -> tensor<1x1x32x64xf32>
+ return %0 : tensor<1x1x32x64xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @unpad_like_unpack(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<1x1x32x64xf32>)
+// CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1, 2], [3]] : tensor<1x1x32x64xf32> into tensor<32x64xf32>
+// CHECK: return %[[COLLAPSED]]
+func.func @unpad_like_unpack(%arg0: tensor<1x1x32x64xf32>) -> tensor<32x64xf32> {
+ %empty = tensor.empty() : tensor<32x64xf32>
+ %0 = tensor.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [32, 64] into %empty : tensor<1x1x32x64xf32> -> tensor<32x64xf32>
+ return %0 : tensor<32x64xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @unpad_like_unpack_with_outer_dims_perm(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<1x1x32x64xf32>)
+// CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1, 2], [3]] : tensor<1x1x32x64xf32> into tensor<32x64xf32>
+// CHECK: return %[[COLLAPSED]]
+func.func @unpad_like_unpack_with_outer_dims_perm(%arg0: tensor<1x1x32x64xf32>) -> tensor<32x64xf32> {
+ %empty = tensor.empty() : tensor<32x64xf32>
+ %0 = tensor.unpack %arg0 outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 64] into %empty : tensor<1x1x32x64xf32> -> tensor<32x64xf32>
+ return %0 : tensor<32x64xf32>
+}
|
@chelini Any comments? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we please add a test where the introduced check fails?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks, looks good to me.
Extend existing tensor patterns to simplify pad-like tensor pack/unpack into expand/collapse shape operations.