Skip to content

Commit d41ea3c

Browse files
committed
[mlir][vector] Add pattern to drop unit dims from vector.transpose
Example: BEFORE: ```mlir %transpose = vector.transpose %vector, [3, 0, 1, 2] : vector<1x1x4x[4]xf32> to vector<[4]x1x1x4xf32> ``` AFTER: ```mlir %dropDims = vector.shape_cast %vector : vector<1x1x4x[4]xf32> to vector<4x[4]xf32> %transpose = vector.transpose %0, [1, 0] : vector<4x[4]xf32> to vector<[4]x4xf32> %restoreDims = vector.shape_cast %transpose : vector<[4]x4xf32> to vector<[4]x1x1x4xf32> ```
1 parent b1234dd commit d41ea3c

File tree

3 files changed

+106
-2
lines changed

3 files changed

+106
-2
lines changed

mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,11 @@ inline auto makeVscaleConstantBuilder(PatternRewriter &rewriter, Location loc) {
120120
};
121121
}
122122

123+
/// Returns an iterator over the dims (inc scalability) of a VectorType.
124+
inline auto getDims(VectorType vType) {
125+
return llvm::zip_equal(vType.getShape(), vType.getScalableDims());
126+
}
127+
123128
/// A wrapper for getMixedSizes for vector.transfer_read and
124129
/// vector.transfer_write Ops (for source and destination, respectively).
125130
///

mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp

Lines changed: 68 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1720,6 +1720,72 @@ struct DropUnitDimFromElementwiseOps final
17201720
}
17211721
};
17221722

