Skip to content

Commit 6f1347d

Browse files
[MLIR] Bubble up tensor.extract_slice through tensor.collapse_shape (#131982)
Add a pattern that bubbles up tensor.extract_slice through tensor.collapse_shape. The pattern is registered in a pattern population function that is used by the transform op transform.apply_patterns.tensor.bubble_up_extract_slice and by the tranform op transform.structured.fuse as a cleanup pattern. This pattern enables tiling and fusing op chains which contain tensor.collapse_shape if added as a cleanup pattern of tile and fuse utility. Without this pattern that would not be possible, as tensor.collapse_shape does not implement the tiling interface. This is an additional pattern to the one added in PR #126898
1 parent c87dc2b commit 6f1347d

File tree

3 files changed

+476
-1
lines changed

3 files changed

+476
-1
lines changed

mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp

Lines changed: 253 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
1313
#include "mlir/IR/PatternMatch.h"
1414
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
15+
#include "llvm/ADT/STLExtras.h"
1516
#include "llvm/Support/Debug.h"
1617
#include "llvm/Support/LogicalResult.h"
1718

@@ -428,6 +429,256 @@ struct BubbleUpExpandShapeThroughExtractSlice
428429
}
429430
};
430431

432+
/// Converts `tensor.extract_slice(tensor.collapse_shape)` to
433+
/// `tensor.collapse_shape(tensor.extract_slice)`.
434+
///
435+
/// For this transformation to be possible - after bubbling up, the extraction
436+
/// of the contiguous slice must be representable as a single slice obtained via
437+
/// tensor.extract_slice within each reassociation group of the src.
438+
///
439+
/// In case the size and offset extracted are static then this is possible if
440+
/// the following conditions are met within each reassociation group:
441+
/// Let T be a tensor of shape [A0, A1, ..., An] (these are the sizes of the
442+
/// dimensions in the reassociation group), and let S = [S0, S1, ..., Sn] be the
443+
/// shape of a desired slice. A slice of shape S can be extracted as a
444+
/// contiguous span of elements if and only if there exists an index k in {0, 1,
445+
/// ..., n} such that:
446+
/// S_i = 1 for all i < k (that is, all leading dimensions are singleton),
447+
/// 1 <= S_k <= A_k (that is, non trivial slicing occurs along exactly
448+
/// one dimension),
449+
/// S_i = A_i for all i > k (that is, all trailing dimensions are preserved
450+
/// in full).
451+
/// In other words, the slice shape S must be of the form:
452+
/// [ 1, 1, ..., 1, Sk, Ak + 1, Ak + 2, ...,An ]
453+
///
454+
/// In case the size and/or offset extracted are dynamic then this is possible
455+
/// only if there is single dimension in the reassociation group that has a size
456+
/// not equal to 1.
457+
/// In other words, the tensor shape must be of the form:
458+
/// [ 1, 1, ..., 1, A, 1, ...,1 ]
459+
/// Note - it might be possible to enable this pattern for more cases when the
460+
/// size/offset are dynamic via performing an analysis of the possible values
461+
/// that could be given to the size/offset.
462+
///
463+
/// Example:
464+
/// The transformation is possible because each reassociation group can be
465+
/// represented as a contiguous slice (i.e., [8x16->2x16], [1x7->1x?],
466+
/// [20->10]).
467+
/// ```
468+
/// BEFORE:
469+
/// %collapse = tensor.collapse_shape %src [[0, 1], [2, 3], [4]] ...
470+
/// tensor<8x16x1x7x20f32> to tensor<128x7x20xf32>
471+
/// %slice = tensor.extract_slice %slice [0, 0, 0][32, %size, 10][1, 1, 1]
472+
/// tensor<128x7x20xf32> to tensor<32x?x10xf32>
473+
///
474+
/// AFTER:
475+
/// %slice = tensor.extract_slice %src [0, 0, 0, 0, 0][2, 16, 1, %size, 10]
476+
// [1, 1, 1, 1, 1] : tensor<8x16x1x7x20f32> to tensor<2x16x1x?x10xf32>
477+
/// %collapse = tensor.collapse_shape %slice [[0, 1], [2, 3], [4]] ...
478+
/// tensor<2x16x1x?x10xf32> to tensor<32x?x10xf32>
479+
/// ```
480+
///
481+
/// Negative example:
482+
/// The transformation is not possible because we cannot use a single slice to
483+
/// represent the reassociation group [2x3x10->???]. If we would want the
484+
/// collapse to be after the extraction, we would need to extract multiple
485+
/// slices and concat them together.
486+
/// ```
487+
/// %collapse = tensor.collapse_shape %src [[0, 1, 2]] : tensor<2x3x10xf32> into
488+
/// tensor<60xf32> %extract = tensor.extract_slice %collapse[0][15][1] :
489+
/// tensor<60xf32> to tensor<15xf32>
490+
/// ```
491+
/// If we would want the collapse to be after the extraction, a possible
492+
/// alternate transformation could be to extract multiple slices and concat them
493+
/// together:
494+
/// ```
495+
/// %extract_1 = tensor.extract_slice %src[0, 0, 0][1, 1, 10] :
496+
/// tensor<2x3x10xf32> to tensor <1x1x10xf32>
497+
/// %extract_2 = tensor.extract_slice %src[0, 1, 0][1, 1, 5] :
498+
/// tensor<2x3x10xf32> to tensor <1x1x5xf32>
499+
/// %concat = tosa.concat %extract_1, %extract_2 {axis = 0 : i32} :
500+
/// (<1x1x10xf32>, <1x1x5xf32>) -> <1x1x15xf32>
501+
/// %collapse = tensor.collapse_shape %concat [[0, 1, 2]] : tensor<1x1x15xf32>
502+
/// to tensor<15xf32>
503+
/// ```
504+
/// But this is not the intended purpose of the transformation.
505+
struct BubbleUpCollapseShapeThroughExtractSlice
506+
: public OpRewritePattern<tensor::ExtractSliceOp> {
507+
using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern;
508+
509+
LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
510+
PatternRewriter &rewriter) const override {
511+
auto collapseShapeOp =
512+
sliceOp.getSource().getDefiningOp<tensor::CollapseShapeOp>();
513+
if (!collapseShapeOp) {
514+
return rewriter.notifyMatchFailure(
515+
sliceOp,
516+
"tensor.extract_slice source not produced by tensor.collapse_shape");
517+
}
518+
519+
if (!sliceOp.hasUnitStride()) {
520+
return rewriter.notifyMatchFailure(
521+
sliceOp, "unsupported: non-unit stride. Only contiguous slices can "
522+
"be supported in this transformation.");
523+
}
524+
525+
// The tensor.extract_slice before applying the pattern works on the result
526+
// of the tensor.collapse_shape, so variables (i.e. inputs for
527+
// ExtractSliceOp) referring to the state before applying the pattern are
528+
// named with the prefix "collapsed", and ones referring to the state after
529+
// applying the pattern are named with the prefix "expanded".
530+
SmallVector<OpFoldResult> collapsedOffsets = sliceOp.getMixedOffsets();
531+
SmallVector<OpFoldResult> collapsedSizes = sliceOp.getMixedSizes();
532+
533+
if (static_cast<size_t>(sliceOp.getResultType().getRank()) !=
534+
collapsedSizes.size()) {
535+
return rewriter.notifyMatchFailure(sliceOp,
536+
"unimplemented: rank reducing slice");
537+
}
538+
539+
ArrayRef<int64_t> srcShape = collapseShapeOp.getSrcType().getShape();
540+
SmallVector<ReassociationIndices, 4> reassociationIndices =
541+
collapseShapeOp.getReassociationIndices();
542+
543+
// Compute new offsets, sizes, and strides for tensor.extract_slice.
544+
// The new tensor.extract_slice will work on a tensor that has has a rank
545+
// equal to the rank of the src of the collapse_shape. In each iteration of
546+
// the loop, the offsets and sizes will be computed per reassociation group.
547+
SmallVector<OpFoldResult> expandedOffsets, expandedSizes;
548+
SmallVector<OpFoldResult> expandedStrides(srcShape.size(),
549+
rewriter.getIndexAttr(1));
550+
551+
for (auto [collapsedSize, collapsedOffset, reassocIndices] :
552+
llvm::zip_equal(collapsedSizes, collapsedOffsets,
553+
collapseShapeOp.getReassociationIndices())) {
554+
// CASE #1 - size and/or offset are dynamic.
555+
// In this case, the slice can be represented as a contiguous slice only
556+
// if there is a single dimension in the reassociation group that has a
557+
// size not equal to 1.
558+
if (isa<Value>(collapsedSize) || isa<Value>(collapsedOffset)) {
559+
int nonUnitSizeCount = 0;
560+
for (int64_t expandedShapeIdx : reassocIndices) {
561+
if (srcShape[expandedShapeIdx] != 1) {
562+
nonUnitSizeCount++;
563+
expandedSizes.push_back(collapsedSize);
564+
expandedOffsets.push_back(collapsedOffset);
565+
continue;
566+
}
567+
568+
expandedSizes.push_back(rewriter.getIndexAttr(1));
569+
expandedOffsets.push_back(rewriter.getIndexAttr(0));
570+
}
571+
572+
if (nonUnitSizeCount != 1) {
573+
return rewriter.notifyMatchFailure(
574+
sliceOp,
575+
"unsupported: slice cannot be verified to be contiguous");
576+
}
577+
continue;
578+
}
579+
580+
// CASE #2 = size and offset are static.
581+
// Verify that the slice can be represented as a contiguous slice of the
582+
// src of the collapse_shape.
583+
// Checking this is done on order of most internal dimensions first,
584+
// so traversal is done in reverse order of the reassociation group.
585+
// If the expected slice shape is [1, 1, ..., 1, Sk, Ak + 1, Ak + 2,
586+
// ...,An] then we first find the size and offset for n...k+1 then for k
587+
// and then for k-1...0.
588+
589+
// currentCollapsedsize and currentCollapsedOffset are initialized with
590+
// the original collapsed size and offset and divided by the expanded
591+
// shape size in each dimension as we go along the reassociation group.
592+
// In essence we are spreading the original collapsed size and offset over
593+
// the various expanded slice dimensions.
594+
// The variables are used both to check the validity of the slice and to
595+
// compute the expanded sizes and offsets.
596+
int64_t currentCollapsedsize = getConstantIntValue(collapsedSize).value();
597+
int64_t currentCollapsedOffset =
598+
getConstantIntValue(collapsedOffset).value();
599+
600+
SmallVector<OpFoldResult> groupExpandedSizes, groupExpandedOffsets;
601+
602+
ReassociationIndices reversedReassocIndices(reassocIndices.rbegin(),
603+
reassocIndices.rend());
604+
int64_t idx = 0;
605+
int64_t reassocGroupSize = reassocIndices.size();
606+
607+
// First handle the trailing dimensions where the slice size should be
608+
// equal to the tensor shape and the offset should be 0 (n...k+1).
609+
for (; idx < reassocGroupSize; ++idx) {
610+
int64_t expandedShapeSize = srcShape[reversedReassocIndices[idx]];
611+
612+
if (currentCollapsedsize < expandedShapeSize)
613+
break;
614+
615+
// We need to make sure that the slice size can be set to the shape size
616+
// and the offset to 0.
617+
if ((currentCollapsedsize % expandedShapeSize) != 0 ||
618+
(currentCollapsedOffset % expandedShapeSize) != 0) {
619+
return rewriter.notifyMatchFailure(
620+
sliceOp, "unsupported: cannot be extracted as a contiguous slice "
621+
"of the src of the collapse_shape");
622+
}
623+
624+
groupExpandedSizes.push_back(rewriter.getIndexAttr(expandedShapeSize));
625+
groupExpandedOffsets.push_back(rewriter.getIndexAttr(0));
626+
627+
currentCollapsedsize /= expandedShapeSize;
628+
currentCollapsedOffset /= expandedShapeSize;
629+
}
630+
631+
// Now handle the first dim where slicing occurs on (k).
632+
if (idx < reassocGroupSize) {
633+
int64_t expandedShapeSize = srcShape[reversedReassocIndices[idx]];
634+
int64_t offsetInDim = currentCollapsedOffset % expandedShapeSize;
635+
// We need to make sure that the slice size in this dim + offset will
636+
// not exceed the shape size.
637+
if ((currentCollapsedsize + offsetInDim) >= expandedShapeSize) {
638+
return rewriter.notifyMatchFailure(
639+
sliceOp, "unsupported: slice cannot be extracted as a contiguous "
640+
"slice of the src of the collapse_shape");
641+
}
642+
643+
groupExpandedSizes.push_back(
644+
rewriter.getIndexAttr(currentCollapsedsize));
645+
groupExpandedOffsets.push_back(rewriter.getIndexAttr(offsetInDim));
646+
647+
currentCollapsedOffset /= expandedShapeSize;
648+
}
649+
650+
// Now handle the leading dimensions where the slice size is equal to 1
651+
// (k-1...0).
652+
// The size for these dimensions must be 1 because of how we constructed
653+
// the slice size of the expanded shape. We spread the original collapsed
654+
// size over the expanded shape sizes until we reached dimension k where
655+
// the remaining size was smaller than the expanded shape size, and spread
656+
// the remaining size on it. So, now we are left with only 1s.
657+
for (idx++; idx < reassocGroupSize; ++idx) {
658+
int64_t expandedShapeSize = srcShape[reversedReassocIndices[idx]];
659+
int64_t offsetInDim = currentCollapsedOffset % expandedShapeSize;
660+
groupExpandedSizes.push_back(rewriter.getIndexAttr(1));
661+
groupExpandedOffsets.push_back(rewriter.getIndexAttr(offsetInDim));
662+
currentCollapsedOffset /= expandedShapeSize;
663+
}
664+
665+
expandedSizes.append(groupExpandedSizes.rbegin(),
666+
groupExpandedSizes.rend());
667+
expandedOffsets.append(groupExpandedOffsets.rbegin(),
668+
groupExpandedOffsets.rend());
669+
}
670+
671+
Value newSliceOp = rewriter.create<tensor::ExtractSliceOp>(
672+
collapseShapeOp->getLoc(), collapseShapeOp.getSrc(), expandedOffsets,
673+
expandedSizes, expandedStrides);
674+
rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
675+
sliceOp, sliceOp.getResultType(), newSliceOp,
676+
collapseShapeOp.getReassociationIndices());
677+
678+
return success();
679+
}
680+
};
681+
431682
} // namespace
432683

