Skip to content

Commit 25a359b

Browse files
author
Ofri Frishman
committed
[MLIR] Add pattern to bubble up tensor.extract_slice
Add a pattern that bubbles up tensor.extract_slice through tensor.expand_shape, and add a transform op to tensor dialect to directly use this pattern. This pattern enables tiling and fusing op chains which contain tensor.expand_shape if added as a cleanup pattern of tile and fuse utility. Without this pattern that would not be possible, as tensor.expand_shape does not implement the tiling interface. In addition, registering this pattern as a cleanup pattern for transform.structured.fuse. The pattren was first implement in IREE project by Quinn Dawkins and is being upstreamed.
1 parent 73413bd commit 25a359b

File tree

7 files changed

+478
-0
lines changed

7 files changed

+478
-0
lines changed

mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,16 @@ def ApplyReassociativeReshapeFoldingPatternsOp : Op<Transform_Dialect,
111111
let assemblyFormat = "attr-dict";
112112
}
113113

114+
def ApplyBubbleUpExtractSlicePatternsOp : Op<Transform_Dialect,
115+
"apply_patterns.tensor.bubble_up_extract_slice",
116+
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
117+
let description = [{
118+
Indicates that tensor.extract_slice and its producer should swap location.
119+
}];
120+
121+
let assemblyFormat = "attr-dict";
122+
}
123+
114124
def ApplyRewriteTensorOpsAsConstantPatternsOp : Op<Transform_Dialect,
115125
"apply_patterns.tensor.rewrite_as_constant",
116126
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {

mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,12 @@ void populateFoldTensorSubsetIntoVectorTransferPatterns(
5858
void populateMergeConsecutiveInsertExtractSlicePatterns(
5959
RewritePatternSet &patterns);
6060

61+
/// Appends patterns that are used to bubble up tensor.extract slice op above
62+
/// its producer. When used as cleanup patterns of tile and fuse, enables fusing
63+
/// the producer with the consumer even if the producer does not implement the
64+
/// tiling interface.
65+
void populateBubbleUpExtractSliceOpPatterns(RewritePatternSet &patterns);
66+
6167
/// Populates `patterns` with patterns that drop redundant tensor.insert_slice
6268
/// rank expansions.
6369
void populateDropRedundantInsertSliceRankExpansionPatterns(

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -582,6 +582,7 @@ transform::FuseOp::apply(transform::TransformRewriter &rewriter,
582582
RewritePatternSet patterns(context);
583583
tensor::ExtractSliceOp::getCanonicalizationPatterns(patterns, context);
584584
tensor::populateMergeConsecutiveInsertExtractSlicePatterns(patterns);
585+
tensor::populateBubbleUpExtractSliceOpPatterns(patterns);
585586
tileAndFuseOptions.cleanupPatterns = std::move(patterns);
586587
}
587588

mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,11 @@ void transform::ApplyReassociativeReshapeFoldingPatternsOp::populatePatterns(
125125
tensor::populateReassociativeReshapeFoldingPatterns(patterns);
126126
}
127127

128+
void transform::ApplyBubbleUpExtractSlicePatternsOp::populatePatterns(
129+
RewritePatternSet &patterns) {
130+
tensor::populateBubbleUpExtractSliceOpPatterns(patterns);
131+
}
132+
128133
void transform::ApplyRewriteTensorOpsAsConstantPatternsOp::populatePatterns(
129134
RewritePatternSet &patterns) {
130135
ControlFoldFn defaultControlFn = [](OpOperand *fusedOperand) {

mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp

Lines changed: 203 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,14 @@
66
//
77
//===----------------------------------------------------------------------===//
88

9+
#include "mlir/Dialect/Affine/IR/AffineOps.h"
10+
#include "mlir/Dialect/Arith/Utils/Utils.h"
911
#include "mlir/Dialect/Tensor/IR/Tensor.h"
1012
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
1113
#include "mlir/IR/PatternMatch.h"
14+
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
1215
#include "llvm/Support/Debug.h"
16+
#include "llvm/Support/LogicalResult.h"
1317

1418
using namespace mlir;
1519
using namespace mlir::tensor;
@@ -210,6 +214,200 @@ struct BubbleUpExpandThroughParallelCollapse
210214
}
211215
};
212216

217+
/// Converts `tensor.extract_slice(tensor.expand_shape)` to
218+
/// `tensor.expand_shape(tensor.extract_slice)`.
219+
/// For this transformation to be possible, the slice must be fully contiguous
220+
/// within each reassociation group of the expand_shape. If the transformation
221+
/// is not possible, or if the slice is rank reducing, the function returns
222+
/// failure.
223+
///
224+
/// Example:
225+
/// ```
226+
/// %reshape = tensor.expand_shape %in [[0, 1], [2, 3], [4, 5, 6]]
227+
/// tensor<8x16x32xf32> to tensor<2x4x2x8x4x2x4xf32>
228+
/// %slice = tensor.extract_slice %reshape ...
229+
/// tensor<2x4x2x8x4x2x4xf32> to tensor<2x4x1x5x1x1x4xf32>
230+
///
231+
/// // The transformation is possible because each reassociation group has a
232+
/// // contiguous slice. (i.e., [2x4->2x4], [2x8->1x5], [4x2x4->1x1x4])
233+
/// // After the transformation:
234+
///
235+
/// %slice = tensor.extract_slice %in ...
236+
/// tensor<8x16x32xf32> to tensor<8x5x4xf32>
237+
/// %reshape = tensor.expand_shape %slice [[0, 1], [2, 3], [4, 5, 6]]
238+
/// tensor<8x5x4xf32> to tensor<2x4x1x5x1x1x4xf32>
239+
/// ```
240+
///
241+
/// Note - this pattern could be reworked to be a swap pattern between
242+
/// `tensor.expand_shape` and `tensor.extract_slice`, but is currently
243+
/// implemented only as a bubble up pattern for `tensor.extract_slice`.
244+
struct BubbleUpExpandShapeThroughExtractSlice
245+
: public OpRewritePattern<tensor::ExtractSliceOp> {
246+
using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern;
247+
248+
LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
249+
PatternRewriter &rewriter) const override {
250+
auto expandShapeOp =
251+
sliceOp.getSource().getDefiningOp<tensor::ExpandShapeOp>();
252+
253+
if (checkPreconditionForBubbleUpExtractSlice(sliceOp, expandShapeOp,
254+
rewriter)
255+
.failed())
256+
return failure();
257+
258+
SmallVector<OpFoldResult> offsets = sliceOp.getMixedOffsets();
259+
SmallVector<OpFoldResult> sizes = sliceOp.getMixedSizes();
260+
SmallVector<OpFoldResult> outputShape =
261+
getMixedValues(expandShapeOp.getStaticOutputShape(),
262+
expandShapeOp.getOutputShape(), rewriter);
263+
264+
// Helper variables and function for accumulating the new offset and length
265+
// values.
266+
Location loc = expandShapeOp->getLoc();
267+
AffineExpr d0, d1, d2;
268+
bindDims(rewriter.getContext(), d0, d1, d2);
269+
// Multiply two integers.
270+
auto mul = [&](OpFoldResult v1, OpFoldResult v2) {
271+
auto mulMap = AffineMap::get(2, 0, {d0 * d1});
272+
return affine::makeComposedFoldedAffineApply(rewriter, loc, mulMap,
273+
{v1, v2});
274+
};
275+
276+
// Compute new offsets, lengths, and strides.
277+
SmallVector<OpFoldResult> newOffsets, newLengths, newStrides;
278+
for (const ReassociationIndices &indices :
279+
expandShapeOp.getReassociationIndices()) {
280+
OpFoldResult newSize = rewriter.getIndexAttr(1);
281+
SmallVector<OpFoldResult> basis, delinOffsets;
282+
283+
int64_t i = 0;
284+
int64_t e = indices.size();
285+
// Offset = cumulative product of leading unit extracted dims.
286+
for (; i < e; ++i) {
287+
int64_t expandedDim = indices[i];
288+
if (!isConstantIntValue(sizes[expandedDim], 1))
289+
break;
290+
291+
basis.push_back(outputShape[expandedDim]);
292+
delinOffsets.push_back(offsets[expandedDim]);
293+
}
294+
295+
if (i != e) {
296+
int64_t expandedDim = indices[i];
297+
basis.push_back(outputShape[expandedDim]);
298+
delinOffsets.push_back(offsets[expandedDim]);
299+
newSize = sizes[expandedDim];
300+
i++;
301+
}
302+
303+
for (; i < e; ++i) {
304+
OpFoldResult fullSize = outputShape[indices[i]];
305+
basis.push_back(fullSize);
306+
delinOffsets.push_back(rewriter.getIndexAttr(0));
307+
newSize = mul(newSize, fullSize);
308+
}
309+
SmallVector<Value> offsetVals =
310+
llvm::map_to_vector(delinOffsets, [&](OpFoldResult ofr) {
311+
return getValueOrCreateConstantIndexOp(rewriter, loc, ofr);
312+
});
313+
OpFoldResult newOffset =
314+
rewriter
315+
.create<affine::AffineLinearizeIndexOp>(loc, offsetVals, basis,
316+
/*disjoint=*/true)
317+
.getResult();
318+
newOffsets.push_back(newOffset);
319+
newLengths.push_back(newSize);
320+
321+
// Only unit stride supported.
322+
newStrides.push_back(rewriter.getIndexAttr(1));
323+
}
324+
325+
// The shape of the result can be obtained from the sizes passed in.
326+
SmallVector<Value> dynDims;
327+
SmallVector<int64_t> shape;
328+
dispatchIndexOpFoldResults(sizes, dynDims, shape);
329+
RankedTensorType resultType = RankedTensorType::get(
330+
shape, expandShapeOp.getResultType().getElementType());
331+
332+
// Create a new ExtractSliceOp and ExpandShapeOp.
333+
Value newSliceOp = rewriter.create<tensor::ExtractSliceOp>(
334+
loc, expandShapeOp.getSrc(), newOffsets, newLengths, newStrides);
335+
rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
336+
sliceOp, resultType, newSliceOp,
337+
expandShapeOp.getReassociationIndices(), sizes);
338+
return success();
339+
}
340+
341+
LogicalResult
342+
checkPreconditionForBubbleUpExtractSlice(tensor::ExtractSliceOp sliceOp,
343+
tensor::ExpandShapeOp expandShapeOp,
344+
PatternRewriter &rewriter) const {
345+
346+
if (!expandShapeOp) {
347+
return rewriter.notifyMatchFailure(
348+
sliceOp, "tensor.extract_slice source not produced by expand_shape");
349+
}
350+
351+
if (!sliceOp.hasUnitStride()) {
352+
return rewriter.notifyMatchFailure(
353+
sliceOp, "unsupported: non-unit stride. Only contiguous slices can "
354+
"be supported in this transformation.");
355+
}
356+
357+
SmallVector<OpFoldResult> offsets = sliceOp.getMixedOffsets();
358+
SmallVector<OpFoldResult> sizes = sliceOp.getMixedSizes();
359+
360+
if (static_cast<size_t>(sliceOp.getResultType().getRank()) !=
361+
sizes.size()) {
362+
return rewriter.notifyMatchFailure(sliceOp,
363+
"unimplemented: rank reducing slice");
364+
}
365+
366+
SmallVector<OpFoldResult> outputShape =
367+
getMixedValues(expandShapeOp.getStaticOutputShape(),
368+
expandShapeOp.getOutputShape(), rewriter);
369+
370+
std::function<bool(OpFoldResult, OpFoldResult, OpFoldResult)>
371+
isZeroOffsetAndFullSize =
372+
[](OpFoldResult offset, OpFoldResult sliceSize, OpFoldResult size) {
373+
if (!isConstantIntValue(offset, 0))
374+
return false;
375+
FailureOr<bool> maybeEqual =
376+
ValueBoundsConstraintSet::areEqual(sliceSize, size);
377+
return llvm::succeeded(maybeEqual) && maybeEqual.value();
378+
};
379+
380+
// First verify that this is a full slice of the expanded tensor.
381+
for (const ReassociationIndices &indices :
382+
expandShapeOp.getReassociationIndices()) {
383+
int64_t i = 0;
384+
int64_t e = indices.size();
385+
// Find the first expanded dim after the first dim with non-unit extracted
386+
// size.
387+
for (; i < e; ++i) {
388+
if (!isConstantIntValue(sizes[indices[i]], 1)) {
389+
// +1 to skip the first non-unit size dim.
390+
i++;
391+
break;
392+
}
393+
}
394+
395+
// Verify that all subsequent dimensions extract the full size of the
396+
// source tensor.
397+
for (; i < e; ++i) {
398+
int64_t expandedDim = indices[i];
399+
if (!isZeroOffsetAndFullSize(offsets[expandedDim], sizes[expandedDim],
400+
outputShape[expandedDim])) {
401+
return rewriter.notifyMatchFailure(
402+
sliceOp, "Not a contiguous slice of the expanded tensor.");
403+
}
404+
}
405+
}
406+
407+
return success();
408+
}
409+
};
410+
213411
} // namespace
214412

215413
void mlir::tensor::populateReassociativeReshapeFoldingPatterns(
@@ -227,3 +425,8 @@ void mlir::tensor::populateBubbleUpExpandShapePatterns(
227425
RewritePatternSet &patterns) {
228426
patterns.add<BubbleUpExpandThroughParallelCollapse>(patterns.getContext());
229427
}
428+
429+
void mlir::tensor::populateBubbleUpExtractSliceOpPatterns(
430+
RewritePatternSet &patterns) {
431+
patterns.add<BubbleUpExpandShapeThroughExtractSlice>(patterns.getContext());
432+
}

0 commit comments

Comments
 (0)