Skip to content

Commit f566b07

Browse files
author
Jerry Wu
authored
[MLIR] Add pattern to fold insert_slice of extract_slice (#86328)
Fold the `tensor.insert_slice` of `tensor.extract_slice` into `tensor_extract_slice` when the `insert_slice` simply expand some unit dims dropped by the `extract_slice`.
1 parent 94b5c11 commit f566b07

File tree

3 files changed

+168
-6
lines changed

3 files changed

+168
-6
lines changed

mlir/lib/Dialect/Tensor/Transforms/MergeConsecutiveInsertExtractSlicePatterns.cpp

Lines changed: 97 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -78,12 +78,12 @@ struct MergeConsecutiveInsertSlice : public OpRewritePattern<OpTy> {
7878
}
7979
};
8080

81-
/// Drop redundant rank expansion. I.e., rank expansions that are directly
82-
/// followed by rank reductions. E.g.:
81+
/// Drop redundant rank expansion of insert_slice that are directly followed
82+
/// by extract_slice. E.g.:
8383
/// %0 = tensor.insert_slice ... : tensor<5x10xf32> into tensor<1x1x5x10xf32>
8484
/// %1 = tensor.extract_slice %0[0, 0, 2, 3] [1, 1, 2, 2] [1, 1, 1, 1]
8585
/// : tensor<1x1x5x10xf32> to tensor<2x2xf32>
86-
struct DropRedundantInsertSliceRankExpansion
86+
struct DropRedundantRankExpansionOnExtractSliceOfInsertSlice
8787
: public OpRewritePattern<ExtractSliceOp> {
8888
using OpRewritePattern::OpRewritePattern;
8989

@@ -134,6 +134,97 @@ struct DropRedundantInsertSliceRankExpansion
134134
return success();
135135
}
136136
};
137+
138+
/// Drop redundant rank expansion of insert_slice that direclty follows
139+
/// extract_slice.
140+
///
141+
/// This can be done when the insert_slice op purely expands ranks (adds unit
142+
/// dims) and the extrace_slice drops corresponding unit dims. For example:
143+
///
144+
/// %extracted_slice = tensor.extract_slice %in[0, 0] [1, 8] [1, 1]
145+
/// : tensor<2x8xf32> to tensor<8xf32>
146+
/// %inserted_slice = tensor.insert_slice %extracted_slice
147+
/// into %dest[0, 0] [1, 8] [1, 1]
148+
/// : tensor<8xf32> into tensor<1x8xf32>
149+
///
150+
/// can be folded into:
151+
///
152+
/// %extracted_slice = tensor.extract_slice %in[0, 0] [1, 8] [1, 1]
153+
/// : tensor<2x8xf32> to tensor<1x8xf32>
154+
struct DropRedundantRankExpansionOnInsertSliceOfExtractSlice final
155+
: public OpRewritePattern<tensor::InsertSliceOp> {
156+
using OpRewritePattern<tensor::InsertSliceOp>::OpRewritePattern;
157+
158+
LogicalResult matchAndRewrite(tensor::InsertSliceOp insertSliceOp,
159+
PatternRewriter &rewriter) const {
160+
auto extractSliceOp =
161+
insertSliceOp.getSource().getDefiningOp<tensor::ExtractSliceOp>();
162+
if (!extractSliceOp) {
163+
return rewriter.notifyMatchFailure(insertSliceOp,
164+
"source is not extract_slice");
165+
}
166+
167+
// Can't fold if the extract_slice op has other users.
168+
if (!extractSliceOp->hasOneUse()) {
169+
return rewriter.notifyMatchFailure(insertSliceOp,
170+
"source has multi-uses");
171+
}
172+
173+
// Check if the insert_slice op purely expands ranks (add unit dims).
174+
if (!isCastLikeInsertSliceOp(insertSliceOp)) {
175+
return rewriter.notifyMatchFailure(insertSliceOp,
176+
"insert_slice is not cast-like");
177+
}
178+
179+
llvm::SmallBitVector extractDroppedDims = extractSliceOp.getDroppedDims();
180+
llvm::SmallBitVector insertDroppedDims = insertSliceOp.getDroppedDims();
181+
// Can't fold if the insert_slice op expands to more dims.
182+
if (extractDroppedDims.size() < insertDroppedDims.size()) {
183+
return rewriter.notifyMatchFailure(insertSliceOp,
184+
"insert_slice expands more dims");
185+
}
186+
187+
// Try to match the extract dropped dims to the insert dropped dims. This is
188+
// done by scanning the dims of extract_slice and find the left-most one can
189+
// match the dim of insert_slice. If a match is found, advance the dim of
190+
// insert_slice to match the next one.
191+
unsigned insertDimPos = 0;
192+
for (unsigned extractDimPos = 0; extractDimPos < extractDroppedDims.size();
193+
++extractDimPos) {
194+
// Matched all dims.
195+
if (insertDimPos == insertDroppedDims.size())
196+
break;
197+
198+
bool isExtractDropped = extractDroppedDims[extractDimPos];
199+
bool isInsertDropped = insertDroppedDims[insertDimPos];
200+
// Match if both sides drop/keep the dim. Advance and match the next dim
201+
// of insert_slice.
202+
if (isExtractDropped == isInsertDropped) {
203+
insertDimPos += 1;
204+
} else if (!isExtractDropped && isInsertDropped) {
205+
// Not enough extract dropped dims to match the insert dropped dims.
206+
return rewriter.notifyMatchFailure(insertSliceOp,
207+
"insert_slice drops more unit dims");
208+
}
209+
// If the dim is dropped by extract_slice and not by insert_slice, look
210+
// the next dim of extract_slice to see if it can match the current dim of
211+
// insert_slice.
212+
}
213+
// Can't match some insert dims.
214+
if (insertDimPos != insertDroppedDims.size()) {
215+
return rewriter.notifyMatchFailure(insertSliceOp,
216+
"insert_slice has unmatched dims");
217+
}
218+
219+
rewriter.replaceOpWithNewOp<tensor::ExtractSliceOp>(
220+
insertSliceOp, insertSliceOp.getType(), extractSliceOp.getSource(),
221+
extractSliceOp.getMixedOffsets(), extractSliceOp.getMixedSizes(),
222+
extractSliceOp.getMixedStrides());
223+
rewriter.eraseOp(extractSliceOp);
224+
225+
return success();
226+
}
227+
};
137228
} // namespace
138229

