12
12
#include " mlir/Dialect/Tensor/Transforms/Transforms.h"
13
13
#include " mlir/IR/PatternMatch.h"
14
14
#include " mlir/Interfaces/ValueBoundsOpInterface.h"
15
+ #include " llvm/ADT/STLExtras.h"
15
16
#include " llvm/Support/Debug.h"
16
17
#include " llvm/Support/LogicalResult.h"
17
18
@@ -428,6 +429,256 @@ struct BubbleUpExpandShapeThroughExtractSlice
428
429
}
429
430
};
430
431
432
+ // / Converts `tensor.extract_slice(tensor.collapse_shape)` to
433
+ // / `tensor.collapse_shape(tensor.extract_slice)`.
434
+ // /
435
+ // / For this transformation to be possible - after bubbling up, the extraction
436
+ // / of the contiguous slice must be representable as a single slice obtained via
437
+ // / tensor.extract_slice within each reassociation group of the src.
438
+ // /
439
+ // / In case the size and offset extracted are static then this is possible if
440
+ // / the following conditions are met within each reassociation group:
441
+ // / Let T be a tensor of shape [A0, A1, ..., An] (these are the sizes of the
442
+ // / dimensions in the reassociation group), and let S = [S0, S1, ..., Sn] be the
443
+ // / shape of a desired slice. A slice of shape S can be extracted as a
444
+ // / contiguous span of elements if and only if there exists an index k in {0, 1,
445
+ // / ..., n} such that:
446
+ // / S_i = 1 for all i < k (that is, all leading dimensions are singleton),
447
+ // / 1 <= S_k <= A_k (that is, non trivial slicing occurs along exactly
448
+ // / one dimension),
449
+ // / S_i = A_i for all i > k (that is, all trailing dimensions are preserved
450
+ // / in full).
451
+ // / In other words, the slice shape S must be of the form:
452
+ // / [ 1, 1, ..., 1, Sk, Ak + 1, Ak + 2, ...,An ]
453
+ // /
454
+ // / In case the size and/or offset extracted are dynamic then this is possible
455
+ // / only if there is single dimension in the reassociation group that has a size
456
+ // / not equal to 1.
457
+ // / In other words, the tensor shape must be of the form:
458
+ // / [ 1, 1, ..., 1, A, 1, ...,1 ]
459
+ // / Note - it might be possible to enable this pattern for more cases when the
460
+ // / size/offset are dynamic via performing an analysis of the possible values
461
+ // / that could be given to the size/offset.
462
+ // /
463
+ // / Example:
464
+ // / The transformation is possible because each reassociation group can be
465
+ // / represented as a contiguous slice (i.e., [8x16->2x16], [1x7->1x?],
466
+ // / [20->10]).
467
+ // / ```
468
+ // / BEFORE:
469
+ // / %collapse = tensor.collapse_shape %src [[0, 1], [2, 3], [4]] ...
470
+ // / tensor<8x16x1x7x20f32> to tensor<128x7x20xf32>
471
+ // / %slice = tensor.extract_slice %slice [0, 0, 0][32, %size, 10][1, 1, 1]
472
+ // / tensor<128x7x20xf32> to tensor<32x?x10xf32>
473
+ // /
474
+ // / AFTER:
475
+ // / %slice = tensor.extract_slice %src [0, 0, 0, 0, 0][2, 16, 1, %size, 10]
476
+ // [1, 1, 1, 1, 1] : tensor<8x16x1x7x20f32> to tensor<2x16x1x?x10xf32>
477
+ // / %collapse = tensor.collapse_shape %slice [[0, 1], [2, 3], [4]] ...
478
+ // / tensor<2x16x1x?x10xf32> to tensor<32x?x10xf32>
479
+ // / ```
480
+ // /
481
+ // / Negative example:
482
+ // / The transformation is not possible because we cannot use a single slice to
483
+ // / represent the reassociation group [2x3x10->???]. If we would want the
484
+ // / collapse to be after the extraction, we would need to extract multiple
485
+ // / slices and concat them together.
486
+ // / ```
487
+ // / %collapse = tensor.collapse_shape %src [[0, 1, 2]] : tensor<2x3x10xf32> into
488
+ // / tensor<60xf32> %extract = tensor.extract_slice %collapse[0][15][1] :
489
+ // / tensor<60xf32> to tensor<15xf32>
490
+ // / ```
491
+ // / If we would want the collapse to be after the extraction, a possible
492
+ // / alternate transformation could be to extract multiple slices and concat them
493
+ // / together:
494
+ // / ```
495
+ // / %extract_1 = tensor.extract_slice %src[0, 0, 0][1, 1, 10] :
496
+ // / tensor<2x3x10xf32> to tensor <1x1x10xf32>
497
+ // / %extract_2 = tensor.extract_slice %src[0, 1, 0][1, 1, 5] :
498
+ // / tensor<2x3x10xf32> to tensor <1x1x5xf32>
499
+ // / %concat = tosa.concat %extract_1, %extract_2 {axis = 0 : i32} :
500
+ // / (<1x1x10xf32>, <1x1x5xf32>) -> <1x1x15xf32>
501
+ // / %collapse = tensor.collapse_shape %concat [[0, 1, 2]] : tensor<1x1x15xf32>
502
+ // / to tensor<15xf32>
503
+ // / ```
504
+ // / But this is not the intended purpose of the transformation.
505
+ struct BubbleUpCollapseShapeThroughExtractSlice
506
+ : public OpRewritePattern<tensor::ExtractSliceOp> {
507
+ using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern;
508
+
509
+ LogicalResult matchAndRewrite (tensor::ExtractSliceOp sliceOp,
510
+ PatternRewriter &rewriter) const override {
511
+ auto collapseShapeOp =
512
+ sliceOp.getSource ().getDefiningOp <tensor::CollapseShapeOp>();
513
+ if (!collapseShapeOp) {
514
+ return rewriter.notifyMatchFailure (
515
+ sliceOp,
516
+ " tensor.extract_slice source not produced by tensor.collapse_shape" );
517
+ }
518
+
519
+ if (!sliceOp.hasUnitStride ()) {
520
+ return rewriter.notifyMatchFailure (
521
+ sliceOp, " unsupported: non-unit stride. Only contiguous slices can "
522
+ " be supported in this transformation." );
523
+ }
524
+
525
+ // The tensor.extract_slice before applying the pattern works on the result
526
+ // of the tensor.collapse_shape, so variables (i.e. inputs for
527
+ // ExtractSliceOp) referring to the state before applying the pattern are
528
+ // named with the prefix "collapsed", and ones referring to the state after
529
+ // applying the pattern are named with the prefix "expanded".
530
+ SmallVector<OpFoldResult> collapsedOffsets = sliceOp.getMixedOffsets ();
531
+ SmallVector<OpFoldResult> collapsedSizes = sliceOp.getMixedSizes ();
532
+
533
+ if (static_cast <size_t >(sliceOp.getResultType ().getRank ()) !=
534
+ collapsedSizes.size ()) {
535
+ return rewriter.notifyMatchFailure (sliceOp,
536
+ " unimplemented: rank reducing slice" );
537
+ }
538
+
539
+ ArrayRef<int64_t > srcShape = collapseShapeOp.getSrcType ().getShape ();
540
+ SmallVector<ReassociationIndices, 4 > reassociationIndices =
541
+ collapseShapeOp.getReassociationIndices ();
542
+
543
+ // Compute new offsets, sizes, and strides for tensor.extract_slice.
544
+ // The new tensor.extract_slice will work on a tensor that has has a rank
545
+ // equal to the rank of the src of the collapse_shape. In each iteration of
546
+ // the loop, the offsets and sizes will be computed per reassociation group.
547
+ SmallVector<OpFoldResult> expandedOffsets, expandedSizes;
548
+ SmallVector<OpFoldResult> expandedStrides (srcShape.size (),
549
+ rewriter.getIndexAttr (1 ));
550
+
551
+ for (auto [collapsedSize, collapsedOffset, reassocIndices] :
552
+ llvm::zip_equal (collapsedSizes, collapsedOffsets,
553
+ collapseShapeOp.getReassociationIndices ())) {
554
+ // CASE #1 - size and/or offset are dynamic.
555
+ // In this case, the slice can be represented as a contiguous slice only
556
+ // if there is a single dimension in the reassociation group that has a
557
+ // size not equal to 1.
558
+ if (isa<Value>(collapsedSize) || isa<Value>(collapsedOffset)) {
559
+ int nonUnitSizeCount = 0 ;
560
+ for (int64_t expandedShapeIdx : reassocIndices) {
561
+ if (srcShape[expandedShapeIdx] != 1 ) {
562
+ nonUnitSizeCount++;
563
+ expandedSizes.push_back (collapsedSize);
564
+ expandedOffsets.push_back (collapsedOffset);
565
+ continue ;
566
+ }
567
+
568
+ expandedSizes.push_back (rewriter.getIndexAttr (1 ));
569
+ expandedOffsets.push_back (rewriter.getIndexAttr (0 ));
570
+ }
571
+
572
+ if (nonUnitSizeCount != 1 ) {
573
+ return rewriter.notifyMatchFailure (
574
+ sliceOp,
575
+ " unsupported: slice cannot be verified to be contiguous" );
576
+ }
577
+ continue ;
578
+ }
579
+
580
+ // CASE #2 = size and offset are static.
581
+ // Verify that the slice can be represented as a contiguous slice of the
582
+ // src of the collapse_shape.
583
+ // Checking this is done on order of most internal dimensions first,
584
+ // so traversal is done in reverse order of the reassociation group.
585
+ // If the expected slice shape is [1, 1, ..., 1, Sk, Ak + 1, Ak + 2,
586
+ // ...,An] then we first find the size and offset for n...k+1 then for k
587
+ // and then for k-1...0.
588
+
589
+ // currentCollapsedsize and currentCollapsedOffset are initialized with
590
+ // the original collapsed size and offset and divided by the expanded
591
+ // shape size in each dimension as we go along the reassociation group.
592
+ // In essence we are spreading the original collapsed size and offset over
593
+ // the various expanded slice dimensions.
594
+ // The variables are used both to check the validity of the slice and to
595
+ // compute the expanded sizes and offsets.
596
+ int64_t currentCollapsedsize = getConstantIntValue (collapsedSize).value ();
597
+ int64_t currentCollapsedOffset =
598
+ getConstantIntValue (collapsedOffset).value ();
599
+
600
+ SmallVector<OpFoldResult> groupExpandedSizes, groupExpandedOffsets;
601
+
602
+ ReassociationIndices reversedReassocIndices (reassocIndices.rbegin (),
603
+ reassocIndices.rend ());
604
+ int64_t idx = 0 ;
605
+ int64_t reassocGroupSize = reassocIndices.size ();
606
+
607
+ // First handle the trailing dimensions where the slice size should be
608
+ // equal to the tensor shape and the offset should be 0 (n...k+1).
609
+ for (; idx < reassocGroupSize; ++idx) {
610
+ int64_t expandedShapeSize = srcShape[reversedReassocIndices[idx]];
611
+
612
+ if (currentCollapsedsize < expandedShapeSize)
613
+ break ;
614
+
615
+ // We need to make sure that the slice size can be set to the shape size
616
+ // and the offset to 0.
617
+ if ((currentCollapsedsize % expandedShapeSize) != 0 ||
618
+ (currentCollapsedOffset % expandedShapeSize) != 0 ) {
619
+ return rewriter.notifyMatchFailure (
620
+ sliceOp, " unsupported: cannot be extracted as a contiguous slice "
621
+ " of the src of the collapse_shape" );
622
+ }
623
+
624
+ groupExpandedSizes.push_back (rewriter.getIndexAttr (expandedShapeSize));
625
+ groupExpandedOffsets.push_back (rewriter.getIndexAttr (0 ));
626
+
627
+ currentCollapsedsize /= expandedShapeSize;
628
+ currentCollapsedOffset /= expandedShapeSize;
629
+ }
630
+
631
+ // Now handle the first dim where slicing occurs on (k).
632
+ if (idx < reassocGroupSize) {
633
+ int64_t expandedShapeSize = srcShape[reversedReassocIndices[idx]];
634
+ int64_t offsetInDim = currentCollapsedOffset % expandedShapeSize;
635
+ // We need to make sure that the slice size in this dim + offset will
636
+ // not exceed the shape size.
637
+ if ((currentCollapsedsize + offsetInDim) >= expandedShapeSize) {
638
+ return rewriter.notifyMatchFailure (
639
+ sliceOp, " unsupported: slice cannot be extracted as a contiguous "
640
+ " slice of the src of the collapse_shape" );
641
+ }
642
+
643
+ groupExpandedSizes.push_back (
644
+ rewriter.getIndexAttr (currentCollapsedsize));
645
+ groupExpandedOffsets.push_back (rewriter.getIndexAttr (offsetInDim));
646
+
647
+ currentCollapsedOffset /= expandedShapeSize;
648
+ }
649
+
650
+ // Now handle the leading dimensions where the slice size is equal to 1
651
+ // (k-1...0).
652
+ // The size for these dimensions must be 1 because of how we constructed
653
+ // the slice size of the expanded shape. We spread the original collapsed
654
+ // size over the expanded shape sizes until we reached dimension k where
655
+ // the remaining size was smaller than the expanded shape size, and spread
656
+ // the remaining size on it. So, now we are left with only 1s.
657
+ for (idx++; idx < reassocGroupSize; ++idx) {
658
+ int64_t expandedShapeSize = srcShape[reversedReassocIndices[idx]];
659
+ int64_t offsetInDim = currentCollapsedOffset % expandedShapeSize;
660
+ groupExpandedSizes.push_back (rewriter.getIndexAttr (1 ));
661
+ groupExpandedOffsets.push_back (rewriter.getIndexAttr (offsetInDim));
662
+ currentCollapsedOffset /= expandedShapeSize;
663
+ }
664
+
665
+ expandedSizes.append (groupExpandedSizes.rbegin (),
666
+ groupExpandedSizes.rend ());
667
+ expandedOffsets.append (groupExpandedOffsets.rbegin (),
668
+ groupExpandedOffsets.rend ());
669
+ }
670
+
671
+ Value newSliceOp = rewriter.create <tensor::ExtractSliceOp>(
672
+ collapseShapeOp->getLoc (), collapseShapeOp.getSrc (), expandedOffsets,
673
+ expandedSizes, expandedStrides);
674
+ rewriter.replaceOpWithNewOp <tensor::CollapseShapeOp>(
675
+ sliceOp, sliceOp.getResultType (), newSliceOp,
676
+ collapseShapeOp.getReassociationIndices ());
677
+
678
+ return success ();
679
+ }
680
+ };
681
+
431
682
} // namespace
432
683
433
684
void mlir::tensor::populateReassociativeReshapeFoldingPatterns (
@@ -448,5 +699,6 @@ void mlir::tensor::populateBubbleUpExpandShapePatterns(
448
699
449
700
void mlir::tensor::populateBubbleUpExtractSliceOpPatterns (
450
701
RewritePatternSet &patterns) {
451
- patterns.add <BubbleUpExpandShapeThroughExtractSlice>(patterns.getContext ());
702
+ patterns.add <BubbleUpExpandShapeThroughExtractSlice,
703
+ BubbleUpCollapseShapeThroughExtractSlice>(patterns.getContext ());
452
704
}
0 commit comments