Skip to content

Commit 2f05786

Browse files
committed
CR fixes
1 parent 5845db6 commit 2f05786

File tree

1 file changed

+11
-8
lines changed

1 file changed

+11
-8
lines changed

mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -510,10 +510,11 @@ struct BubbleUpCollapseShapeThroughExtractSlice
510510
PatternRewriter &rewriter) const override {
511511
auto collapseShapeOp =
512512
sliceOp.getSource().getDefiningOp<tensor::CollapseShapeOp>();
513-
if (!collapseShapeOp)
513+
if (!collapseShapeOp) {
514514
return rewriter.notifyMatchFailure(
515515
sliceOp,
516516
"tensor.extract_slice source not produced by tensor.collapse_shape");
517+
}
517518

518519
if (!sliceOp.hasUnitStride()) {
519520
return rewriter.notifyMatchFailure(
@@ -530,9 +531,10 @@ struct BubbleUpCollapseShapeThroughExtractSlice
530531
SmallVector<OpFoldResult> collapsedSizes = sliceOp.getMixedSizes();
531532

532533
if (static_cast<size_t>(sliceOp.getResultType().getRank()) !=
533-
collapsedSizes.size())
534+
collapsedSizes.size()) {
534535
return rewriter.notifyMatchFailure(sliceOp,
535536
"unimplemented: rank reducing slice");
537+
}
536538

537539
ArrayRef<int64_t> srcShape = collapseShapeOp.getSrcType().getShape();
538540
SmallVector<ReassociationIndices, 4> reassociationIndices =
@@ -546,10 +548,9 @@ struct BubbleUpCollapseShapeThroughExtractSlice
546548
SmallVector<OpFoldResult> expandedStrides(srcShape.size(),
547549
rewriter.getIndexAttr(1));
548550

549-
for (auto [groupIdx, reassocIndices] :
550-
enumerate(collapseShapeOp.getReassociationIndices())) {
551-
OpFoldResult collapsedSize = collapsedSizes[groupIdx];
552-
OpFoldResult collapsedOffset = collapsedOffsets[groupIdx];
551+
for (auto [collapsedSize, collapsedOffset, reassocIndices] :
552+
llvm::zip_equal(collapsedSizes, collapsedOffsets,
553+
collapseShapeOp.getReassociationIndices())) {
553554
// CASE #1 - size and/or offset are dynamic.
554555
// In this case, the slice can be represented as a contiguous slice only
555556
// if there is a single dimension in the reassociation group that has a
@@ -614,10 +615,11 @@ struct BubbleUpCollapseShapeThroughExtractSlice
614615
// We need to make sure that the slice size can be set to the shape size
615616
// and the offset to 0.
616617
if ((currentCollapsedsize % expandedShapeSize) != 0 ||
617-
(currentCollapsedOffset % expandedShapeSize) != 0)
618+
(currentCollapsedOffset % expandedShapeSize) != 0) {
618619
return rewriter.notifyMatchFailure(
619620
sliceOp, "unsupported: cannot be extracted as a contiguous slice "
620621
"of the src of the collapse_shape");
622+
}
621623

622624
groupExpandedSizes.push_back(rewriter.getIndexAttr(expandedShapeSize));
623625
groupExpandedOffsets.push_back(rewriter.getIndexAttr(0));
@@ -632,10 +634,11 @@ struct BubbleUpCollapseShapeThroughExtractSlice
632634
int64_t offsetInDim = currentCollapsedOffset % expandedShapeSize;
633635
// We need to make sure that the slice size in this dim + offset will
634636
// not exceed the shape size.
635-
if ((currentCollapsedsize + offsetInDim) >= expandedShapeSize)
637+
if ((currentCollapsedsize + offsetInDim) >= expandedShapeSize) {
636638
return rewriter.notifyMatchFailure(
637639
sliceOp, "unsupported: slice cannot be extracted as a contiguous "
638640
"slice of the src of the collapse_shape");
641+
}
639642

640643
groupExpandedSizes.push_back(
641644
rewriter.getIndexAttr(currentCollapsedsize));

0 commit comments

Comments
 (0)