139230
void mlir::tensor::populateMergeConsecutiveInsertExtractSlicePatterns(
@@ -146,5 +237,7 @@ void mlir::tensor::populateMergeConsecutiveInsertExtractSlicePatterns(
146237

147238
void mlir::tensor::populateDropRedundantInsertSliceRankExpansionPatterns(
148239
RewritePatternSet &patterns) {
149-
patterns.add<DropRedundantInsertSliceRankExpansion>(patterns.getContext());
240+
patterns.add<DropRedundantRankExpansionOnExtractSliceOfInsertSlice,
241+
DropRedundantRankExpansionOnInsertSliceOfExtractSlice>(
242+
patterns.getContext());
150243
}

mlir/lib/Dialect/Tensor/Utils/Utils.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -142,11 +142,15 @@ mlir::tensor::getUnPackInverseSrcPerm(UnPackOp unpackOp,
142142
bool mlir::tensor::isCastLikeInsertSliceOp(InsertSliceOp op) {
143143
llvm::SmallBitVector droppedDims = op.getDroppedDims();
144144
int64_t srcDim = 0;
145+
RankedTensorType resultType = op.getDestType();
145146
// Source dims and destination dims (apart from dropped dims) must have the
146147
// same size.
147-
for (int64_t resultDim = 0; resultDim < op.getDestType().getRank();
148-
++resultDim) {
148+
for (int64_t resultDim = 0; resultDim < resultType.getRank(); ++resultDim) {
149149
if (droppedDims.test(resultDim)) {
150+
// InsertSlice may expand unit dimensions that result from inserting a
151+
// size-1 slice into a non-size-1 result dimension.
152+
if (resultType.getDimSize(resultDim) != 1)
153+
return false;
150154
continue;
151155
}
152156
FailureOr<bool> equalDimSize = ValueBoundsConstraintSet::areEqual(

mlir/test/Dialect/Tensor/drop-redundant-insert-slice-rank-expansion.mlir

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,68 @@ func.func @test_drop_rank_expansion(%src: tensor<128x480xf32>, %dest: tensor<1x1
99
%extracted_slice = tensor.extract_slice %inserted_slice[0, 0, 0, 0] [1, 1, 123, 456] [1, 1, 1, 1] : tensor<1x1x128x480xf32> to tensor<123x456xf32>
1010
return %extracted_slice : tensor<123x456xf32>
1111
}
12+
13+
// -----
14+
15+
func.func @fold_casting_insert_slice_of_extract_slice(%in : tensor<?x8x2x8xf32>, %dest : tensor<8x1x8xf32>) -> tensor<8x1x8xf32> {
16+
%extracted_slice = tensor.extract_slice %in[0, 0, 0, 0] [1, 8, 1, 8] [1, 1, 1, 1] : tensor<?x8x2x8xf32> to tensor<8x8xf32>
17+
%inserted_slice = tensor.insert_slice %extracted_slice into %dest[0, 0, 0] [8, 1, 8] [1, 1, 1] : tensor<8x8xf32> into tensor<8x1x8xf32>
18+
return %inserted_slice : tensor<8x1x8xf32>
19+
}
20+
// CHECK-LABEL: func.func @fold_casting_insert_slice_of_extract_slice(
21+
// CHECK-SAME: %[[ARG0:.*]]: tensor<?x8x2x8xf32>
22+
// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][0, 0, 0, 0] [1, 8, 1, 8] [1, 1, 1, 1]
23+
// CHECK-SAME: : tensor<?x8x2x8xf32> to tensor<8x1x8xf32>
24+
// CHECK: return %[[EXTRACTED_SLICE]] : tensor<8x1x8xf32>
25+
26+
// -----
27+
28+
func.func @fold_casting_insert_slice_of_strided_extract_slice(%in : tensor<?x8x2x8xf32>, %dest : tensor<1x4x8xf32>) -> tensor<1x4x8xf32> {
29+
%extracted_slice = tensor.extract_slice %in[0, 0, 0, 0] [1, 4, 1, 8] [1, 2, 1, 1] : tensor<?x8x2x8xf32> to tensor<4x8xf32>
30+
%inserted_slice = tensor.insert_slice %extracted_slice into %dest[0, 0, 0] [1, 4, 8] [1, 1, 1] : tensor<4x8xf32> into tensor<1x4x8xf32>
31+
return %inserted_slice : tensor<1x4x8xf32>
32+
}
33+
// CHECK-LABEL: func.func @fold_casting_insert_slice_of_strided_extract_slice(
34+
// CHECK-SAME: %[[ARG0:.*]]: tensor<?x8x2x8xf32>
35+
// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][0, 0, 0, 0] [1, 4, 1, 8] [1, 2, 1, 1]
36+
// CHECK-SAME: : tensor<?x8x2x8xf32> to tensor<1x4x8xf32>
37+
// CHECK: return %[[EXTRACTED_SLICE]] : tensor<1x4x8xf32>
38+
39+
// -----
40+
41+
func.func @no_fold_more_unit_dims_insert_slice_of_extract_slice(%in : tensor<?x8x8xf32>, %dest : tensor<1x1x8x8xf32>) -> tensor<1x1x8x8xf32> {
42+
%extracted_slice = tensor.extract_slice %in[0, 0, 0] [1, 8, 8] [1, 1, 1] : tensor<?x8x8xf32> to tensor<8x8xf32>
43+
%inserted_slice = tensor.insert_slice %extracted_slice into %dest[0, 0, 0, 0] [1, 1, 8, 8] [1, 1, 1, 1] : tensor<8x8xf32> into tensor<1x1x8x8xf32>
44+
return %inserted_slice : tensor<1x1x8x8xf32>
45+
}
46+
// CHECK-LABEL: func.func @no_fold_more_unit_dims_insert_slice_of_extract_slice(
47+
// CHECK-SAME: %[[ARG0:.*]]: tensor<?x8x8xf32>
48+
// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]]
49+
// CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[EXTRACTED_SLICE]]
50+
// CHECK: return %[[INSERTED_SLICE]] : tensor<1x1x8x8xf32>
51+
52+
// -----
53+
54+
func.func @no_fold_strided_insert_slice_of_extract_slice(%in : tensor<?x8x2x8xf32>, %dest : tensor<1x4x4xf32>) -> tensor<1x4x4xf32> {
55+
%extracted_slice = tensor.extract_slice %in[0, 0, 0, 0] [1, 8, 1, 8] [1, 1, 1, 1] : tensor<?x8x2x8xf32> to tensor<8x8xf32>
56+
%inserted_slice = tensor.insert_slice %extracted_slice into %dest[0, 0, 0] [1, 8, 8] [1, 2, 2] : tensor<8x8xf32> into tensor<1x4x4xf32>
57+
return %inserted_slice : tensor<1x4x4xf32>
58+
}
59+
// CHECK-LABEL: func.func @no_fold_strided_insert_slice_of_extract_slice(
60+
// CHECK-SAME: %[[ARG0:.*]]: tensor<?x8x2x8xf32>
61+
// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]]
62+
// CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[EXTRACTED_SLICE]]
63+
// CHECK: return %[[INSERTED_SLICE]] : tensor<1x4x4xf32>
64+
65+
// -----
66+
67+
func.func @no_fold_non_casting_insert_slice_of_extract_slice(%in : tensor<1x1x1x8x8xf32>, %dest : tensor<2x8x8xf32>) -> tensor<2x8x8xf32> {
68+
%extracted_slice = tensor.extract_slice %in[0, 0, 0, 0, 0] [1, 1, 1, 8, 8] [1, 1, 1, 1, 1] : tensor<1x1x1x8x8xf32> to tensor<8x8xf32>
69+
%inserted_slice = tensor.insert_slice %extracted_slice into %dest[0, 0, 0] [1, 8, 8] [1, 1, 1] : tensor<8x8xf32> into tensor<2x8x8xf32>
70+
return %inserted_slice : tensor<2x8x8xf32>
71+
}
72+
// CHECK-LABEL: func.func @no_fold_non_casting_insert_slice_of_extract_slice(
73+
// CHECK-SAME: %[[ARG0:.*]]: tensor<1x1x1x8x8xf32>
74+
// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]]
75+
// CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[EXTRACTED_SLICE]]
76+
// CHECK: return %[[INSERTED_SLICE]] : tensor<2x8x8xf32>

0 commit comments

Comments
 (0)