@@ -1720,6 +1720,72 @@ struct DropUnitDimFromElementwiseOps final
1720
1720
}
1721
1721
};
1722
1722
1723
+ // / A pattern to drop unit dims from vector.transpose.
1724
+ // /
1725
+ // / Example:
1726
+ // /
1727
+ // / BEFORE:
1728
+ // / ```mlir
1729
+ // / %transpose = vector.transpose %vector, [3, 0, 1, 2]
1730
+ // / : vector<1x1x4x[4]xf32> to vector<[4]x1x1x4xf32>
1731
+ // / ```
1732
+ // /
1733
+ // / AFTER:
1734
+ // / ```mlir
1735
+ // / %dropDims = vector.shape_cast %vector
1736
+ // / : vector<1x1x4x[4]xf32> to vector<4x[4]xf32>
1737
+ // / %transpose = vector.transpose %0, [1, 0]
1738
+ // / : vector<4x[4]xf32> to vector<[4]x4xf32>
1739
+ // / %restoreDims = vector.shape_cast %transpose
1740
+ // / : vector<[4]x4xf32> to vector<[4]x1x1x4xf32>
1741
+ // / ```
1742
+ struct DropUnitDimsFromTransposeOp final
1743
+ : OpRewritePattern<vector::TransposeOp> {
1744
+ using OpRewritePattern::OpRewritePattern;
1745
+
1746
+ LogicalResult matchAndRewrite (vector::TransposeOp op,
1747
+ PatternRewriter &rewriter) const override {
1748
+ VectorType sourceType = op.getSourceVectorType ();
1749
+ VectorType sourceTypeWithoutUnitDims =
1750
+ dropNonScalableUnitDimFromType (sourceType);
1751
+
1752
+ if (sourceType == sourceTypeWithoutUnitDims)
1753
+ return failure ();
1754
+
1755
+ // Construct a map from dimIdx -> number of dims dropped before dimIdx.
1756
+ auto sourceDims = llvm::to_vector (vector::getDims (sourceType));
1757
+ SmallVector<int64_t > droppedDimsBefore (sourceType.getRank ());
1758
+ int64_t droppedDims = 0 ;
1759
+ for (auto [i, dim] : llvm::enumerate (sourceDims)) {
1760
+ droppedDimsBefore[i] = droppedDims;
1761
+ if (dim == std::make_tuple (1 , false ))
1762
+ ++droppedDims;
1763
+ }
1764
+
1765
+ // Drop unit dims from transpose permutation.
1766
+ ArrayRef<int64_t > perm = op.getPermutation ();
1767
+ SmallVector<int64_t > newPerm;
1768
+ for (int64_t idx : perm) {
1769
+ if (sourceDims[idx] == std::make_tuple (1 , false ))
1770
+ continue ;
1771
+ newPerm.push_back (idx - droppedDimsBefore[idx]);
1772
+ }
1773
+
1774
+ auto loc = op.getLoc ();
1775
+ // Drop the unit dims via shape_cast.
1776
+ auto dropDimsShapeCast = rewriter.create <vector::ShapeCastOp>(
1777
+ loc, sourceTypeWithoutUnitDims, op.getVector ());
1778
+ // Create the new transpose.
1779
+ auto tranposeWithoutUnitDims =
1780
+ rewriter.create <vector::TransposeOp>(loc, dropDimsShapeCast, newPerm);
1781
+ // Restore the unit dims via shape cast.
1782
+ rewriter.replaceOpWithNewOp <vector::ShapeCastOp>(
1783
+ op, op.getResultVectorType (), tranposeWithoutUnitDims);
1784
+
1785
+ return failure ();
1786
+ }
1787
+ };
1788
+
1723
1789
// / Pattern to eliminate redundant zero-constants added to reduction operands.
1724
1790
// / It's enough for there to be one initial zero value, so we can eliminate the
1725
1791
// / extra ones that feed into `vector.reduction <add>`. These get created by the
@@ -1924,8 +1990,8 @@ void mlir::vector::populateShapeCastFoldingPatterns(RewritePatternSet &patterns,
1924
1990
1925
1991
void mlir::vector::populateDropUnitDimWithShapeCastPatterns (
1926
1992
RewritePatternSet &patterns, PatternBenefit benefit) {
1927
- patterns.add <DropUnitDimFromElementwiseOps, ShapeCastOpFolder>(
1928
- patterns.getContext (), benefit);
1993
+ patterns.add <DropUnitDimFromElementwiseOps, DropUnitDimsFromTransposeOp,
1994
+ ShapeCastOpFolder>( patterns.getContext (), benefit);
1929
1995
}
1930
1996
1931
1997
void mlir::vector::populateBubbleVectorBitCastOpPatterns (
0 commit comments