Skip to content

[MLIR] Add pattern to bubble up tensor.extract_slice #126898

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,17 @@ def ApplyReassociativeReshapeFoldingPatternsOp : Op<Transform_Dialect,
let assemblyFormat = "attr-dict";
}

def ApplyBubbleUpExtractSlicePatternsOp : Op<Transform_Dialect,
"apply_patterns.tensor.bubble_up_extract_slice",
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
let description = [{
Indicates that producers of tensor.extract_slice should swap and operate on
the result of the slice.
}];

let assemblyFormat = "attr-dict";
}

def ApplyRewriteTensorOpsAsConstantPatternsOp : Op<Transform_Dialect,
"apply_patterns.tensor.rewrite_as_constant",
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
Expand Down
6 changes: 6 additions & 0 deletions mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,12 @@ void populateFoldTensorSubsetIntoVectorTransferPatterns(
void populateMergeConsecutiveInsertExtractSlicePatterns(
RewritePatternSet &patterns);

/// Appends patterns that are used to bubble up tensor.extract slice op above
/// its producer. When used as cleanup patterns of tile and fuse, enables fusing
/// the producer with the consumer even if the producer does not implement the
/// tiling interface.
void populateBubbleUpExtractSliceOpPatterns(RewritePatternSet &patterns);

/// Populates `patterns` with patterns that drop redundant tensor.insert_slice
/// rank expansions.
void populateDropRedundantInsertSliceRankExpansionPatterns(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -582,6 +582,7 @@ transform::FuseOp::apply(transform::TransformRewriter &rewriter,
RewritePatternSet patterns(context);
tensor::ExtractSliceOp::getCanonicalizationPatterns(patterns, context);
tensor::populateMergeConsecutiveInsertExtractSlicePatterns(patterns);
tensor::populateBubbleUpExtractSliceOpPatterns(patterns);
tileAndFuseOptions.cleanupPatterns = std::move(patterns);
}

Expand Down
5 changes: 5 additions & 0 deletions mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,11 @@ void transform::ApplyReassociativeReshapeFoldingPatternsOp::populatePatterns(
tensor::populateReassociativeReshapeFoldingPatterns(patterns);
}

void transform::ApplyBubbleUpExtractSlicePatternsOp::populatePatterns(
RewritePatternSet &patterns) {
tensor::populateBubbleUpExtractSliceOpPatterns(patterns);
}

void transform::ApplyRewriteTensorOpsAsConstantPatternsOp::populatePatterns(
RewritePatternSet &patterns) {
ControlFoldFn defaultControlFn = [](OpOperand *fusedOperand) {
Expand Down
217 changes: 217 additions & 0 deletions mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,14 @@
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/LogicalResult.h"

using namespace mlir;
using namespace mlir::tensor;
Expand Down Expand Up @@ -210,6 +214,214 @@ struct BubbleUpExpandThroughParallelCollapse
}
};

/// Converts `tensor.extract_slice(tensor.expand_shape)` to
/// `tensor.expand_shape(tensor.extract_slice)`.
///
/// For this transformation to be possible, the slice must be fully contiguous
/// within each reassociation group of the expand_shape. A slice is defined as
/// fully contiguous within a reassociation group if after flattening the
/// reassociation group to a single 1D range, then the slice taken out of the
/// group could be defined as a single contiguous subrange within that range.
///
/// Rank reducing slices are not supported.
///
/// Example:
/// The transformation is possible because each reassociation group has a
/// contiguous slice (i.e., [2x4->2x4], [2x8->1x5], [4x2x4->1x1x4]).
/// ```
/// BEFORE:
/// %reshape = tensor.expand_shape %in [[0, 1], [2, 3], [4, 5, 6]]
/// tensor<8x16x32xf32> to tensor<2x4x2x8x4x2x4xf32>
/// %slice = tensor.extract_slice %reshape ...
/// tensor<2x4x2x8x4x2x4xf32> to tensor<2x4x1x5x1x1x4xf32>
///
/// AFTER:
/// %slice = tensor.extract_slice %in ...
/// tensor<8x16x32xf32> to tensor<8x5x4xf32>
/// %reshape = tensor.expand_shape %slice [[0, 1], [2, 3], [4, 5, 6]]
/// tensor<8x5x4xf32> to tensor<2x4x1x5x1x1x4xf32>
/// ```
///
/// Note - this pattern could be extended to be a swap pattern between
/// `tensor.expand_shape` and `tensor.extract_slice`, but is currently
/// implemented only as a bubble up pattern for `tensor.extract_slice`.
struct BubbleUpExpandShapeThroughExtractSlice
: public OpRewritePattern<tensor::ExtractSliceOp> {
using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern;

LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
PatternRewriter &rewriter) const override {
auto expandShapeOp =
sliceOp.getSource().getDefiningOp<tensor::ExpandShapeOp>();

if (checkPreconditionForBubbleUpExtractSlice(sliceOp, expandShapeOp,
rewriter)
.failed())
return failure();

// The tensor.extract_slice before applying the pattern works on the result
// of the tensor.expand_shape, so variables (i.e. inputs for ExtractSliceOp)
// referring to the state before applying the pattern are named with the
// prefix "expanded", and ones referring to the state after applying the
// pattern are named with the prefix "collapsed".
SmallVector<OpFoldResult> expandedOffsets = sliceOp.getMixedOffsets();
SmallVector<OpFoldResult> expandedSizes = sliceOp.getMixedSizes();
SmallVector<OpFoldResult> expandedShape =
getMixedValues(expandShapeOp.getStaticOutputShape(),
expandShapeOp.getOutputShape(), rewriter);

// Helper variables and function for accumulating the size values.
Location loc = expandShapeOp->getLoc();
AffineExpr d0, d1, d2;
bindDims(rewriter.getContext(), d0, d1, d2);
// Multiply two integers.
auto mul = [&](OpFoldResult v1, OpFoldResult v2) {
auto mulMap = AffineMap::get(2, 0, {d0 * d1});
return affine::makeComposedFoldedAffineApply(rewriter, loc, mulMap,
{v1, v2});
};

// Compute new offsets, sizes, and strides for tensor.extract_slice.
// The new tensor.extract_slice will work on a tensor that has has a rank of
// ReassociationIndices.size(). In the loop a single offset, size, and
// stride value is computed per reassociation group.
SmallVector<OpFoldResult> collapsedOffsets, collapsedSizes,
collapsedStrides;
for (const ReassociationIndices &indices :
expandShapeOp.getReassociationIndices()) {
// collapsedSize will hold the size of the single dim that represents the
// reassociation group in the non expanded tensor.
OpFoldResult collapsedSize = rewriter.getIndexAttr(1);
// The reassocGroupSizes and reassocGroupOffsets are used to create an
// affine.linearize_index op to linearize the single offset value required
// for this reassociation group.
SmallVector<OpFoldResult> reassocGroupSizes, reassocGroupOffsets;

for (long expandedDim : indices) {
// reassocGroupSizes and reassocGroupOffsets can be obtained directly
// from the expanded state, but the collapsed size requires calculation
// as it did not previously exist.
reassocGroupSizes.push_back(expandedShape[expandedDim]);
reassocGroupOffsets.push_back(expandedOffsets[expandedDim]);
collapsedSize = mul(collapsedSize, expandedSizes[expandedDim]);
}

SmallVector<Value> offsetVals =
llvm::map_to_vector(reassocGroupOffsets, [&](OpFoldResult ofr) {
return getValueOrCreateConstantIndexOp(rewriter, loc, ofr);
});
OpFoldResult collapsedOffset =
rewriter
.create<affine::AffineLinearizeIndexOp>(loc, offsetVals,
reassocGroupSizes,
/*disjoint=*/true)
.getResult();
collapsedOffsets.push_back(collapsedOffset);
collapsedSizes.push_back(collapsedSize);

// Only unit stride is supported.
collapsedStrides.push_back(rewriter.getIndexAttr(1));
}

// The shape of the result can be obtained from the sizes passed in.
SmallVector<Value> dynDims;
SmallVector<int64_t> shape;
dispatchIndexOpFoldResults(expandedSizes, dynDims, shape);
RankedTensorType resultType = RankedTensorType::get(
shape, expandShapeOp.getResultType().getElementType());

// Create a new ExtractSliceOp and ExpandShapeOp.
Value newSliceOp = rewriter.create<tensor::ExtractSliceOp>(
loc, expandShapeOp.getSrc(), collapsedOffsets, collapsedSizes,
collapsedStrides);
rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
sliceOp, resultType, newSliceOp,
expandShapeOp.getReassociationIndices(), expandedSizes);
return success();
}

// Helper function to check if all the required conditions for the
// tensor.extract_slice to be bubbled up through the tensor.expand_shape are
// met.
LogicalResult
checkPreconditionForBubbleUpExtractSlice(tensor::ExtractSliceOp sliceOp,
tensor::ExpandShapeOp expandShapeOp,
PatternRewriter &rewriter) const {

if (!expandShapeOp) {
return rewriter.notifyMatchFailure(
sliceOp, "tensor.extract_slice source not produced by expand_shape");
}

if (!sliceOp.hasUnitStride()) {
return rewriter.notifyMatchFailure(
sliceOp, "unsupported: non-unit stride. Only contiguous slices can "
"be supported in this transformation.");
}

SmallVector<OpFoldResult> offsets = sliceOp.getMixedOffsets();
SmallVector<OpFoldResult> sizes = sliceOp.getMixedSizes();

if (static_cast<size_t>(sliceOp.getResultType().getRank()) !=
sizes.size()) {
return rewriter.notifyMatchFailure(sliceOp,
"unimplemented: rank reducing slice");
}

SmallVector<OpFoldResult> outputShape =
getMixedValues(expandShapeOp.getStaticOutputShape(),
expandShapeOp.getOutputShape(), rewriter);

std::function<bool(OpFoldResult, OpFoldResult, OpFoldResult)>
isZeroOffsetAndFullSize =
[](OpFoldResult offset, OpFoldResult sliceSize, OpFoldResult size) {
if (!isConstantIntValue(offset, 0))
return false;
FailureOr<bool> maybeEqual =
ValueBoundsConstraintSet::areEqual(sliceSize, size);
return llvm::succeeded(maybeEqual) && maybeEqual.value();
};

// Check that the slice is contiguous within each reassociation group.
// The slice is contiguous only if after the first dimension where a non
// unit slice is taken, the slice size on all subsequent dimensions of the
// group is equal to the entire size of the dimension.
// Examples of contiguous slices:
// full sizes: [8, 8, 10] slice offsets: [0, 0, 0] slice sizes: [1, 1, 10]
// full sizes: [5, 10] slice offsets: [3, 0] slice sizes: [2, 10]
// Examples of non contiguous slices:
// full sizes: [8, 8, 10] slice offsets: [0, 0, 0] slice sizes: [1, 2, 5]
// full sizes: [5, 10] slice offsets: [0, 4] slice sizes: [2, 5]
for (const ReassociationIndices &indices :
expandShapeOp.getReassociationIndices()) {
int64_t i = 0;
int64_t e = indices.size();
// Find the first expanded dim after the first dim with non-unit extracted
// size.
for (; i < e; ++i) {
if (!isConstantIntValue(sizes[indices[i]], 1)) {
// +1 to skip the first non-unit size dim.
i++;
break;
}
}

// Verify that all subsequent dimensions extract the full size of the
// source tensor.
for (; i < e; ++i) {
int64_t expandedDim = indices[i];
if (!isZeroOffsetAndFullSize(offsets[expandedDim], sizes[expandedDim],
outputShape[expandedDim])) {
return rewriter.notifyMatchFailure(
sliceOp, "Not a contiguous slice of the expanded tensor.");
}
}
}

return success();
}
};

} // namespace

void mlir::tensor::populateReassociativeReshapeFoldingPatterns(
Expand All @@ -227,3 +439,8 @@ void mlir::tensor::populateBubbleUpExpandShapePatterns(
RewritePatternSet &patterns) {
patterns.add<BubbleUpExpandThroughParallelCollapse>(patterns.getContext());
}

void mlir::tensor::populateBubbleUpExtractSliceOpPatterns(
RewritePatternSet &patterns) {
patterns.add<BubbleUpExpandShapeThroughExtractSlice>(patterns.getContext());
}
Loading