1723+
/// A pattern to drop unit dims from vector.transpose.
1724+
///
1725+
/// Example:
1726+
///
1727+
/// BEFORE:
1728+
/// ```mlir
1729+
/// %transpose = vector.transpose %vector, [3, 0, 1, 2]
1730+
/// : vector<1x1x4x[4]xf32> to vector<[4]x1x1x4xf32>
1731+
/// ```
1732+
///
1733+
/// AFTER:
1734+
/// ```mlir
1735+
/// %dropDims = vector.shape_cast %vector
1736+
/// : vector<1x1x4x[4]xf32> to vector<4x[4]xf32>
1737+
/// %transpose = vector.transpose %0, [1, 0]
1738+
/// : vector<4x[4]xf32> to vector<[4]x4xf32>
1739+
/// %restoreDims = vector.shape_cast %transpose
1740+
/// : vector<[4]x4xf32> to vector<[4]x1x1x4xf32>
1741+
/// ```
1742+
struct DropUnitDimsFromTransposeOp final
1743+
: OpRewritePattern<vector::TransposeOp> {
1744+
using OpRewritePattern::OpRewritePattern;
1745+
1746+
LogicalResult matchAndRewrite(vector::TransposeOp op,
1747+
PatternRewriter &rewriter) const override {
1748+
VectorType sourceType = op.getSourceVectorType();
1749+
VectorType sourceTypeWithoutUnitDims =
1750+
dropNonScalableUnitDimFromType(sourceType);
1751+
1752+
if (sourceType == sourceTypeWithoutUnitDims)
1753+
return failure();
1754+
1755+
// Construct a map from dimIdx -> number of dims dropped before dimIdx.
1756+
auto sourceDims = llvm::to_vector(vector::getDims(sourceType));
1757+
SmallVector<int64_t> droppedDimsBefore(sourceType.getRank());
1758+
int64_t droppedDims = 0;
1759+
for (auto [i, dim] : llvm::enumerate(sourceDims)) {
1760+
droppedDimsBefore[i] = droppedDims;
1761+
if (dim == std::make_tuple(1, false))
1762+
++droppedDims;
1763+
}
1764+
1765+
// Drop unit dims from transpose permutation.
1766+
ArrayRef<int64_t> perm = op.getPermutation();
1767+
SmallVector<int64_t> newPerm;
1768+
for (int64_t idx : perm) {
1769+
if (sourceDims[idx] == std::make_tuple(1, false))
1770+
continue;
1771+
newPerm.push_back(idx - droppedDimsBefore[idx]);
1772+
}
1773+
1774+
auto loc = op.getLoc();
1775+
// Drop the unit dims via shape_cast.
1776+
auto dropDimsShapeCast = rewriter.create<vector::ShapeCastOp>(
1777+
loc, sourceTypeWithoutUnitDims, op.getVector());
1778+
// Create the new transpose.
1779+
auto tranposeWithoutUnitDims =
1780+
rewriter.create<vector::TransposeOp>(loc, dropDimsShapeCast, newPerm);
1781+
// Restore the unit dims via shape cast.
1782+
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
1783+
op, op.getResultVectorType(), tranposeWithoutUnitDims);
1784+
1785+
return failure();
1786+
}
1787+
};
1788+
17231789
/// Pattern to eliminate redundant zero-constants added to reduction operands.
17241790
/// It's enough for there to be one initial zero value, so we can eliminate the
17251791
/// extra ones that feed into `vector.reduction <add>`. These get created by the
@@ -1924,8 +1990,8 @@ void mlir::vector::populateShapeCastFoldingPatterns(RewritePatternSet &patterns,
19241990

19251991
void mlir::vector::populateDropUnitDimWithShapeCastPatterns(
19261992
RewritePatternSet &patterns, PatternBenefit benefit) {
1927-
patterns.add<DropUnitDimFromElementwiseOps, ShapeCastOpFolder>(
1928-
patterns.getContext(), benefit);
1993+
patterns.add<DropUnitDimFromElementwiseOps, DropUnitDimsFromTransposeOp,
1994+
ShapeCastOpFolder>(patterns.getContext(), benefit);
19291995
}
19301996

19311997
void mlir::vector::populateBubbleVectorBitCastOpPatterns(

mlir/test/Dialect/Vector/vector-transfer-flatten.mlir

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -700,3 +700,36 @@ func.func @negative_out_of_bound_transfer_write(
700700
}
701701
// CHECK: func.func @negative_out_of_bound_transfer_write
702702
// CHECK-NOT: memref.collapse_shape
703+
704+
// -----
705+
706+
///----------------------------------------------------------------------------------------
707+
/// [Pattern: DropUnitDimsFromTransposeOp]
708+
/// TODO: Move to a dedicated file - there's no "flattening" in the following tests
709+
///----------------------------------------------------------------------------------------
710+
711+
func.func @transpose_with_internal_unit_dims(%vector: vector<1x1x4x[4]xf32>) -> vector<[4]x1x1x4xf32> {
712+
%0 = vector.transpose %vector, [3, 0, 1, 2] : vector<1x1x4x[4]xf32> to vector<[4]x1x1x4xf32>
713+
return %0 : vector<[4]x1x1x4xf32>
714+
}
715+
716+
// CHECK-LABEL: func.func @transpose_with_internal_unit_dims(
717+
// CHECK-SAME: %[[VEC:.*]]: vector<1x1x4x[4]xf32>)
718+
// CHECK-NEXT: %[[DROP_DIMS:.*]] = vector.shape_cast %arg0 : vector<1x1x4x[4]xf32> to vector<4x[4]xf32>
719+
// CHECK-NEXT: %[[TRANSPOSE:.*]] = vector.transpose %0, [1, 0] : vector<4x[4]xf32> to vector<[4]x4xf32>
720+
// CHECK-NEXT: %[[RESTORE_DIMS:.*]] = vector.shape_cast %1 : vector<[4]x4xf32> to vector<[4]x1x1x4xf32>
721+
// CHECK-NEXT: return %[[RESTORE_DIMS]] : vector<[4]x1x1x4xf32>
722+
723+
// -----
724+
725+
func.func @transpose_with_units_dims_before_and_after(%vector: vector<1x1x1x4x[4]x1xf32>) -> vector<[4]x1x1x1x4x1xf32> {
726+
%0 = vector.transpose %vector, [4, 1, 0, 2, 3, 5] : vector<1x1x1x4x[4]x1xf32> to vector<[4]x1x1x1x4x1xf32>
727+
return %0 : vector<[4]x1x1x1x4x1xf32>
728+
}
729+
730+
// CHECK-LABEL: func.func @transpose_with_units_dims_before_and_after(
731+
// CHECK-SAME: %[[VEC:.*]]: vector<1x1x1x4x[4]x1xf32>)
732+
// CHECK-NEXT: %[[DROP_DIMS:.*]] = vector.shape_cast %arg0 : vector<1x1x1x4x[4]x1xf32> to vector<4x[4]xf32>
733+
// CHECK-NEXT: %[[TRANSPOSE:.*]] = vector.transpose %0, [1, 0] : vector<4x[4]xf32> to vector<[4]x4xf32>
734+
// CHECK-NEXT: %[[RESTORE_DIMS:.*]] = vector.shape_cast %1 : vector<[4]x4xf32> to vector<[4]x1x1x1x4x1xf32>
735+
// CHECK-NEXT: return %[[RESTORE_DIMS]] : vector<[4]x1x1x1x4x1xf32>

0 commit comments

Comments
 (0)