Skip to content

Commit 201da87

Browse files
authored
[mlir][vector] Handle corner cases in DropUnitDimsFromTransposeOp. (#102518)
da8778e breaks the lowering of vector.transpose that all the dimensions are unit dimensions. The revision fixes the issue and adds a test. --------- Signed-off-by: hanhanW <[email protected]>
1 parent 4c1dbbe commit 201da87

File tree

2 files changed

+19
-0
lines changed

2 files changed

+19
-0
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1771,6 +1771,13 @@ struct DropUnitDimsFromTransposeOp final
17711771
newPerm.push_back(idx - droppedDimsBefore[idx]);
17721772
}
17731773

1774+
// Fixup for `newPerm`. The `sourceTypeWithoutUnitDims` could be vector<1xT>
1775+
// type when the dimensions are unit dimensions. In this case, the newPerm
1776+
// should be [0].
1777+
if (newPerm.empty()) {
1778+
newPerm.push_back(0);
1779+
}
1780+
17741781
Location loc = op.getLoc();
17751782
// Drop the unit dims via shape_cast.
17761783
auto dropDimsShapeCast = rewriter.create<vector::ShapeCastOp>(

mlir/test/Dialect/Vector/vector-transfer-flatten.mlir

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -737,6 +737,18 @@ func.func @transpose_with_scalable_unit_dims(%vec: vector<[1]x1x2x4x1xf32>) -> v
737737

738738
// -----
739739

740+
func.func @transpose_with_all_unit_dims(%vec: vector<1x1x1xf32>) -> vector<1x1x1xf32> {
741+
%res = vector.transpose %vec, [0, 2, 1] : vector<1x1x1xf32> to vector<1x1x1xf32>
742+
return %res : vector<1x1x1xf32>
743+
}
744+
// The `vec` is returned because there are other flattening patterns that fold
745+
// vector.shape_cast ops away.
746+
// CHECK-LABEL: func.func @transpose_with_all_unit_dims
747+
// CHECK-SAME: %[[VEC:.[a-zA-Z0-9]+]]
748+
// CHECK-NEXT: return %[[VEC]]
749+
750+
// -----
751+
740752
func.func @negative_transpose_with_no_unit_dims(%vec: vector<4x2x3xf32>) -> vector<4x3x2xf32> {
741753
%res = vector.transpose %vec, [0, 2, 1] : vector<4x2x3xf32> to vector<4x3x2xf32>
742754
return %res : vector<4x3x2xf32>

0 commit comments

Comments
 (0)