Skip to content

[MLIR] Add pattern to fold insert_slice of extract_slice #86328

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Mar 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -78,12 +78,12 @@ struct MergeConsecutiveInsertSlice : public OpRewritePattern<OpTy> {
}
};

/// Drop redundant rank expansion. I.e., rank expansions that are directly
/// followed by rank reductions. E.g.:
/// Drop redundant rank expansion of insert_slice that are directly followed
/// by extract_slice. E.g.:
/// %0 = tensor.insert_slice ... : tensor<5x10xf32> into tensor<1x1x5x10xf32>
/// %1 = tensor.extract_slice %0[0, 0, 2, 3] [1, 1, 2, 2] [1, 1, 1, 1]
/// : tensor<1x1x5x10xf32> to tensor<2x2xf32>
struct DropRedundantInsertSliceRankExpansion
struct DropRedundantRankExpansionOnExtractSliceOfInsertSlice
: public OpRewritePattern<ExtractSliceOp> {
using OpRewritePattern::OpRewritePattern;

Expand Down Expand Up @@ -134,6 +134,97 @@ struct DropRedundantInsertSliceRankExpansion
return success();
}
};

/// Drop redundant rank expansion of insert_slice that direclty follows
/// extract_slice.
///
/// This can be done when the insert_slice op purely expands ranks (adds unit
/// dims) and the extrace_slice drops corresponding unit dims. For example:
///
/// %extracted_slice = tensor.extract_slice %in[0, 0] [1, 8] [1, 1]
/// : tensor<2x8xf32> to tensor<8xf32>
/// %inserted_slice = tensor.insert_slice %extracted_slice
/// into %dest[0, 0] [1, 8] [1, 1]
/// : tensor<8xf32> into tensor<1x8xf32>
///
/// can be folded into:
///
/// %extracted_slice = tensor.extract_slice %in[0, 0] [1, 8] [1, 1]
/// : tensor<2x8xf32> to tensor<1x8xf32>
struct DropRedundantRankExpansionOnInsertSliceOfExtractSlice final
: public OpRewritePattern<tensor::InsertSliceOp> {
using OpRewritePattern<tensor::InsertSliceOp>::OpRewritePattern;

LogicalResult matchAndRewrite(tensor::InsertSliceOp insertSliceOp,
PatternRewriter &rewriter) const {
auto extractSliceOp =
insertSliceOp.getSource().getDefiningOp<tensor::ExtractSliceOp>();
if (!extractSliceOp) {
return rewriter.notifyMatchFailure(insertSliceOp,
"source is not extract_slice");
}

// Can't fold if the extract_slice op has other users.
if (!extractSliceOp->hasOneUse()) {
return rewriter.notifyMatchFailure(insertSliceOp,
"source has multi-uses");
}

// Check if the insert_slice op purely expands ranks (add unit dims).
if (!isCastLikeInsertSliceOp(insertSliceOp)) {
return rewriter.notifyMatchFailure(insertSliceOp,
"insert_slice is not cast-like");
}

llvm::SmallBitVector extractDroppedDims = extractSliceOp.getDroppedDims();
llvm::SmallBitVector insertDroppedDims = insertSliceOp.getDroppedDims();
// Can't fold if the insert_slice op expands to more dims.
if (extractDroppedDims.size() < insertDroppedDims.size()) {
return rewriter.notifyMatchFailure(insertSliceOp,
"insert_slice expands more dims");
}

// Try to match the extract dropped dims to the insert dropped dims. This is
// done by scanning the dims of extract_slice and find the left-most one can
// match the dim of insert_slice. If a match is found, advance the dim of
// insert_slice to match the next one.
unsigned insertDimPos = 0;
for (unsigned extractDimPos = 0; extractDimPos < extractDroppedDims.size();
++extractDimPos) {
// Matched all dims.
if (insertDimPos == insertDroppedDims.size())
break;

bool isExtractDropped = extractDroppedDims[extractDimPos];
bool isInsertDropped = insertDroppedDims[insertDimPos];
// Match if both sides drop/keep the dim. Advance and match the next dim
// of insert_slice.
if (isExtractDropped == isInsertDropped) {
insertDimPos += 1;
} else if (!isExtractDropped && isInsertDropped) {
// Not enough extract dropped dims to match the insert dropped dims.
return rewriter.notifyMatchFailure(insertSliceOp,
"insert_slice drops more unit dims");
}
// If the dim is dropped by extract_slice and not by insert_slice, look
// the next dim of extract_slice to see if it can match the current dim of
// insert_slice.
}
// Can't match some insert dims.
if (insertDimPos != insertDroppedDims.size()) {
return rewriter.notifyMatchFailure(insertSliceOp,
"insert_slice has unmatched dims");
}

rewriter.replaceOpWithNewOp<tensor::ExtractSliceOp>(
insertSliceOp, insertSliceOp.getType(), extractSliceOp.getSource(),
extractSliceOp.getMixedOffsets(), extractSliceOp.getMixedSizes(),
extractSliceOp.getMixedStrides());
rewriter.eraseOp(extractSliceOp);

return success();
}
};
} // namespace

