@@ -510,10 +510,11 @@ struct BubbleUpCollapseShapeThroughExtractSlice
510
510
PatternRewriter &rewriter) const override {
511
511
auto collapseShapeOp =
512
512
sliceOp.getSource ().getDefiningOp <tensor::CollapseShapeOp>();
513
- if (!collapseShapeOp)
513
+ if (!collapseShapeOp) {
514
514
return rewriter.notifyMatchFailure (
515
515
sliceOp,
516
516
" tensor.extract_slice source not produced by tensor.collapse_shape" );
517
+ }
517
518
518
519
if (!sliceOp.hasUnitStride ()) {
519
520
return rewriter.notifyMatchFailure (
@@ -530,9 +531,10 @@ struct BubbleUpCollapseShapeThroughExtractSlice
530
531
SmallVector<OpFoldResult> collapsedSizes = sliceOp.getMixedSizes ();
531
532
532
533
if (static_cast <size_t >(sliceOp.getResultType ().getRank ()) !=
533
- collapsedSizes.size ())
534
+ collapsedSizes.size ()) {
534
535
return rewriter.notifyMatchFailure (sliceOp,
535
536
" unimplemented: rank reducing slice" );
537
+ }
536
538
537
539
ArrayRef<int64_t > srcShape = collapseShapeOp.getSrcType ().getShape ();
538
540
SmallVector<ReassociationIndices, 4 > reassociationIndices =
@@ -546,10 +548,9 @@ struct BubbleUpCollapseShapeThroughExtractSlice
546
548
SmallVector<OpFoldResult> expandedStrides (srcShape.size (),
547
549
rewriter.getIndexAttr (1 ));
548
550
549
- for (auto [groupIdx, reassocIndices] :
550
- enumerate(collapseShapeOp.getReassociationIndices ())) {
551
- OpFoldResult collapsedSize = collapsedSizes[groupIdx];
552
- OpFoldResult collapsedOffset = collapsedOffsets[groupIdx];
551
+ for (auto [collapsedSize, collapsedOffset, reassocIndices] :
552
+ llvm::zip_equal (collapsedSizes, collapsedOffsets,
553
+ collapseShapeOp.getReassociationIndices ())) {
553
554
// CASE #1 - size and/or offset are dynamic.
554
555
// In this case, the slice can be represented as a contiguous slice only
555
556
// if there is a single dimension in the reassociation group that has a
@@ -614,10 +615,11 @@ struct BubbleUpCollapseShapeThroughExtractSlice
614
615
// We need to make sure that the slice size can be set to the shape size
615
616
// and the offset to 0.
616
617
if ((currentCollapsedsize % expandedShapeSize) != 0 ||
617
- (currentCollapsedOffset % expandedShapeSize) != 0 )
618
+ (currentCollapsedOffset % expandedShapeSize) != 0 ) {
618
619
return rewriter.notifyMatchFailure (
619
620
sliceOp, " unsupported: cannot be extracted as a contiguous slice "
620
621
" of the src of the collapse_shape" );
622
+ }
621
623
622
624
groupExpandedSizes.push_back (rewriter.getIndexAttr (expandedShapeSize));
623
625
groupExpandedOffsets.push_back (rewriter.getIndexAttr (0 ));
@@ -632,10 +634,11 @@ struct BubbleUpCollapseShapeThroughExtractSlice
632
634
int64_t offsetInDim = currentCollapsedOffset % expandedShapeSize;
633
635
// We need to make sure that the slice size in this dim + offset will
634
636
// not exceed the shape size.
635
- if ((currentCollapsedsize + offsetInDim) >= expandedShapeSize)
637
+ if ((currentCollapsedsize + offsetInDim) >= expandedShapeSize) {
636
638
return rewriter.notifyMatchFailure (
637
639
sliceOp, " unsupported: slice cannot be extracted as a contiguous "
638
640
" slice of the src of the collapse_shape" );
641
+ }
639
642
640
643
groupExpandedSizes.push_back (
641
644
rewriter.getIndexAttr (currentCollapsedsize));
0 commit comments