Skip to content

[MLIR] Add pattern to fold insert_slice of extract_slice #86328

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Mar 28, 2024

Conversation

pzread
Copy link
Member

@pzread pzread commented Mar 22, 2024

Fold the tensor.insert_slice of tensor.extract_slice into tensor_extract_slice when the insert_slice simply expand some unit dims dropped by the extract_slice.

For example:

%extracted_slice = tensor.extract_slice %in[0, 0] [1, 8] [1, 1] : tensor<2x8xf32> to tensor<8xf32>
%inserted_slice = tensor.insert_slice %extracted_slice into %dest[0, 0] [1, 8] [1, 1] : tensor<8xf32> into tensor<1x8xf32>

can be folded into:

%extracted_slice = tensor.extract_slice %in[0, 0] [1, 8] [1, 1] : tensor<2x8xf32> to tensor<1x8xf32>

@pzread pzread force-pushed the insert-extract-fold branch from 77afad7 to 76ebd1d Compare March 22, 2024 19:08
@pzread pzread marked this pull request as ready for review March 22, 2024 19:08
@llvmbot
Copy link
Member

llvmbot commented Mar 22, 2024

@llvm/pr-subscribers-mlir-tensor

@llvm/pr-subscribers-mlir

Author: Jerry Wu (pzread)

Changes

Fold the tensor.insert_slice of tensor.extract_slice into tensor_extract_slice when the insert_slice simply expand some unit dims dropped by the extract_slice.

For example:

%extracted_slice = tensor.extract_slice %in[0, 0] [1, 8] [1, 1] : tensor&lt;2x8xf32&gt; to tensor&lt;8xf32&gt;
%inserted_slice = tensor.insert_slice %extracted_slice into %dest[0, 0] [1, 8] [1, 1] : tensor&lt;8xf32&gt; into tensor&lt;1x8xf32&gt;

can be folded into:

%extracted_slice = tensor.extract_slice %in[0, 0] [1, 8] [1, 1] : tensor&lt;2x8xf32&gt; to tensor&lt;1x8xf32&gt;

Full diff: https://github.com/llvm/llvm-project/pull/86328.diff

3 Files Affected:

  • (modified) mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp (+88-2)
  • (modified) mlir/lib/Dialect/Tensor/Utils/Utils.cpp (+4-2)
  • (modified) mlir/test/Dialect/Tensor/fold-tensor-subset-ops.mlir (+65)
diff --git a/mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp b/mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp
index 3b8d3708bb7314..d04d1b5eaf5c5b 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp
@@ -16,6 +16,7 @@
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Tensor/Transforms/Passes.h"
 #include "mlir/Dialect/Tensor/Transforms/Transforms.h"
+#include "mlir/Dialect/Tensor/Utils/Utils.h"
 #include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/IR/AffineMap.h"
@@ -65,6 +66,17 @@ class InsertSliceOfTransferWriteOpFolder final
   LogicalResult matchAndRewrite(tensor::InsertSliceOp insertSliceOp,
                                 PatternRewriter &rewriter) const override;
 };
+
+/// Merge insert_slice operation with extract_slice operation.
+class InsertSliceOfExtractSliceFolder final
+    : public OpRewritePattern<tensor::InsertSliceOp> {
+public:
+  using OpRewritePattern<tensor::InsertSliceOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(tensor::InsertSliceOp insertSliceOp,
+                                PatternRewriter &rewriter) const override;
+};
+
 } // namespace
 
 template <typename XferOp, typename ExtractOrInsertOp>
@@ -147,6 +159,80 @@ LogicalResult InsertSliceOfTransferWriteOpFolder::matchAndRewrite(
   return success();
 }
 