void mlir::tensor::populateMergeConsecutiveInsertExtractSlicePatterns(
Expand All @@ -146,5 +237,7 @@ void mlir::tensor::populateMergeConsecutiveInsertExtractSlicePatterns(

void mlir::tensor::populateDropRedundantInsertSliceRankExpansionPatterns(
RewritePatternSet &patterns) {
patterns.add<DropRedundantInsertSliceRankExpansion>(patterns.getContext());
patterns.add<DropRedundantRankExpansionOnExtractSliceOfInsertSlice,
DropRedundantRankExpansionOnInsertSliceOfExtractSlice>(
patterns.getContext());
}
8 changes: 6 additions & 2 deletions mlir/lib/Dialect/Tensor/Utils/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -142,11 +142,15 @@ mlir::tensor::getUnPackInverseSrcPerm(UnPackOp unpackOp,
bool mlir::tensor::isCastLikeInsertSliceOp(InsertSliceOp op) {
llvm::SmallBitVector droppedDims = op.getDroppedDims();
int64_t srcDim = 0;
RankedTensorType resultType = op.getDestType();
// Source dims and destination dims (apart from dropped dims) must have the
// same size.
for (int64_t resultDim = 0; resultDim < op.getDestType().getRank();
++resultDim) {
for (int64_t resultDim = 0; resultDim < resultType.getRank(); ++resultDim) {
if (droppedDims.test(resultDim)) {
// InsertSlice may expand unit dimensions that result from inserting a
// size-1 slice into a non-size-1 result dimension.
if (resultType.getDimSize(resultDim) != 1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this needed? If a dimension is "dropped", it must be a unit dimension.

Copy link
Member Author

@pzread pzread Mar 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is to check if the dropped unit dim is inserted into non-unit dim. For example:

tensor.insert_slice %x into %y[0, 0, 0] [1, 8, 8] [1, 1, 1] : tensor<8x8xf32> into tensor<2x8x8xf32>

in this case, the first dim of the source 1x8x8xf32 is seen as a dropped unit dim, but it is inserted into 2x8x8xf32. I think in this case we don't consider it as a cast-like op as the source (1)x8x8xf32 != the destination 2x8x8xf32

return false;
continue;
}
FailureOr<bool> equalDimSize = ValueBoundsConstraintSet::areEqual(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,68 @@ func.func @test_drop_rank_expansion(%src: tensor<128x480xf32>, %dest: tensor<1x1
%extracted_slice = tensor.extract_slice %inserted_slice[0, 0, 0, 0] [1, 1, 123, 456] [1, 1, 1, 1] : tensor<1x1x128x480xf32> to tensor<123x456xf32>
return %extracted_slice : tensor<123x456xf32>
}

// -----

func.func @fold_casting_insert_slice_of_extract_slice(%in : tensor<?x8x2x8xf32>, %dest : tensor<8x1x8xf32>) -> tensor<8x1x8xf32> {
%extracted_slice = tensor.extract_slice %in[0, 0, 0, 0] [1, 8, 1, 8] [1, 1, 1, 1] : tensor<?x8x2x8xf32> to tensor<8x8xf32>
%inserted_slice = tensor.insert_slice %extracted_slice into %dest[0, 0, 0] [8, 1, 8] [1, 1, 1] : tensor<8x8xf32> into tensor<8x1x8xf32>
return %inserted_slice : tensor<8x1x8xf32>
}
// CHECK-LABEL: func.func @fold_casting_insert_slice_of_extract_slice(
// CHECK-SAME: %[[ARG0:.*]]: tensor<?x8x2x8xf32>
// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][0, 0, 0, 0] [1, 8, 1, 8] [1, 1, 1, 1]
// CHECK-SAME: : tensor<?x8x2x8xf32> to tensor<8x1x8xf32>
// CHECK: return %[[EXTRACTED_SLICE]] : tensor<8x1x8xf32>

// -----

func.func @fold_casting_insert_slice_of_strided_extract_slice(%in : tensor<?x8x2x8xf32>, %dest : tensor<1x4x8xf32>) -> tensor<1x4x8xf32> {
%extracted_slice = tensor.extract_slice %in[0, 0, 0, 0] [1, 4, 1, 8] [1, 2, 1, 1] : tensor<?x8x2x8xf32> to tensor<4x8xf32>
%inserted_slice = tensor.insert_slice %extracted_slice into %dest[0, 0, 0] [1, 4, 8] [1, 1, 1] : tensor<4x8xf32> into tensor<1x4x8xf32>
return %inserted_slice : tensor<1x4x8xf32>
}
// CHECK-LABEL: func.func @fold_casting_insert_slice_of_strided_extract_slice(
// CHECK-SAME: %[[ARG0:.*]]: tensor<?x8x2x8xf32>
// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][0, 0, 0, 0] [1, 4, 1, 8] [1, 2, 1, 1]
// CHECK-SAME: : tensor<?x8x2x8xf32> to tensor<1x4x8xf32>
// CHECK: return %[[EXTRACTED_SLICE]] : tensor<1x4x8xf32>

// -----

func.func @no_fold_more_unit_dims_insert_slice_of_extract_slice(%in : tensor<?x8x8xf32>, %dest : tensor<1x1x8x8xf32>) -> tensor<1x1x8x8xf32> {
%extracted_slice = tensor.extract_slice %in[0, 0, 0] [1, 8, 8] [1, 1, 1] : tensor<?x8x8xf32> to tensor<8x8xf32>
%inserted_slice = tensor.insert_slice %extracted_slice into %dest[0, 0, 0, 0] [1, 1, 8, 8] [1, 1, 1, 1] : tensor<8x8xf32> into tensor<1x1x8x8xf32>
return %inserted_slice : tensor<1x1x8x8xf32>
}
// CHECK-LABEL: func.func @no_fold_more_unit_dims_insert_slice_of_extract_slice(
// CHECK-SAME: %[[ARG0:.*]]: tensor<?x8x8xf32>
// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]]
// CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[EXTRACTED_SLICE]]
// CHECK: return %[[INSERTED_SLICE]] : tensor<1x1x8x8xf32>

// -----

func.func @no_fold_strided_insert_slice_of_extract_slice(%in : tensor<?x8x2x8xf32>, %dest : tensor<1x4x4xf32>) -> tensor<1x4x4xf32> {
%extracted_slice = tensor.extract_slice %in[0, 0, 0, 0] [1, 8, 1, 8] [1, 1, 1, 1] : tensor<?x8x2x8xf32> to tensor<8x8xf32>
%inserted_slice = tensor.insert_slice %extracted_slice into %dest[0, 0, 0] [1, 8, 8] [1, 2, 2] : tensor<8x8xf32> into tensor<1x4x4xf32>
return %inserted_slice : tensor<1x4x4xf32>
}
// CHECK-LABEL: func.func @no_fold_strided_insert_slice_of_extract_slice(
// CHECK-SAME: %[[ARG0:.*]]: tensor<?x8x2x8xf32>
// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]]
// CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[EXTRACTED_SLICE]]
// CHECK: return %[[INSERTED_SLICE]] : tensor<1x4x4xf32>

// -----

func.func @no_fold_non_casting_insert_slice_of_extract_slice(%in : tensor<1x1x1x8x8xf32>, %dest : tensor<2x8x8xf32>) -> tensor<2x8x8xf32> {
%extracted_slice = tensor.extract_slice %in[0, 0, 0, 0, 0] [1, 1, 1, 8, 8] [1, 1, 1, 1, 1] : tensor<1x1x1x8x8xf32> to tensor<8x8xf32>
%inserted_slice = tensor.insert_slice %extracted_slice into %dest[0, 0, 0] [1, 8, 8] [1, 1, 1] : tensor<8x8xf32> into tensor<2x8x8xf32>
return %inserted_slice : tensor<2x8x8xf32>
}
// CHECK-LABEL: func.func @no_fold_non_casting_insert_slice_of_extract_slice(
// CHECK-SAME: %[[ARG0:.*]]: tensor<1x1x1x8x8xf32>
// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]]
// CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[EXTRACTED_SLICE]]
// CHECK: return %[[INSERTED_SLICE]] : tensor<2x8x8xf32>