Skip to content

Commit b8903a8

Browse files
author
Ofri Frishman
committed
[MLIR] Add pattern to bubble up tensor.extract_slice
Add a pattern that bubbles up tensor.extract_slice through tensor.expand_shape, and add a transform op to tensor dialect to directly use this pattern. This pattern enables tiling and fusing op chains which contain tensor.expand_shape if added as a cleanup pattern of tile and fuse utility. Without this pattern that would not be possible, as tensor.expand_shape does not implement the tiling interface. In addition, registering this pattern as a cleanup pattern for transform.structured.fuse. The pattren was first implement in IREE project by Quinn Dawkins and is being upstreamed.
1 parent 73413bd commit b8903a8

File tree

7 files changed

+526
-0
lines changed

7 files changed

+526
-0
lines changed

mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,17 @@ def ApplyReassociativeReshapeFoldingPatternsOp : Op<Transform_Dialect,
111111
let assemblyFormat = "attr-dict";
112112
}
113113

114+
def ApplyBubbleUpExtractSlicePatternsOp : Op<Transform_Dialect,
115+
"apply_patterns.tensor.bubble_up_extract_slice",
116+
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
117+
let description = [{
118+
Indicates that producers of tensor.extract_slice should swap and operate on
119+
the result of the slice.
120+
}];
121+
122+
let assemblyFormat = "attr-dict";
123+
}
124+
114125
def ApplyRewriteTensorOpsAsConstantPatternsOp : Op<Transform_Dialect,
115126
"apply_patterns.tensor.rewrite_as_constant",
116127
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {

mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,12 @@ void populateFoldTensorSubsetIntoVectorTransferPatterns(
5858
void populateMergeConsecutiveInsertExtractSlicePatterns(
5959
RewritePatternSet &patterns);
6060

61+
/// Appends patterns that are used to bubble up tensor.extract slice op above
62+
/// its producer. When used as cleanup patterns of tile and fuse, enables fusing
63+
/// the producer with the consumer even if the producer does not implement the
64+
/// tiling interface.
65+
void populateBubbleUpExtractSliceOpPatterns(RewritePatternSet &patterns);
66+
6167
/// Populates `patterns` with patterns that drop redundant tensor.insert_slice
6268
/// rank expansions.
6369
void populateDropRedundantInsertSliceRankExpansionPatterns(

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -582,6 +582,7 @@ transform::FuseOp::apply(transform::TransformRewriter &rewriter,
582582
RewritePatternSet patterns(context);
583583
tensor::ExtractSliceOp::getCanonicalizationPatterns(patterns, context);
584584
tensor::populateMergeConsecutiveInsertExtractSlicePatterns(patterns);
585+
tensor::populateBubbleUpExtractSliceOpPatterns(patterns);
585586
tileAndFuseOptions.cleanupPatterns = std::move(patterns);
586587
}
587588

mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,11 @@ void transform::ApplyReassociativeReshapeFoldingPatternsOp::populatePatterns(
125125
tensor::populateReassociativeReshapeFoldingPatterns(patterns);
126126
}
127127

128+
void transform::ApplyBubbleUpExtractSlicePatternsOp::populatePatterns(
129+
RewritePatternSet &patterns) {
130+
tensor::populateBubbleUpExtractSliceOpPatterns(patterns);
131+
}
132+
128133
void transform::ApplyRewriteTensorOpsAsConstantPatternsOp::populatePatterns(
129134
RewritePatternSet &patterns) {
130135
ControlFoldFn defaultControlFn = [](OpOperand *fusedOperand) {

mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp

Lines changed: 220 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,14 @@
66
//
77
//===----------------------------------------------------------------------===//
88

9+
#include "mlir/Dialect/Affine/IR/AffineOps.h"
10+
#include "mlir/Dialect/Arith/Utils/Utils.h"
911
#include "mlir/Dialect/Tensor/IR/Tensor.h"
1012
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
1113
#include "mlir/IR/PatternMatch.h"
14+
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
1215
#include "llvm/Support/Debug.h"
16+
#include "llvm/Support/LogicalResult.h"
1317

1418
using namespace mlir;
1519
using namespace mlir::tensor;
@@ -210,6 +214,217 @@ struct BubbleUpExpandThroughParallelCollapse
210214
}
211215
};
212216

217+
/// Converts `tensor.extract_slice(tensor.expand_shape)` to
218+
/// `tensor.expand_shape(tensor.extract_slice)`.
219+
/// For this transformation to be possible, the slice must be fully contiguous
220+
/// within each reassociation group of the expand_shape.
221+
/// A slice is defined as fully contiguous within a reassociation group if after
222+
/// flattening the reassociation group to a single 1D range, then the slice
223+
/// taken out of the group could be defined as a single contiguous subrange
224+
/// within that range.
225+
/// If the transformation is not possible, or if the slice is rank reducing, the
226+
/// function returns failure.
227+
///
228+
/// Example:
229+
/// ```
230+
/// %reshape = tensor.expand_shape %in [[0, 1], [2, 3], [4, 5, 6]]
231+
/// tensor<8x16x32xf32> to tensor<2x4x2x8x4x2x4xf32>
232+
/// %slice = tensor.extract_slice %reshape ...
233+
/// tensor<2x4x2x8x4x2x4xf32> to tensor<2x4x1x5x1x1x4xf32>
234+
///
235+
/// // The transformation is possible because each reassociation group has a
236+
/// // contiguous slice. (i.e., [2x4->2x4], [2x8->1x5], [4x2x4->1x1x4])
237+
/// // After the transformation:
238+
///
239+
/// %slice = tensor.extract_slice %in ...
240+
/// tensor<8x16x32xf32> to tensor<8x5x4xf32>
241+
/// %reshape = tensor.expand_shape %slice [[0, 1], [2, 3], [4, 5, 6]]
242+
/// tensor<8x5x4xf32> to tensor<2x4x1x5x1x1x4xf32>
243+
/// ```
244+
///
245+
/// Note - this pattern could be reworked to be a swap pattern between
246+
/// `tensor.expand_shape` and `tensor.extract_slice`, but is currently
247+
/// implemented only as a bubble up pattern for `tensor.extract_slice`.
248+
struct BubbleUpExpandShapeThroughExtractSlice
249+
: public OpRewritePattern<tensor::ExtractSliceOp> {
250+
using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern;
251+
252+
LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
253+
PatternRewriter &rewriter) const override {
254+
auto expandShapeOp =
255+
sliceOp.getSource().getDefiningOp<tensor::ExpandShapeOp>();
256+
257+
if (checkPreconditionForBubbleUpExtractSlice(sliceOp, expandShapeOp,
258+
rewriter)
259+
.failed())
260+
return failure();
261+
262+
// The tensor.extract_slice before applying the pattern works on the result
263+
// of the tensor.expand_shape, so variables referring to the state before
264+
// applying the pattern are named with the prefix "expanded", and ones
265+
// referring to the state after applying the pattern are named with the
266+
// prefix "collapsed".
267+
SmallVector<OpFoldResult> expandedOffsets = sliceOp.getMixedOffsets();
268+
SmallVector<OpFoldResult> expandedSizes = sliceOp.getMixedSizes();
269+
SmallVector<OpFoldResult> expandedShape =
270+
getMixedValues(expandShapeOp.getStaticOutputShape(),
271+
expandShapeOp.getOutputShape(), rewriter);
272+
273+
// Helper variables and function for accumulating the size values.
274+
Location loc = expandShapeOp->getLoc();
275+
AffineExpr d0, d1, d2;
276+
bindDims(rewriter.getContext(), d0, d1, d2);
277+
// Multiply two integers.
278+
auto mul = [&](OpFoldResult v1, OpFoldResult v2) {
279+
auto mulMap = AffineMap::get(2, 0, {d0 * d1});
280+
return affine::makeComposedFoldedAffineApply(rewriter, loc, mulMap,
281+
{v1, v2});
282+
};
283+
284+
// Compute new offsets, sizes, and strides for tensor.extract_slice.
285+
// The new tensor.extract_slice will work on a tensor that has has a rank of
286+
// ReassociationIndices.size(). In the loop a single offset, size, and
287+
// stride value is computed per reassociation group.
288+
SmallVector<OpFoldResult> collapsedOffsets, collapsedSizes,
289+
collapsedStrides;
290+
for (const ReassociationIndices &indices :
291+
expandShapeOp.getReassociationIndices()) {
292+
// collapsedSize will hold the size of the single dim that represents the
293+
// reassociation group in the non expanded tensor.
294+
OpFoldResult collapsedSize = rewriter.getIndexAttr(1);
295+
// The basis and delinOffsets are used to create an affine.linearize_index
296+
// op to linearize the single offset value required for this reassociation
297+
// group.
298+
// basis holds the full sizes of the reassociation group dimensions
299+
// of the expanded tensor.
300+
// delinOffsets as in "delinearized offsets", holds the offsets within the
301+
// reassociation group dimensions of the expanded tensor.
302+
SmallVector<OpFoldResult> basis, delinOffsets;
303+
304+
for (long expandedDim : indices) {
305+
// basis and delinOffsets can be obtained directly from the expanded
306+
// state, but the collapsed size requires calculation as it did not
307+
// previously exist.
308+
basis.push_back(expandedShape[expandedDim]);
309+
delinOffsets.push_back(expandedOffsets[expandedDim]);
310+
collapsedSize = mul(collapsedSize, expandedSizes[expandedDim]);
311+
}
312+
313+
SmallVector<Value> offsetVals =
314+
llvm::map_to_vector(delinOffsets, [&](OpFoldResult ofr) {
315+
return getValueOrCreateConstantIndexOp(rewriter, loc, ofr);
316+
});
317+
OpFoldResult collapsedOffset =
318+
rewriter
319+
.create<affine::AffineLinearizeIndexOp>(loc, offsetVals, basis,
320+
/*disjoint=*/true)
321+
.getResult();
322+
collapsedOffsets.push_back(collapsedOffset);
323+
collapsedSizes.push_back(collapsedSize);
324+
325+
// Only unit stride supported.
326+
collapsedStrides.push_back(rewriter.getIndexAttr(1));
327+
}
328+
329+
// The shape of the result can be obtained from the sizes passed in.
330+
SmallVector<Value> dynDims;
331+
SmallVector<int64_t> shape;
332+
dispatchIndexOpFoldResults(expandedSizes, dynDims, shape);
333+
RankedTensorType resultType = RankedTensorType::get(
334+
shape, expandShapeOp.getResultType().getElementType());
335+
336+
// Create a new ExtractSliceOp and ExpandShapeOp.
337+
Value newSliceOp = rewriter.create<tensor::ExtractSliceOp>(
338+
loc, expandShapeOp.getSrc(), collapsedOffsets, collapsedSizes,
339+
collapsedStrides);
340+
rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
341+
sliceOp, resultType, newSliceOp,
342+
expandShapeOp.getReassociationIndices(), expandedSizes);
343+
return success();
344+
}
345+
346+
// Helper function to check if all the required conditions for the
347+
// tensor.extract_slice to be bubbled up through the tensor.expand_shape are
348+
// met.
349+
LogicalResult
350+
checkPreconditionForBubbleUpExtractSlice(tensor::ExtractSliceOp sliceOp,
351+
tensor::ExpandShapeOp expandShapeOp,
352+
PatternRewriter &rewriter) const {
353+
354+
if (!expandShapeOp) {
355+
return rewriter.notifyMatchFailure(
356+
sliceOp, "tensor.extract_slice source not produced by expand_shape");
357+
}
358+
359+
if (!sliceOp.hasUnitStride()) {
360+
return rewriter.notifyMatchFailure(
361+
sliceOp, "unsupported: non-unit stride. Only contiguous slices can "
362+
"be supported in this transformation.");
363+
}
364+
365+
SmallVector<OpFoldResult> offsets = sliceOp.getMixedOffsets();
366+
SmallVector<OpFoldResult> sizes = sliceOp.getMixedSizes();
367+
368+
if (static_cast<size_t>(sliceOp.getResultType().getRank()) !=
369+
sizes.size()) {
370+
return rewriter.notifyMatchFailure(sliceOp,
371+
"unimplemented: rank reducing slice");
372+
}
373+
374+
SmallVector<OpFoldResult> outputShape =
375+
getMixedValues(expandShapeOp.getStaticOutputShape(),
376+
expandShapeOp.getOutputShape(), rewriter);
377+
378+
std::function<bool(OpFoldResult, OpFoldResult, OpFoldResult)>
379+
isZeroOffsetAndFullSize =
380+
[](OpFoldResult offset, OpFoldResult sliceSize, OpFoldResult size) {
381+
if (!isConstantIntValue(offset, 0))
382+
return false;
383+
FailureOr<bool> maybeEqual =
384+
ValueBoundsConstraintSet::areEqual(sliceSize, size);
385+
return llvm::succeeded(maybeEqual) && maybeEqual.value();
386+
};
387+
388+
// Check that the slice is contiguous within each reassociation group.
389+
// The slice is contiguous only if after the first dimension where a non
390+
// unit slice is taken, the slice size on all subsequent dimensions of the
391+
// group is equal to the entire size of the dimension.
392+
// Examples of contiguous slices:
393+
// full sizes: [8, 8, 10] slice offsets: [0, 0, 0] slice sizes: [1, 1, 10]
394+
// full sizes: [5, 10] slice offsets: [3, 0] slice sizes: [2, 10]
395+
// Examples of non contiguous slices:
396+
// full sizes: [8, 8, 10] slice offsets: [0, 0, 0] slice sizes: [1, 2, 5]
397+
// full sizes: [5, 10] slice offsets: [0, 4] slice sizes: [2, 5]
398+
for (const ReassociationIndices &indices :
399+
expandShapeOp.getReassociationIndices()) {
400+
int64_t i = 0;
401+
int64_t e = indices.size();
402+
// Find the first expanded dim after the first dim with non-unit extracted
403+
// size.
404+
for (; i < e; ++i) {
405+
if (!isConstantIntValue(sizes[indices[i]], 1)) {
406+
// +1 to skip the first non-unit size dim.
407+
i++;
408+
break;
409+
}
410+
}
411+
412+
// Verify that all subsequent dimensions extract the full size of the
413+
// source tensor.
414+
for (; i < e; ++i) {
415+
int64_t expandedDim = indices[i];
416+
if (!isZeroOffsetAndFullSize(offsets[expandedDim], sizes[expandedDim],
417+
outputShape[expandedDim])) {
418+
return rewriter.notifyMatchFailure(
419+
sliceOp, "Not a contiguous slice of the expanded tensor.");
420+
}
421+
}
422+
}
423+
424+
return success();
425+
}
426+
};
427+
213428
} // namespace
214429

215430
void mlir::tensor::populateReassociativeReshapeFoldingPatterns(
@@ -227,3 +442,8 @@ void mlir::tensor::populateBubbleUpExpandShapePatterns(
227442
RewritePatternSet &patterns) {
228443
patterns.add<BubbleUpExpandThroughParallelCollapse>(patterns.getContext());
229444
}
445+
446+
void mlir::tensor::populateBubbleUpExtractSliceOpPatterns(
447+
RewritePatternSet &patterns) {
448+
patterns.add<BubbleUpExpandShapeThroughExtractSlice>(patterns.getContext());
449+
}

0 commit comments

Comments
 (0)