+/// Merge insert_slice operation with extract_slice operation.
+///
+/// This can be done when the insert_slice op purely expands ranks (adds unit
+/// dims) and the extrace_slice drops corresponding unit dims. For example:
+///
+/// %extracted_slice = tensor.extract_slice %in[0, 0] [1, 8] [1, 1]
+///     : tensor<2x8xf32> to tensor<8xf32>
+/// %inserted_slice = tensor.insert_slice %extracted_slice
+///     into %dest[0, 0] [1, 8] [1, 1]
+///     : tensor<8xf32> into tensor<1x8xf32>
+///
+/// can be folded into:
+///
+/// %extracted_slice = tensor.extract_slice %in[0, 0] [1, 8] [1, 1]
+///     : tensor<2x8xf32> to tensor<1x8xf32>
+LogicalResult InsertSliceOfExtractSliceFolder::matchAndRewrite(
+    tensor::InsertSliceOp insertSliceOp, PatternRewriter &rewriter) const {
+  auto extractSliceOp =
+      insertSliceOp.getSource().getDefiningOp<tensor::ExtractSliceOp>();
+  if (!extractSliceOp)
+    return failure();
+
+  // Can't fold if the extract_slice op has other users.
+  if (!extractSliceOp->hasOneUse())
+    return failure();
+
+  // Check if the insert_slice op purely expands ranks (add unit dims).
+  if (!isCastLikeInsertSliceOp(insertSliceOp))
+    return failure();
+
+  llvm::SmallBitVector extractDroppedDims = extractSliceOp.getDroppedDims();
+  llvm::SmallBitVector insertExpandedDims = insertSliceOp.getDroppedDims();
+  // Can't fold if the insert_slice op expands more dims.
+  if (extractDroppedDims.size() < insertExpandedDims.size())
+    return failure();
+
+  // Try to match the dropped unit dims to the expanded unit dims. This is done
+  // by scanning the dims of extract_slice and find the left-most one can match
+  // the dim of insert_slice. If a match is found, advance the dim of
+  // insert_slice to match the next one.
+  int64_t insertPos = 0;
+  for (int64_t extractPos = 0; extractPos < extractDroppedDims.size();
+       ++extractPos) {
+    // Matched all expanded dims.
+    if (insertPos == insertExpandedDims.size())
+      break;
+
+    bool isDropped = extractDroppedDims[extractPos];
+    bool isExpanded = insertExpandedDims[insertPos];
+    // Match if both sides drop/keep the dim. Advance and match the next dim of
+    // insert_slice.
+    if (isDropped == isExpanded) {
+      insertPos += 1;
+    } else if (!isDropped && isExpanded) {
+      // Not enough dropped unit dims to match the expanded unit dims.
+      return failure();
+    }
+    // If the dim is dropped by extract_slice and not by insert_slice, look the
+    // next dim of extract_slice to see if it can match the current dim of
+    // insert_slice.
+  }
+  // Can't match some expanded dims.
+  if (insertPos != insertExpandedDims.size())
+    return failure();
+
+  rewriter.replaceOpWithNewOp<tensor::ExtractSliceOp>(
+      insertSliceOp, insertSliceOp.getType(), extractSliceOp.getSource(),
+      extractSliceOp.getMixedOffsets(), extractSliceOp.getMixedSizes(),
+      extractSliceOp.getMixedStrides());
+  rewriter.eraseOp(extractSliceOp);
+
+  return success();
+}
+
 template <typename OpTy>
 struct InsertSliceOfInsertSliceFolder : public OpRewritePattern<OpTy> {
   using OpRewritePattern<OpTy>::OpRewritePattern;
@@ -224,8 +310,8 @@ struct InsertSliceOfInsertSliceFolder : public OpRewritePattern<OpTy> {
 void tensor::populateFoldTensorSubsetOpPatterns(RewritePatternSet &patterns) {
   populateFoldTensorSubsetIntoVectorTransferPatterns(patterns);
   patterns.add<InsertSliceOfInsertSliceFolder<tensor::InsertSliceOp>,
-               InsertSliceOfInsertSliceFolder<tensor::ParallelInsertSliceOp>>(
-      patterns.getContext());
+               InsertSliceOfInsertSliceFolder<tensor::ParallelInsertSliceOp>,
+               InsertSliceOfExtractSliceFolder>(patterns.getContext());
 }
 
 void tensor::populateFoldTensorSubsetIntoVectorTransferPatterns(
diff --git a/mlir/lib/Dialect/Tensor/Utils/Utils.cpp b/mlir/lib/Dialect/Tensor/Utils/Utils.cpp
index 186f85d2ce20a6..4bc966f2079d8a 100644
--- a/mlir/lib/Dialect/Tensor/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Tensor/Utils/Utils.cpp
@@ -142,11 +142,13 @@ mlir::tensor::getUnPackInverseSrcPerm(UnPackOp unpackOp,
 bool mlir::tensor::isCastLikeInsertSliceOp(InsertSliceOp op) {
   llvm::SmallBitVector droppedDims = op.getDroppedDims();
   int64_t srcDim = 0;
+  RankedTensorType resultType = op.getDestType();
   // Source dims and destination dims (apart from dropped dims) must have the
   // same size.
-  for (int64_t resultDim = 0; resultDim < op.getDestType().getRank();
-       ++resultDim) {
+  for (int64_t resultDim = 0; resultDim < resultType.getRank(); ++resultDim) {
     if (droppedDims.test(resultDim)) {
+      if (resultType.getDimSize(resultDim) != 1)
+        return false;
       continue;
     }
     FailureOr<bool> equalDimSize = ValueBoundsConstraintSet::areEqual(
diff --git a/mlir/test/Dialect/Tensor/fold-tensor-subset-ops.mlir b/mlir/test/Dialect/Tensor/fold-tensor-subset-ops.mlir
index f2e529b4cac950..bb1df99d4c97ee 100644
--- a/mlir/test/Dialect/Tensor/fold-tensor-subset-ops.mlir
+++ b/mlir/test/Dialect/Tensor/fold-tensor-subset-ops.mlir
@@ -390,3 +390,68 @@ func.func @parallel_insert_slice_of_insert_slice_dynamic(
   }
   return %0: tensor<12x34xf32>
 }
+
+// -----
+
+func.func @fold_casting_insert_slice_of_extract_slice(%in : tensor<?x8x2x8xf32>, %dest : tensor<8x1x8xf32>) -> tensor<8x1x8xf32> {
+  %extracted_slice = tensor.extract_slice %in[0, 0, 0, 0] [1, 8, 1, 8] [1, 1, 1, 1] : tensor<?x8x2x8xf32> to tensor<8x8xf32>
+  %inserted_slice = tensor.insert_slice %extracted_slice into %dest[0, 0, 0] [8, 1, 8] [1, 1, 1] : tensor<8x8xf32> into tensor<8x1x8xf32>
+  return %inserted_slice : tensor<8x1x8xf32>
+}
+// CHECK-LABEL: func.func @fold_casting_insert_slice_of_extract_slice(
+// CHECK-SAME:      %[[ARG0:.*]]: tensor<?x8x2x8xf32>
+// CHECK:         %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][0, 0, 0, 0] [1, 8, 1, 8] [1, 1, 1, 1]
+// CHECK-SAME:      : tensor<?x8x2x8xf32> to tensor<8x1x8xf32>
+// CHECK:         return %[[EXTRACTED_SLICE]] : tensor<8x1x8xf32>
+
+// -----
+
+func.func @fold_casting_insert_slice_of_strided_extract_slice(%in : tensor<?x8x2x8xf32>, %dest : tensor<1x4x8xf32>) -> tensor<1x4x8xf32> {
+  %extracted_slice = tensor.extract_slice %in[0, 0, 0, 0] [1, 4, 1, 8] [1, 2, 1, 1] : tensor<?x8x2x8xf32> to tensor<4x8xf32>
+  %inserted_slice = tensor.insert_slice %extracted_slice into %dest[0, 0, 0] [1, 4, 8] [1, 1, 1] : tensor<4x8xf32> into tensor<1x4x8xf32>
+  return %inserted_slice : tensor<1x4x8xf32>
+}
+// CHECK-LABEL: func.func @fold_casting_insert_slice_of_strided_extract_slice(
+// CHECK-SAME:      %[[ARG0:.*]]: tensor<?x8x2x8xf32>
+// CHECK:         %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][0, 0, 0, 0] [1, 4, 1, 8] [1, 2, 1, 1]
+// CHECK-SAME:      : tensor<?x8x2x8xf32> to tensor<1x4x8xf32>
+// CHECK:         return %[[EXTRACTED_SLICE]] : tensor<1x4x8xf32>
+
+// -----
+
+func.func @no_fold_more_unit_dims_insert_slice_of_extract_slice(%in : tensor<?x8x8xf32>, %dest : tensor<1x1x8x8xf32>) -> tensor<1x1x8x8xf32> {
+  %extracted_slice = tensor.extract_slice %in[0, 0, 0] [1, 8, 8] [1, 1, 1] : tensor<?x8x8xf32> to tensor<8x8xf32>
+  %inserted_slice = tensor.insert_slice %extracted_slice into %dest[0, 0, 0, 0] [1, 1, 8, 8] [1, 1, 1, 1] : tensor<8x8xf32> into tensor<1x1x8x8xf32>
+  return %inserted_slice : tensor<1x1x8x8xf32>
+}
+// CHECK-LABEL: func.func @no_fold_more_unit_dims_insert_slice_of_extract_slice(
+// CHECK-SAME:      %[[ARG0:.*]]: tensor<?x8x8xf32>
+// CHECK:         %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]]
+// CHECK:         %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[EXTRACTED_SLICE]]
+// CHECK:         return %[[INSERTED_SLICE]] : tensor<1x1x8x8xf32>
+
+// -----
+
+func.func @no_fold_strided_insert_slice_of_extract_slice(%in : tensor<?x8x2x8xf32>, %dest : tensor<1x4x4xf32>) -> tensor<1x4x4xf32> {
+  %extracted_slice = tensor.extract_slice %in[0, 0, 0, 0] [1, 8, 1, 8] [1, 1, 1, 1] : tensor<?x8x2x8xf32> to tensor<8x8xf32>
+  %inserted_slice = tensor.insert_slice %extracted_slice into %dest[0, 0, 0] [1, 8, 8] [1, 2, 2] : tensor<8x8xf32> into tensor<1x4x4xf32>
+  return %inserted_slice : tensor<1x4x4xf32>
+}
+// CHECK-LABEL: func.func @no_fold_strided_insert_slice_of_extract_slice(
+// CHECK-SAME:      %[[ARG0:.*]]: tensor<?x8x2x8xf32>
+// CHECK:         %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]]
+// CHECK:         %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[EXTRACTED_SLICE]]
+// CHECK:         return %[[INSERTED_SLICE]] : tensor<1x4x4xf32>
+
+// -----
+
+func.func @no_fold_non_casting_insert_slice_of_extract_slice(%in : tensor<1x1x1x8x8xf32>, %dest : tensor<2x8x8xf32>) -> tensor<2x8x8xf32> {
+  %extracted_slice = tensor.extract_slice %in[0, 0, 0, 0, 0] [1, 1, 1, 8, 8] [1, 1, 1, 1, 1] : tensor<1x1x1x8x8xf32> to tensor<8x8xf32>
+  %inserted_slice = tensor.insert_slice %extracted_slice into %dest[0, 0, 0] [1, 8, 8] [1, 1, 1] : tensor<8x8xf32> into tensor<2x8x8xf32>
+  return %inserted_slice : tensor<2x8x8xf32>
+}
+// CHECK-LABEL: func.func @no_fold_non_casting_insert_slice_of_extract_slice(
+// CHECK-SAME:      %[[ARG0:.*]]: tensor<1x1x1x8x8xf32>
+// CHECK:         %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]]
+// CHECK:         %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[EXTRACTED_SLICE]]
+// CHECK:         return %[[INSERTED_SLICE]] : tensor<2x8x8xf32>

Copy link

✅ With the latest revision this PR passed the Python code formatter.

Copy link

✅ With the latest revision this PR passed the C/C++ code formatter.

@pzread pzread force-pushed the insert-extract-fold branch from 76ebd1d to f4d6bf2 Compare March 22, 2024 19:10
@pzread pzread changed the title Add pattern to fold insert_slice of extract_slice [MLIR] Add pattern to fold insert_slice of extract_slice Mar 22, 2024
@pzread pzread force-pushed the insert-extract-fold branch from f4d6bf2 to dceb7b7 Compare March 22, 2024 21:09
Copy link
Member

@matthias-springer matthias-springer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this pattern would fit best next to DropRedundantInsertSliceRankExpansion (maybe rename populateDropRedundantInsertSliceRankExpansionPatterns).

@pzread pzread force-pushed the insert-extract-fold branch 3 times, most recently from 370193c to 1621d6d Compare March 26, 2024 17:50
@pzread
Copy link
Member Author

pzread commented Mar 26, 2024

I think this pattern would fit best next to DropRedundantInsertSliceRankExpansion (maybe rename populateDropRedundantInsertSliceRankExpansionPatterns).

Done.

if (droppedDims.test(resultDim)) {
// InsertSlice may expand unit dimensions that result from inserting a
// size-1 slice into a non-size-1 result dimension.
if (resultType.getDimSize(resultDim) != 1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this needed? If a dimension is "dropped", it must be a unit dimension.

Copy link
Member Author

@pzread pzread Mar 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is to check if the dropped unit dim is inserted into non-unit dim. For example:

tensor.insert_slice %x into %y[0, 0, 0] [1, 8, 8] [1, 1, 1] : tensor<8x8xf32> into tensor<2x8x8xf32>

in this case, the first dim of the source 1x8x8xf32 is seen as a dropped unit dim, but it is inserted into 2x8x8xf32. I think in this case we don't consider it as a cast-like op as the source (1)x8x8xf32 != the destination 2x8x8xf32

@pzread pzread force-pushed the insert-extract-fold branch from 1621d6d to 0db53ae Compare March 27, 2024 20:58
Copy link
Contributor

@hanhanW hanhanW left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants