-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[MLIR] Bubble up tensor.extract_slice through tensor.collapse_shape #131982
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 5 commits
72fbf71
c0291d0
72b0be3
1aaf3c9
5845db6
3c69390
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,6 +12,7 @@ | |
#include "mlir/Dialect/Tensor/Transforms/Transforms.h" | ||
#include "mlir/IR/PatternMatch.h" | ||
#include "mlir/Interfaces/ValueBoundsOpInterface.h" | ||
#include "llvm/ADT/STLExtras.h" | ||
#include "llvm/Support/Debug.h" | ||
#include "llvm/Support/LogicalResult.h" | ||
|
||
|
@@ -428,6 +429,253 @@ struct BubbleUpExpandShapeThroughExtractSlice | |
} | ||
}; | ||
|
||
/// Converts `tensor.extract_slice(tensor.collapse_shape)` to | ||
/// `tensor.collapse_shape(tensor.extract_slice)`. | ||
/// | ||
/// For this transformation to be possible - after bubbling up, the extraction | ||
/// of the contiguous slice must be representable as a single slice obtained via | ||
/// tensor.extract_slice within each reassociation group of the src. | ||
/// | ||
/// In case the size and offset extracted are static then this is possible if | ||
/// the following conditions are met within each reassociation group: | ||
/// Let T be a tensor of shape [A0, A1, ..., An] (these are the sizes of the | ||
/// dimensions in the reassociation group), and let S = [S0, S1, ..., Sn] be the | ||
/// shape of a desired slice. A slice of shape S can be extracted as a | ||
/// contiguous span of elements if and only if there exists an index k in {0, 1, | ||
/// ..., n} such that: | ||
/// S_i = 1 for all i < k (that is, all leading dimensions are singleton), | ||
/// 1 <= S_k <= A_k (that is, non trivial slicing occurs along exactly | ||
/// one dimension), | ||
/// S_i = A_i for all i > k (that is, all trailing dimensions are preserved | ||
/// in full). | ||
/// In other words, the slice shape S must be of the form: | ||
/// [ 1, 1, ..., 1, Sk, Ak + 1, Ak + 2, ...,An ] | ||
/// | ||
/// In case the size and/or offset extracted are dynamic then this is possible | ||
/// only if there is single dimension in the reassociation group that has a size | ||
/// not equal to 1. | ||
/// In other words, the tensor shape must be of the form: | ||
/// [ 1, 1, ..., 1, A, 1, ...,1 ] | ||
/// Note - it might be possible to enable this pattern for more cases when the | ||
/// size/offset are dynamic via performing an analysis of the possible values | ||
/// that could be given to the size/offset. | ||
/// | ||
/// Example: | ||
/// The transformation is possible because each reassociation group can be | ||
/// represented as a contiguous slice (i.e., [8x16->2x16], [1x7->1x?], | ||
/// [20->10]). | ||
/// ``` | ||
/// BEFORE: | ||
/// %collapse = tensor.collapse_shape %src [[0, 1], [2, 3], [4]] ... | ||
/// tensor<8x16x1x7x20f32> to tensor<128x7x20xf32> | ||
/// %slice = tensor.extract_slice %slice [0, 0, 0][32, %size, 10][1, 1, 1] | ||
/// tensor<128x7x20xf32> to tensor<32x?x10xf32> | ||
/// | ||
/// AFTER: | ||
/// %slice = tensor.extract_slice %src [0, 0, 0, 0, 0][2, 16, 1, %size, 10] | ||
// [1, 1, 1, 1, 1] : tensor<8x16x1x7x20f32> to tensor<2x16x1x?x10xf32> | ||
/// %collapse = tensor.collapse_shape %slice [[0, 1], [2, 3], [4]] ... | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Whew! This is actually correct. Took me a while to work out that it, but this works because for the reassociation the other dimensions are all 1. |
||
/// tensor<2x16x1x?x10xf32> to tensor<32x?x10xf32> | ||
/// ``` | ||
/// | ||
/// Negative example: | ||
/// The transformation is not possible because we cannot use a single slice to | ||
/// represent the reassociation group [2x3x10->???]. If we would want the | ||
/// collapse to be after the extraction, we would need to extract multiple | ||
/// slices and concat them together. | ||
/// ``` | ||
/// %collapse = tensor.collapse_shape %src [[0, 1, 2]] : tensor<2x3x10xf32> into | ||
/// tensor<60xf32> %extract = tensor.extract_slice %collapse[0][15][1] : | ||
/// tensor<60xf32> to tensor<15xf32> | ||
/// ``` | ||
/// If we would want the collapse to be after the extraction, a possible | ||
/// alternate transformation could be to extract multiple slices and concat them | ||
/// together: | ||
/// ``` | ||
/// %extract_1 = tensor.extract_slice %src[0, 0, 0][1, 1, 10] : | ||
/// tensor<2x3x10xf32> to tensor <1x1x10xf32> | ||
/// %extract_2 = tensor.extract_slice %src[0, 1, 0][1, 1, 5] : | ||
/// tensor<2x3x10xf32> to tensor <1x1x5xf32> | ||
/// %concat = tosa.concat %extract_1, %extract_2 {axis = 0 : i32} : | ||
/// (<1x1x10xf32>, <1x1x5xf32>) -> <1x1x15xf32> | ||
/// %collapse = tensor.collapse_shape %concat [[0, 1, 2]] : tensor<1x1x15xf32> | ||
/// to tensor<15xf32> | ||
/// ``` | ||
/// But this is not the intended purpose of the transformation. | ||
struct BubbleUpCollapseShapeThroughExtractSlice | ||
: public OpRewritePattern<tensor::ExtractSliceOp> { | ||
using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern; | ||
|
||
LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp, | ||
PatternRewriter &rewriter) const override { | ||
auto collapseShapeOp = | ||
sliceOp.getSource().getDefiningOp<tensor::CollapseShapeOp>(); | ||
if (!collapseShapeOp) | ||
return rewriter.notifyMatchFailure( | ||
sliceOp, | ||
"tensor.extract_slice source not produced by tensor.collapse_shape"); | ||
|
||
if (!sliceOp.hasUnitStride()) { | ||
return rewriter.notifyMatchFailure( | ||
sliceOp, "unsupported: non-unit stride. Only contiguous slices can " | ||
"be supported in this transformation."); | ||
} | ||
|
||
// The tensor.extract_slice before applying the pattern works on the result | ||
// of the tensor.collapse_shape, so variables (i.e. inputs for | ||
// ExtractSliceOp) referring to the state before applying the pattern are | ||
// named with the prefix "collapsed", and ones referring to the state after | ||
// applying the pattern are named with the prefix "expanded". | ||
SmallVector<OpFoldResult> collapsedOffsets = sliceOp.getMixedOffsets(); | ||
SmallVector<OpFoldResult> collapsedSizes = sliceOp.getMixedSizes(); | ||
|
||
if (static_cast<size_t>(sliceOp.getResultType().getRank()) != | ||
collapsedSizes.size()) | ||
return rewriter.notifyMatchFailure(sliceOp, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit: Please add There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added in all the relevant places in the pattern |
||
"unimplemented: rank reducing slice"); | ||
|
||
ArrayRef<int64_t> srcShape = collapseShapeOp.getSrcType().getShape(); | ||
SmallVector<ReassociationIndices, 4> reassociationIndices = | ||
collapseShapeOp.getReassociationIndices(); | ||
|
||
// Compute new offsets, sizes, and strides for tensor.extract_slice. | ||
// The new tensor.extract_slice will work on a tensor that has has a rank | ||
// equal to the rank of the src of the collapse_shape. In each iteration of | ||
// the loop, the offsets and sizes will be computed per reassociation group. | ||
SmallVector<OpFoldResult> expandedOffsets, expandedSizes; | ||
SmallVector<OpFoldResult> expandedStrides(srcShape.size(), | ||
rewriter.getIndexAttr(1)); | ||
|
||
for (auto [groupIdx, reassocIndices] : | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit: You could do
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nice! updated |
||
enumerate(collapseShapeOp.getReassociationIndices())) { | ||
OpFoldResult collapsedSize = collapsedSizes[groupIdx]; | ||
OpFoldResult collapsedOffset = collapsedOffsets[groupIdx]; | ||
// CASE #1 - size and/or offset are dynamic. | ||
// In this case, the slice can be represented as a contiguous slice only | ||
// if there is a single dimension in the reassociation group that has a | ||
// size not equal to 1. | ||
if (isa<Value>(collapsedSize) || isa<Value>(collapsedOffset)) { | ||
int nonUnitSizeCount = 0; | ||
for (int64_t expandedShapeIdx : reassocIndices) { | ||
if (srcShape[expandedShapeIdx] != 1) { | ||
nonUnitSizeCount++; | ||
expandedSizes.push_back(collapsedSize); | ||
expandedOffsets.push_back(collapsedOffset); | ||
continue; | ||
} | ||
|
||
expandedSizes.push_back(rewriter.getIndexAttr(1)); | ||
expandedOffsets.push_back(rewriter.getIndexAttr(0)); | ||
} | ||
|
||
if (nonUnitSizeCount != 1) { | ||
return rewriter.notifyMatchFailure( | ||
sliceOp, | ||
"unsupported: slice cannot be verified to be contiguous"); | ||
} | ||
continue; | ||
} | ||
|
||
// CASE #2 = size and offset are static. | ||
// Verify that the slice can be represented as a contiguous slice of the | ||
// src of the collapse_shape. | ||
// Checking this is done on order of most internal dimensions first, | ||
// so traversal is done in reverse order of the reassociation group. | ||
// If the expected slice shape is [1, 1, ..., 1, Sk, Ak + 1, Ak + 2, | ||
// ...,An] then we first find the size and offset for n...k+1 then for k | ||
// and then for k-1...0. | ||
|
||
// currentCollapsedsize and currentCollapsedOffset are initialized with | ||
// the original collapsed size and offset and divided by the expanded | ||
// shape size in each dimension as we go along the reassociation group. | ||
// In essence we are spreading the original collapsed size and offset over | ||
// the various expanded slice dimensions. | ||
// The variables are used both to check the validity of the slice and to | ||
// compute the expanded sizes and offsets. | ||
int64_t currentCollapsedsize = getConstantIntValue(collapsedSize).value(); | ||
int64_t currentCollapsedOffset = | ||
getConstantIntValue(collapsedOffset).value(); | ||
|
||
SmallVector<OpFoldResult> groupExpandedSizes, groupExpandedOffsets; | ||
|
||
ReassociationIndices reversedReassocIndices(reassocIndices.rbegin(), | ||
reassocIndices.rend()); | ||
int64_t idx = 0; | ||
int64_t reassocGroupSize = reassocIndices.size(); | ||
|
||
// First handle the trailing dimensions where the slice size should be | ||
// equal to the tensor shape and the offset should be 0 (n...k+1). | ||
for (; idx < reassocGroupSize; ++idx) { | ||
int64_t expandedShapeSize = srcShape[reversedReassocIndices[idx]]; | ||
|
||
if (currentCollapsedsize < expandedShapeSize) | ||
break; | ||
|
||
// We need to make sure that the slice size can be set to the shape size | ||
// and the offset to 0. | ||
if ((currentCollapsedsize % expandedShapeSize) != 0 || | ||
(currentCollapsedOffset % expandedShapeSize) != 0) | ||
return rewriter.notifyMatchFailure( | ||
sliceOp, "unsupported: cannot be extracted as a contiguous slice " | ||
"of the src of the collapse_shape"); | ||
|
||
groupExpandedSizes.push_back(rewriter.getIndexAttr(expandedShapeSize)); | ||
groupExpandedOffsets.push_back(rewriter.getIndexAttr(0)); | ||
|
||
currentCollapsedsize /= expandedShapeSize; | ||
currentCollapsedOffset /= expandedShapeSize; | ||
} | ||
|
||
// Now handle the first dim where slicing occurs on (k). | ||
if (idx < reassocGroupSize) { | ||
int64_t expandedShapeSize = srcShape[reversedReassocIndices[idx]]; | ||
int64_t offsetInDim = currentCollapsedOffset % expandedShapeSize; | ||
// We need to make sure that the slice size in this dim + offset will | ||
// not exceed the shape size. | ||
if ((currentCollapsedsize + offsetInDim) >= expandedShapeSize) | ||
return rewriter.notifyMatchFailure( | ||
sliceOp, "unsupported: slice cannot be extracted as a contiguous " | ||
"slice of the src of the collapse_shape"); | ||
|
||
groupExpandedSizes.push_back( | ||
rewriter.getIndexAttr(currentCollapsedsize)); | ||
groupExpandedOffsets.push_back(rewriter.getIndexAttr(offsetInDim)); | ||
|
||
currentCollapsedOffset /= expandedShapeSize; | ||
} | ||
|
||
// Now handle the leading dimensions where the slice size is equal to 1 | ||
// (k-1...0). | ||
banach-space marked this conversation as resolved.
Show resolved
Hide resolved
|
||
// The size for these dimensions must be 1 because of how we constructed | ||
// the slice size of the expanded shape. We spread the original collapsed | ||
// size over the expanded shape sizes until we reached dimension k where | ||
// the remaining size was smaller than the expanded shape size, and spread | ||
// the remaining size on it. So, now we are left with only 1s. | ||
for (idx++; idx < reassocGroupSize; ++idx) { | ||
int64_t expandedShapeSize = srcShape[reversedReassocIndices[idx]]; | ||
int64_t offsetInDim = currentCollapsedOffset % expandedShapeSize; | ||
groupExpandedSizes.push_back(rewriter.getIndexAttr(1)); | ||
groupExpandedOffsets.push_back(rewriter.getIndexAttr(offsetInDim)); | ||
currentCollapsedOffset /= expandedShapeSize; | ||
} | ||
|
||
expandedSizes.append(groupExpandedSizes.rbegin(), | ||
groupExpandedSizes.rend()); | ||
expandedOffsets.append(groupExpandedOffsets.rbegin(), | ||
groupExpandedOffsets.rend()); | ||
} | ||
|
||
Value newSliceOp = rewriter.create<tensor::ExtractSliceOp>( | ||
collapseShapeOp->getLoc(), collapseShapeOp.getSrc(), expandedOffsets, | ||
expandedSizes, expandedStrides); | ||
rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>( | ||
sliceOp, sliceOp.getResultType(), newSliceOp, | ||
collapseShapeOp.getReassociationIndices()); | ||
|
||
return success(); | ||
} | ||
}; | ||
|
||
} // namespace | ||
|
||
void mlir::tensor::populateReassociativeReshapeFoldingPatterns( | ||
|
@@ -448,5 +696,6 @@ void mlir::tensor::populateBubbleUpExpandShapePatterns( | |
|
||
void mlir::tensor::populateBubbleUpExtractSliceOpPatterns( | ||
RewritePatternSet &patterns) { | ||
patterns.add<BubbleUpExpandShapeThroughExtractSlice>(patterns.getContext()); | ||
patterns.add<BubbleUpExpandShapeThroughExtractSlice, | ||
BubbleUpCollapseShapeThroughExtractSlice>(patterns.getContext()); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I struggle with this part a bit. Shouldn't there be only one reassociation group in which slicing happens? And, within that group, exactly one expanded dim should be sliced? If I am incorrect, is there an example?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll start out with the second part of your question regarding each individual reassociation group.
We need to make sure that after we bubble up the
tensor.extract_slice
that the data we extract from the expanded tensor for this group is contiguous, since it was contiguous before the bubble up.Now lets look at the post bubble up sizes [ 1, 1, ..., 1, Sk, Ak + 1, Ak + 2, ...,An ] and see why this makes sense.
If we start with the trailing dims, we want them all the be full, which means that the data is obviously contiguous. Then we have this one dim where we extract a slice that is smaller than the full size Sk, but things are still contiguous since we now have a contiguous slice of size Sk*(Prod(Ai) for i=k+1 to n). Note this this slice might not be with offset 0, but that is still fine since the original slice we extracted might not have been with offset 0.
Then we have all those leading dime of size 1. Since they are size 1 this means that we could say that we are technically slicing on this dim but in practice the size 1 means that we aren't breaking contiguity and just maybe changing the offset of the contiguous block that we extract.
to summarize - we could technically say that we are slicing on k different dims, but in practice the result is a single contiguous block of data as required.
Now regarding the first part of the question about different reassociation groups.
The data extracted from the collapsed tensor might have been sliced on multiple different collapsed dims which could make the entire original slice to not be contiguous, so it should be fine if slicing occurs on multiple different reassociation groups after the bubble up.
I'll give an example for this from the tests:
In the example there are 2 different reassociation groups, where a slice is extracted from each one of them, and the result of the bubble up is still legal and the result of the
tensor.collapse_shape
post bubble up represents the same data as the result of thetensor.extract_slice
pre bubble up.We can also directly calculate the indices within the src tensor of the elements from the result and see that they are equal in both cases: [160:169], [190:199], [220:229]...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for taking the time to write this very comprehensive explanation, this makes it super clear to me 🙏🏻 (and I was indeed incorrect)
I was just about to ask you to add this example as a test, but that's already done! :)