433684
void mlir::tensor::populateReassociativeReshapeFoldingPatterns(
@@ -448,5 +699,6 @@ void mlir::tensor::populateBubbleUpExpandShapePatterns(
448699

449700
void mlir::tensor::populateBubbleUpExtractSliceOpPatterns(
450701
RewritePatternSet &patterns) {
451-
patterns.add<BubbleUpExpandShapeThroughExtractSlice>(patterns.getContext());
702+
patterns.add<BubbleUpExpandShapeThroughExtractSlice,
703+
BubbleUpCollapseShapeThroughExtractSlice>(patterns.getContext());
452704
}

mlir/test/Dialect/Linalg/transform-op-fuse.mlir

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -438,3 +438,52 @@ module attributes {transform.with_named_sequence} {
438438
transform.yield
439439
}
440440
}
441+
442+
// -----
443+
444+
// CHECK-LABEL: func.func @bubble_up_extract_slice_through_collapse_shape(
445+
// CHECK: scf.for %[[X:[A-Za-z0-9]+]] = {{.*}} -> (tensor<8x1800x32xf32>) {
446+
// CHECK: %[[EXTRACT:.*]] = tensor.extract_slice
447+
// CHECK: %[[COLLAPSE:.*]] = tensor.collapse_shape %[[EXTRACT]]
448+
// CHECK: %[[EXP1:.*]] = linalg.exp ins(%[[COLLAPSE]]
449+
func.func @bubble_up_extract_slice_through_collapse_shape(%0: tensor<1x8x1800x32xf32>) -> tensor<8x1800x32xf32> {
450+
%expand = tensor.collapse_shape %0 [[0, 1], [2], [3]] : tensor<1x8x1800x32xf32> into tensor<8x1800x32xf32>
451+
%empty = tensor.empty() : tensor<8x1800x32xf32>
452+
%exp = linalg.exp ins(%expand : tensor<8x1800x32xf32>) outs(%empty : tensor<8x1800x32xf32>) -> tensor<8x1800x32xf32>
453+
return %exp : tensor<8x1800x32xf32>
454+
}
455+
456+
module attributes {transform.with_named_sequence} {
457+
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
458+
%0 = transform.structured.match ops{["linalg.exp"]} in %arg0 : (!transform.any_op) -> !transform.any_op
459+
%transformed, %loops:1 = transform.structured.fuse %0 [1, 0, 0] interchange [0, 1, 2] apply_cleanup = true :
460+
(!transform.any_op) -> (!transform.any_op, !transform.op<"scf.for">)
461+
transform.yield
462+
}
463+
}
464+
465+
// -----
466+
467+
// CHECK-LABEL: func.func @bubble_up_extract_slice_through_collapse_shape_with_collapse_producer(
468+
// CHECK: scf.for %[[X:[A-Za-z0-9]+]] = {{.*}}
469+
// CHECK: %[[EXTRACT:.*]] = tensor.extract_slice
470+
// CHECK: %[[ABS:.*]] = linalg.abs ins(%[[EXTRACT]]
471+
// CHECK: %[[COLLAPSE:.*]] = tensor.collapse_shape %[[ABS]]
472+
// CHECK: %[[EXP:.*]] = linalg.exp ins(%[[COLLAPSE]]
473+
func.func @bubble_up_extract_slice_through_collapse_shape_with_collapse_producer(%0: tensor<1x8x1800x32xf32>) -> tensor<8x1800x32xf32> {
474+
%empty1 = tensor.empty() : tensor<1x8x1800x32xf32>
475+
%abs = linalg.abs ins(%0 : tensor<1x8x1800x32xf32>) outs(%empty1 : tensor<1x8x1800x32xf32>) -> tensor<1x8x1800x32xf32>
476+
%expand = tensor.collapse_shape %abs [[0, 1], [2], [3]] : tensor<1x8x1800x32xf32> into tensor<8x1800x32xf32>
477+
%empty2 = tensor.empty() : tensor<8x1800x32xf32>
478+
%exp = linalg.exp ins(%expand : tensor<8x1800x32xf32>) outs(%empty2 : tensor<8x1800x32xf32>) -> tensor<8x1800x32xf32>
479+
return %exp : tensor<8x1800x32xf32>
480+
}
481+
482+
module attributes {transform.with_named_sequence} {
483+
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
484+
%0 = transform.structured.match ops{["linalg.exp"]} in %arg0 : (!transform.any_op) -> !transform.any_op
485+
%transformed, %loops:1 = transform.structured.fuse %0 [1, 0, 0] interchange [0, 1, 2] apply_cleanup = true :
486+
(!transform.any_op) -> (!transform.any_op, !transform.op<"scf.for">)
487+
transform.yield
488+
}
489+
}

0 commit comments

Comments
 (0)