Skip to content

Commit dceb7b7

Browse files
author
Jerry Wu
committed
Add pattern to fold insert_slice of extract_slice
1 parent ca2607a commit dceb7b7

File tree

3 files changed

+108
-12
lines changed

3 files changed

+108
-12
lines changed

mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp

Lines changed: 41 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ class InsertSliceOfTransferWriteOpFolder final
6767
PatternRewriter &rewriter) const override;
6868
};
6969

70+
/// Merge insert_slice operation with extract_slice operation.
7071
class InsertSliceOfExtractSliceFolder final
7172
: public OpRewritePattern<tensor::InsertSliceOp> {
7273
public:
@@ -158,41 +159,69 @@ LogicalResult InsertSliceOfTransferWriteOpFolder::matchAndRewrite(
158159
return success();
159160
}
160161

162+
/// Merge insert_slice operation with extract_slice operation.
163+
///
164+
/// This can be done when the insert_slice op purely expands ranks (adds unit
165+
/// dims) and the extrace_slice drops corresponding unit dims. For example:
166+
///
167+
/// %extracted_slice = tensor.extract_slice %in[0, 0] [1, 8] [1, 1]
168+
/// : tensor<2x8xf32> to tensor<8xf32>
169+
/// %inserted_slice = tensor.insert_slice %extracted_slice
170+
/// into %dest[0, 0] [1, 8] [1, 1]
171+
/// : tensor<8xf32> into tensor<1x8xf32>
172+
///
173+
/// can be folded into:
174+
///
175+
/// %extracted_slice = tensor.extract_slice %in[0, 0] [1, 8] [1, 1]
176+
/// : tensor<2x8xf32> to tensor<1x8xf32>
161177
LogicalResult InsertSliceOfExtractSliceFolder::matchAndRewrite(
162178
tensor::InsertSliceOp insertSliceOp, PatternRewriter &rewriter) const {
163179
auto extractSliceOp =
164180
insertSliceOp.getSource().getDefiningOp<tensor::ExtractSliceOp>();
165181
if (!extractSliceOp)
166182
return failure();
167183

184+
// Can't fold if the extract_slice op has other users.
168185
if (!extractSliceOp->hasOneUse())
169186
return failure();
170187

188+
// Check if the insert_slice op purely expands ranks (add unit dims).
171189
if (!isCastLikeInsertSliceOp(insertSliceOp))
172190
return failure();
173191

174192
llvm::SmallBitVector extractDroppedDims = extractSliceOp.getDroppedDims();
175193
llvm::SmallBitVector insertExpandedDims = insertSliceOp.getDroppedDims();
194+
// Can't fold if the insert_slice op expands to more dims.
176195
if (extractDroppedDims.size() < insertExpandedDims.size())
177196
return failure();
178197

179-
int64_t insertPos = 0;
180-
for (int64_t extractPos = 0; extractPos < extractDroppedDims.size();
181-
++extractPos) {
182-
if (insertPos == insertExpandedDims.size())
198+
// Try to match the dropped unit dims to the expanded unit dims. This is done
199+
// by scanning the dims of extract_slice and find the left-most one can match
200+
// the dim of insert_slice. If a match is found, advance the dim of
201+
// insert_slice to match the next one.
202+
unsigned insertDimPos = 0;
203+
for (unsigned extractDimPos = 0; extractDimPos < extractDroppedDims.size();
204+
++extractDimPos) {
205+
// Matched all expanded dims.
206+
if (insertDimPos == insertExpandedDims.size())
183207
break;
184208

185-
bool isDropped = extractDroppedDims[extractPos];
186-
bool isExpanded = insertExpandedDims[insertPos];
209+
bool isDropped = extractDroppedDims[extractDimPos];
210+
bool isExpanded = insertExpandedDims[insertDimPos];
211+
// Match if both sides drop/keep the dim. Advance and match the next dim of
212+
// insert_slice.
187213
if (isDropped == isExpanded) {
188-
insertPos += 1;
189-
} else {
190-
if (!isDropped && isExpanded) {
191-
return failure();
192-
}
214+
insertDimPos += 1;
215+
} else if (!isDropped && isExpanded) {
216+
// Not enough dropped unit dims to match the expanded unit dims.
217+
return failure();
193218
}
219+
// If the dim is dropped by extract_slice and not by insert_slice, look the
220+
// next dim of extract_slice to see if it can match the current dim of
221+
// insert_slice.
194222
}
195-
if (insertPos != insertExpandedDims.size())
223+
// Can't match some expanded dims.
224+
if (insertDimPos != insertExpandedDims.size())
196225
return failure();
197226

198227
rewriter.replaceOpWithNewOp<tensor::ExtractSliceOp>(

mlir/lib/Dialect/Tensor/Utils/Utils.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,8 @@ bool mlir::tensor::isCastLikeInsertSliceOp(InsertSliceOp op) {
147147
// same size.
148148
for (int64_t resultDim = 0; resultDim < resultType.getRank(); ++resultDim) {
149149
if (droppedDims.test(resultDim)) {
150+
// InsertSlice may expand unit dimensions that result from inserting a
151+
// size-1 slice into a non-size-1 result dimension.
150152
if (resultType.getDimSize(resultDim) != 1)
151153
return false;
152154
continue;

mlir/test/Dialect/Tensor/fold-tensor-subset-ops.mlir

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -390,3 +390,68 @@ func.func @parallel_insert_slice_of_insert_slice_dynamic(
390390
}
391391
return %0: tensor<12x34xf32>
392392
}
393+
394+
// -----
395+
396+
func.func @fold_casting_insert_slice_of_extract_slice(%in : tensor<?x8x2x8xf32>, %dest : tensor<8x1x8xf32>) -> tensor<8x1x8xf32> {
397+
%extracted_slice = tensor.extract_slice %in[0, 0, 0, 0] [1, 8, 1, 8] [1, 1, 1, 1] : tensor<?x8x2x8xf32> to tensor<8x8xf32>
398+
%inserted_slice = tensor.insert_slice %extracted_slice into %dest[0, 0, 0] [8, 1, 8] [1, 1, 1] : tensor<8x8xf32> into tensor<8x1x8xf32>
399+
return %inserted_slice : tensor<8x1x8xf32>
400+
}
401+
// CHECK-LABEL: func.func @fold_casting_insert_slice_of_extract_slice(
402+
// CHECK-SAME: %[[ARG0:.*]]: tensor<?x8x2x8xf32>
403+
// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][0, 0, 0, 0] [1, 8, 1, 8] [1, 1, 1, 1]
404+
// CHECK-SAME: : tensor<?x8x2x8xf32> to tensor<8x1x8xf32>
405+
// CHECK: return %[[EXTRACTED_SLICE]] : tensor<8x1x8xf32>
406+
407+
// -----
408+
409+
func.func @fold_casting_insert_slice_of_strided_extract_slice(%in : tensor<?x8x2x8xf32>, %dest : tensor<1x4x8xf32>) -> tensor<1x4x8xf32> {
410+
%extracted_slice = tensor.extract_slice %in[0, 0, 0, 0] [1, 4, 1, 8] [1, 2, 1, 1] : tensor<?x8x2x8xf32> to tensor<4x8xf32>
411+
%inserted_slice = tensor.insert_slice %extracted_slice into %dest[0, 0, 0] [1, 4, 8] [1, 1, 1] : tensor<4x8xf32> into tensor<1x4x8xf32>
412+
return %inserted_slice : tensor<1x4x8xf32>
413+
}
414+
// CHECK-LABEL: func.func @fold_casting_insert_slice_of_strided_extract_slice(
415+
// CHECK-SAME: %[[ARG0:.*]]: tensor<?x8x2x8xf32>
416+
// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][0, 0, 0, 0] [1, 4, 1, 8] [1, 2, 1, 1]
417+
// CHECK-SAME: : tensor<?x8x2x8xf32> to tensor<1x4x8xf32>
418+
// CHECK: return %[[EXTRACTED_SLICE]] : tensor<1x4x8xf32>
419+
420+
// -----
421+
422+
func.func @no_fold_more_unit_dims_insert_slice_of_extract_slice(%in : tensor<?x8x8xf32>, %dest : tensor<1x1x8x8xf32>) -> tensor<1x1x8x8xf32> {
423+
%extracted_slice = tensor.extract_slice %in[0, 0, 0] [1, 8, 8] [1, 1, 1] : tensor<?x8x8xf32> to tensor<8x8xf32>
424+
%inserted_slice = tensor.insert_slice %extracted_slice into %dest[0, 0, 0, 0] [1, 1, 8, 8] [1, 1, 1, 1] : tensor<8x8xf32> into tensor<1x1x8x8xf32>
425+
return %inserted_slice : tensor<1x1x8x8xf32>
426+
}
427+
// CHECK-LABEL: func.func @no_fold_more_unit_dims_insert_slice_of_extract_slice(
428+
// CHECK-SAME: %[[ARG0:.*]]: tensor<?x8x8xf32>
429+
// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]]
430+
// CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[EXTRACTED_SLICE]]
431+
// CHECK: return %[[INSERTED_SLICE]] : tensor<1x1x8x8xf32>
432+
433+
// -----
434+
435+
func.func @no_fold_strided_insert_slice_of_extract_slice(%in : tensor<?x8x2x8xf32>, %dest : tensor<1x4x4xf32>) -> tensor<1x4x4xf32> {
436+
%extracted_slice = tensor.extract_slice %in[0, 0, 0, 0] [1, 8, 1, 8] [1, 1, 1, 1] : tensor<?x8x2x8xf32> to tensor<8x8xf32>
437+
%inserted_slice = tensor.insert_slice %extracted_slice into %dest[0, 0, 0] [1, 8, 8] [1, 2, 2] : tensor<8x8xf32> into tensor<1x4x4xf32>
438+
return %inserted_slice : tensor<1x4x4xf32>
439+
}
440+
// CHECK-LABEL: func.func @no_fold_strided_insert_slice_of_extract_slice(
441+
// CHECK-SAME: %[[ARG0:.*]]: tensor<?x8x2x8xf32>
442+
// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]]
443+
// CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[EXTRACTED_SLICE]]
444+
// CHECK: return %[[INSERTED_SLICE]] : tensor<1x4x4xf32>
445+
446+
// -----
447+
448+
func.func @no_fold_non_casting_insert_slice_of_extract_slice(%in : tensor<1x1x1x8x8xf32>, %dest : tensor<2x8x8xf32>) -> tensor<2x8x8xf32> {
449+
%extracted_slice = tensor.extract_slice %in[0, 0, 0, 0, 0] [1, 1, 1, 8, 8] [1, 1, 1, 1, 1] : tensor<1x1x1x8x8xf32> to tensor<8x8xf32>
450+
%inserted_slice = tensor.insert_slice %extracted_slice into %dest[0, 0, 0] [1, 8, 8] [1, 1, 1] : tensor<8x8xf32> into tensor<2x8x8xf32>
451+
return %inserted_slice : tensor<2x8x8xf32>
452+
}
453+
// CHECK-LABEL: func.func @no_fold_non_casting_insert_slice_of_extract_slice(
454+
// CHECK-SAME: %[[ARG0:.*]]: tensor<1x1x1x8x8xf32>
455+
// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]]
456+
// CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[EXTRACTED_SLICE]]
457+
// CHECK: return %[[INSERTED_SLICE]] : tensor<2x8x8xf32>

0 commit comments

Comments
 (0)