-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[MLIR] Add pattern to fold insert_slice of extract_slice #86328
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
77afad7
to
76ebd1d
Compare
@llvm/pr-subscribers-mlir-tensor @llvm/pr-subscribers-mlir Author: Jerry Wu (pzread) ChangesFold the For example: %extracted_slice = tensor.extract_slice %in[0, 0] [1, 8] [1, 1] : tensor<2x8xf32> to tensor<8xf32>
%inserted_slice = tensor.insert_slice %extracted_slice into %dest[0, 0] [1, 8] [1, 1] : tensor<8xf32> into tensor<1x8xf32> can be folded into: %extracted_slice = tensor.extract_slice %in[0, 0] [1, 8] [1, 1] : tensor<2x8xf32> to tensor<1x8xf32> Full diff: https://github.com/llvm/llvm-project/pull/86328.diff 3 Files Affected:
diff --git a/mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp b/mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp
index 3b8d3708bb7314..d04d1b5eaf5c5b 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp
@@ -16,6 +16,7 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/Transforms/Passes.h"
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
+#include "mlir/Dialect/Tensor/Utils/Utils.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/AffineMap.h"
@@ -65,6 +66,17 @@ class InsertSliceOfTransferWriteOpFolder final
LogicalResult matchAndRewrite(tensor::InsertSliceOp insertSliceOp,
PatternRewriter &rewriter) const override;
};
+
+/// Merge insert_slice operation with extract_slice operation.
+class InsertSliceOfExtractSliceFolder final
+ : public OpRewritePattern<tensor::InsertSliceOp> {
+public:
+ using OpRewritePattern<tensor::InsertSliceOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(tensor::InsertSliceOp insertSliceOp,
+ PatternRewriter &rewriter) const override;
+};
+
} // namespace
template <typename XferOp, typename ExtractOrInsertOp>
@@ -147,6 +159,80 @@ LogicalResult InsertSliceOfTransferWriteOpFolder::matchAndRewrite(
return success();
}
+/// Merge insert_slice operation with extract_slice operation.
+///
+/// This can be done when the insert_slice op purely expands ranks (adds unit
+/// dims) and the extrace_slice drops corresponding unit dims. For example:
+///
+/// %extracted_slice = tensor.extract_slice %in[0, 0] [1, 8] [1, 1]
+/// : tensor<2x8xf32> to tensor<8xf32>
+/// %inserted_slice = tensor.insert_slice %extracted_slice
+/// into %dest[0, 0] [1, 8] [1, 1]
+/// : tensor<8xf32> into tensor<1x8xf32>
+///
+/// can be folded into:
+///
+/// %extracted_slice = tensor.extract_slice %in[0, 0] [1, 8] [1, 1]
+/// : tensor<2x8xf32> to tensor<1x8xf32>
+LogicalResult InsertSliceOfExtractSliceFolder::matchAndRewrite(
+ tensor::InsertSliceOp insertSliceOp, PatternRewriter &rewriter) const {
+ auto extractSliceOp =
+ insertSliceOp.getSource().getDefiningOp<tensor::ExtractSliceOp>();
+ if (!extractSliceOp)
+ return failure();
+
+ // Can't fold if the extract_slice op has other users.
+ if (!extractSliceOp->hasOneUse())
+ return failure();
+
+ // Check if the insert_slice op purely expands ranks (add unit dims).
+ if (!isCastLikeInsertSliceOp(insertSliceOp))
+ return failure();
+
+ llvm::SmallBitVector extractDroppedDims = extractSliceOp.getDroppedDims();
+ llvm::SmallBitVector insertExpandedDims = insertSliceOp.getDroppedDims();
+ // Can't fold if the insert_slice op expands more dims.
+ if (extractDroppedDims.size() < insertExpandedDims.size())
+ return failure();
+
+ // Try to match the dropped unit dims to the expanded unit dims. This is done
+ // by scanning the dims of extract_slice and find the left-most one can match
+ // the dim of insert_slice. If a match is found, advance the dim of
+ // insert_slice to match the next one.
+ int64_t insertPos = 0;
+ for (int64_t extractPos = 0; extractPos < extractDroppedDims.size();
+ ++extractPos) {
+ // Matched all expanded dims.
+ if (insertPos == insertExpandedDims.size())
+ break;
+
+ bool isDropped = extractDroppedDims[extractPos];
+ bool isExpanded = insertExpandedDims[insertPos];
+ // Match if both sides drop/keep the dim. Advance and match the next dim of
+ // insert_slice.
+ if (isDropped == isExpanded) {
+ insertPos += 1;
+ } else if (!isDropped && isExpanded) {
+ // Not enough dropped unit dims to match the expanded unit dims.
+ return failure();
+ }
+ // If the dim is dropped by extract_slice and not by insert_slice, look the
+ // next dim of extract_slice to see if it can match the current dim of
+ // insert_slice.
+ }
+ // Can't match some expanded dims.
+ if (insertPos != insertExpandedDims.size())
+ return failure();
+
+ rewriter.replaceOpWithNewOp<tensor::ExtractSliceOp>(
+ insertSliceOp, insertSliceOp.getType(), extractSliceOp.getSource(),
+ extractSliceOp.getMixedOffsets(), extractSliceOp.getMixedSizes(),
+ extractSliceOp.getMixedStrides());
+ rewriter.eraseOp(extractSliceOp);
+
+ return success();
+}
+
template <typename OpTy>
struct InsertSliceOfInsertSliceFolder : public OpRewritePattern<OpTy> {
using OpRewritePattern<OpTy>::OpRewritePattern;
@@ -224,8 +310,8 @@ struct InsertSliceOfInsertSliceFolder : public OpRewritePattern<OpTy> {
void tensor::populateFoldTensorSubsetOpPatterns(RewritePatternSet &patterns) {
populateFoldTensorSubsetIntoVectorTransferPatterns(patterns);
patterns.add<InsertSliceOfInsertSliceFolder<tensor::InsertSliceOp>,
- InsertSliceOfInsertSliceFolder<tensor::ParallelInsertSliceOp>>(
- patterns.getContext());
+ InsertSliceOfInsertSliceFolder<tensor::ParallelInsertSliceOp>,
+ InsertSliceOfExtractSliceFolder>(patterns.getContext());
}
void tensor::populateFoldTensorSubsetIntoVectorTransferPatterns(
diff --git a/mlir/lib/Dialect/Tensor/Utils/Utils.cpp b/mlir/lib/Dialect/Tensor/Utils/Utils.cpp
index 186f85d2ce20a6..4bc966f2079d8a 100644
--- a/mlir/lib/Dialect/Tensor/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Tensor/Utils/Utils.cpp
@@ -142,11 +142,13 @@ mlir::tensor::getUnPackInverseSrcPerm(UnPackOp unpackOp,
bool mlir::tensor::isCastLikeInsertSliceOp(InsertSliceOp op) {
llvm::SmallBitVector droppedDims = op.getDroppedDims();
int64_t srcDim = 0;
+ RankedTensorType resultType = op.getDestType();
// Source dims and destination dims (apart from dropped dims) must have the
// same size.
- for (int64_t resultDim = 0; resultDim < op.getDestType().getRank();
- ++resultDim) {
+ for (int64_t resultDim = 0; resultDim < resultType.getRank(); ++resultDim) {
if (droppedDims.test(resultDim)) {
+ if (resultType.getDimSize(resultDim) != 1)
+ return false;
continue;
}
FailureOr<bool> equalDimSize = ValueBoundsConstraintSet::areEqual(
diff --git a/mlir/test/Dialect/Tensor/fold-tensor-subset-ops.mlir b/mlir/test/Dialect/Tensor/fold-tensor-subset-ops.mlir
index f2e529b4cac950..bb1df99d4c97ee 100644
--- a/mlir/test/Dialect/Tensor/fold-tensor-subset-ops.mlir
+++ b/mlir/test/Dialect/Tensor/fold-tensor-subset-ops.mlir
@@ -390,3 +390,68 @@ func.func @parallel_insert_slice_of_insert_slice_dynamic(
}
return %0: tensor<12x34xf32>
}
+
+// -----
+
+func.func @fold_casting_insert_slice_of_extract_slice(%in : tensor<?x8x2x8xf32>, %dest : tensor<8x1x8xf32>) -> tensor<8x1x8xf32> {
+ %extracted_slice = tensor.extract_slice %in[0, 0, 0, 0] [1, 8, 1, 8] [1, 1, 1, 1] : tensor<?x8x2x8xf32> to tensor<8x8xf32>
+ %inserted_slice = tensor.insert_slice %extracted_slice into %dest[0, 0, 0] [8, 1, 8] [1, 1, 1] : tensor<8x8xf32> into tensor<8x1x8xf32>
+ return %inserted_slice : tensor<8x1x8xf32>
+}
+// CHECK-LABEL: func.func @fold_casting_insert_slice_of_extract_slice(
+// CHECK-SAME: %[[ARG0:.*]]: tensor<?x8x2x8xf32>
+// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][0, 0, 0, 0] [1, 8, 1, 8] [1, 1, 1, 1]
+// CHECK-SAME: : tensor<?x8x2x8xf32> to tensor<8x1x8xf32>
+// CHECK: return %[[EXTRACTED_SLICE]] : tensor<8x1x8xf32>
+
+// -----
+
+func.func @fold_casting_insert_slice_of_strided_extract_slice(%in : tensor<?x8x2x8xf32>, %dest : tensor<1x4x8xf32>) -> tensor<1x4x8xf32> {
+ %extracted_slice = tensor.extract_slice %in[0, 0, 0, 0] [1, 4, 1, 8] [1, 2, 1, 1] : tensor<?x8x2x8xf32> to tensor<4x8xf32>
+ %inserted_slice = tensor.insert_slice %extracted_slice into %dest[0, 0, 0] [1, 4, 8] [1, 1, 1] : tensor<4x8xf32> into tensor<1x4x8xf32>
+ return %inserted_slice : tensor<1x4x8xf32>
+}
+// CHECK-LABEL: func.func @fold_casting_insert_slice_of_strided_extract_slice(
+// CHECK-SAME: %[[ARG0:.*]]: tensor<?x8x2x8xf32>
+// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][0, 0, 0, 0] [1, 4, 1, 8] [1, 2, 1, 1]
+// CHECK-SAME: : tensor<?x8x2x8xf32> to tensor<1x4x8xf32>
+// CHECK: return %[[EXTRACTED_SLICE]] : tensor<1x4x8xf32>
+
+// -----
+
+func.func @no_fold_more_unit_dims_insert_slice_of_extract_slice(%in : tensor<?x8x8xf32>, %dest : tensor<1x1x8x8xf32>) -> tensor<1x1x8x8xf32> {
+ %extracted_slice = tensor.extract_slice %in[0, 0, 0] [1, 8, 8] [1, 1, 1] : tensor<?x8x8xf32> to tensor<8x8xf32>
+ %inserted_slice = tensor.insert_slice %extracted_slice into %dest[0, 0, 0, 0] [1, 1, 8, 8] [1, 1, 1, 1] : tensor<8x8xf32> into tensor<1x1x8x8xf32>
+ return %inserted_slice : tensor<1x1x8x8xf32>
+}
+// CHECK-LABEL: func.func @no_fold_more_unit_dims_insert_slice_of_extract_slice(
+// CHECK-SAME: %[[ARG0:.*]]: tensor<?x8x8xf32>
+// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]]
+// CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[EXTRACTED_SLICE]]
+// CHECK: return %[[INSERTED_SLICE]] : tensor<1x1x8x8xf32>
+
+// -----
+
+func.func @no_fold_strided_insert_slice_of_extract_slice(%in : tensor<?x8x2x8xf32>, %dest : tensor<1x4x4xf32>) -> tensor<1x4x4xf32> {
+ %extracted_slice = tensor.extract_slice %in[0, 0, 0, 0] [1, 8, 1, 8] [1, 1, 1, 1] : tensor<?x8x2x8xf32> to tensor<8x8xf32>
+ %inserted_slice = tensor.insert_slice %extracted_slice into %dest[0, 0, 0] [1, 8, 8] [1, 2, 2] : tensor<8x8xf32> into tensor<1x4x4xf32>
+ return %inserted_slice : tensor<1x4x4xf32>
+}
+// CHECK-LABEL: func.func @no_fold_strided_insert_slice_of_extract_slice(
+// CHECK-SAME: %[[ARG0:.*]]: tensor<?x8x2x8xf32>
+// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]]
+// CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[EXTRACTED_SLICE]]
+// CHECK: return %[[INSERTED_SLICE]] : tensor<1x4x4xf32>
+
+// -----
+
+func.func @no_fold_non_casting_insert_slice_of_extract_slice(%in : tensor<1x1x1x8x8xf32>, %dest : tensor<2x8x8xf32>) -> tensor<2x8x8xf32> {
+ %extracted_slice = tensor.extract_slice %in[0, 0, 0, 0, 0] [1, 1, 1, 8, 8] [1, 1, 1, 1, 1] : tensor<1x1x1x8x8xf32> to tensor<8x8xf32>
+ %inserted_slice = tensor.insert_slice %extracted_slice into %dest[0, 0, 0] [1, 8, 8] [1, 1, 1] : tensor<8x8xf32> into tensor<2x8x8xf32>
+ return %inserted_slice : tensor<2x8x8xf32>
+}
+// CHECK-LABEL: func.func @no_fold_non_casting_insert_slice_of_extract_slice(
+// CHECK-SAME: %[[ARG0:.*]]: tensor<1x1x1x8x8xf32>
+// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]]
+// CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[EXTRACTED_SLICE]]
+// CHECK: return %[[INSERTED_SLICE]] : tensor<2x8x8xf32>
|
✅ With the latest revision this PR passed the Python code formatter. |
✅ With the latest revision this PR passed the C/C++ code formatter. |
76ebd1d
to
f4d6bf2
Compare
f4d6bf2
to
dceb7b7
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this pattern would fit best next to DropRedundantInsertSliceRankExpansion
(maybe rename populateDropRedundantInsertSliceRankExpansionPatterns
).
370193c
to
1621d6d
Compare
Done. |
if (droppedDims.test(resultDim)) { | ||
// InsertSlice may expand unit dimensions that result from inserting a | ||
// size-1 slice into a non-size-1 result dimension. | ||
if (resultType.getDimSize(resultDim) != 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this needed? If a dimension is "dropped", it must be a unit dimension.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is to check if the dropped unit dim is inserted into non-unit dim. For example:
tensor.insert_slice %x into %y[0, 0, 0] [1, 8, 8] [1, 1, 1] : tensor<8x8xf32> into tensor<2x8x8xf32>
in this case, the first dim of the source 1x8x8xf32
is seen as a dropped unit dim, but it is inserted into 2x8x8xf32
. I think in this case we don't consider it as a cast-like op as the source (1)x8x8xf32
!= the destination 2x8x8xf32
mlir/lib/Dialect/Tensor/Transforms/MergeConsecutiveInsertExtractSlicePatterns.cpp
Outdated
Show resolved
Hide resolved
1621d6d
to
0db53ae
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks
Fold the
tensor.insert_slice
oftensor.extract_slice
intotensor_extract_slice
when theinsert_slice
simply expand some unit dims dropped by theextract_slice
.For example:
can be folded into: