Skip to content

Commit 457afd9

Browse files
committed
catch additional foldable case
1 parent 3715de9 commit 457afd9

File tree

3 files changed

+49
-38
lines changed

3 files changed

+49
-38
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5575,13 +5575,11 @@ LogicalResult ShapeCastOp::verify() {
55755575
return success();
55765576
}
55775577

5578-
namespace {
5579-
55805578
/// Return true if `transpose` does not permute a pair of non-unit dims.
55815579
/// By `order preserving` we mean that the flattened versions of the input and
55825580
/// output vectors are (numerically) identical. In other words `transpose` is
55835581
/// effectively a shape cast.
5584-
bool isOrderPreserving(TransposeOp transpose) {
5582+
static bool isOrderPreserving(TransposeOp transpose) {
55855583
ArrayRef<int64_t> permutation = transpose.getPermutation();
55865584
VectorType sourceType = transpose.getSourceVectorType();
55875585
ArrayRef<int64_t> inShape = sourceType.getShape();
@@ -5601,8 +5599,6 @@ bool isOrderPreserving(TransposeOp transpose) {
56015599
return true;
56025600
}
56035601

5604-
} // namespace
5605-
56065602
OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
56075603

56085604
VectorType resultType = getType();
@@ -5999,18 +5995,22 @@ OpFoldResult vector::TransposeOp::fold(FoldAdaptor adaptor) {
59995995
if (llvm::dyn_cast_if_present<ub::PoisonAttr>(adaptor.getVector()))
60005996
return ub::PoisonAttr::get(getContext());
60015997

6002-
// Eliminate identity transpose ops. This happens when the dimensions of the
6003-
// input vector remain in their original order after the transpose operation.
6004-
ArrayRef<int64_t> perm = getPermutation();
6005-
6006-
// Check if the permutation of the dimensions contains sequential values:
6007-
// {0, 1, 2, ...}.
6008-
for (int64_t i = 0, e = perm.size(); i < e; i++) {
6009-
if (perm[i] != i)
6010-
return {};
6011-
}
5998+
// Eliminate identity transposes, and more generally any transposes that
5999+
// preserves the shape without permuting elements.
6000+
//
6001+
// Examples of what to fold:
6002+
// %0 = vector.transpose %arg, [0, 1] : vector<1x1xi8> to vector<1x1xi8>
6003+
// %0 = vector.transpose %arg, [0, 1] : vector<2x2xi8> to vector<2x2xi8>
6004+
// %0 = vector.transpose %arg, [1, 0] : vector<1x1xi8> to vector<1x1xi8>
6005+
//
6006+
// Example of what NOT to fold:
6007+
// %0 = vector.transpose %arg, [1, 0] : vector<2x2xi8> to vector<2x2xi8>
6008+
//
6009+
if (getSourceVectorType() == getResultVectorType() &&
6010+
isOrderPreserving(*this))
6011+
return getVector();
60126012

6013-
return getVector();
6013+
return {};
60146014
}
60156015

60166016
LogicalResult vector::TransposeOp::verify() {

mlir/test/Dialect/Vector/canonicalize.mlir

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -450,28 +450,6 @@ func.func @extract_strided_fold_insert(%a: vector<2x8xf32>, %b: vector<1x4xf32>,
450450

451451
// -----
452452

453-
// CHECK-LABEL: transpose_1D_identity
454-
// CHECK-SAME: ([[ARG:%.*]]: vector<4xf32>)
455-
func.func @transpose_1D_identity(%arg : vector<4xf32>) -> vector<4xf32> {
456-
// CHECK-NOT: transpose
457-
%0 = vector.transpose %arg, [0] : vector<4xf32> to vector<4xf32>
458-
// CHECK-NEXT: return [[ARG]]
459-
return %0 : vector<4xf32>
460-
}
461-
462-
// -----
463-
464-
// CHECK-LABEL: transpose_2D_identity
465-
// CHECK-SAME: ([[ARG:%.*]]: vector<4x3xf32>)
466-
func.func @transpose_2D_identity(%arg : vector<4x3xf32>) -> vector<4x3xf32> {
467-
// CHECK-NOT: transpose
468-
%0 = vector.transpose %arg, [0, 1] : vector<4x3xf32> to vector<4x3xf32>
469-
// CHECK-NEXT: return [[ARG]]
470-
return %0 : vector<4x3xf32>
471-
}
472-
473-
// -----
474-
475453
// CHECK-LABEL: transpose_3D_identity
476454
// CHECK-SAME: ([[ARG:%.*]]: vector<4x3x2xf32>)
477455
func.func @transpose_3D_identity(%arg : vector<4x3x2xf32>) -> vector<4x3x2xf32> {

mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,3 +247,36 @@ func.func @negative_shape_cast_transpose(%arg : vector<6xi8>) -> vector<2x3xi8>
247247
%1 = vector.transpose %0, [1, 0] : vector<3x2xi8> to vector<2x3xi8>
248248
return %1 : vector<2x3xi8>
249249
}
250+
251+
// -----
252+
253+
// Test of transpose folding
254+
// CHECK-LABEL: transpose_1D_identity
255+
// CHECK-SAME: [[ARG:%.*]]: vector<4xf32>
256+
// CHECK-NEXT: return [[ARG]]
257+
func.func @transpose_1D_identity(%arg : vector<4xf32>) -> vector<4xf32> {
258+
%0 = vector.transpose %arg, [0] : vector<4xf32> to vector<4xf32>
259+
return %0 : vector<4xf32>
260+
}
261+
262+
// -----
263+
264+
// Test of transpose folding
265+
// CHECK-LABEL: transpose_2D_identity
266+
// CHECK-SAME: [[ARG:%.*]]: vector<4x3xf32>
267+
// CHECK-NEXT: return [[ARG]]
268+
func.func @transpose_2D_identity(%arg : vector<4x3xf32>) -> vector<4x3xf32> {
269+
%0 = vector.transpose %arg, [0, 1] : vector<4x3xf32> to vector<4x3xf32>
270+
return %0 : vector<4x3xf32>
271+
}
272+
273+
// -----
274+
275+
// Test of transpose folding
276+
// CHECK-LABEL: transpose_shape_and_order_preserving
277+
// CHECK-SAME: [[ARG:%.*]]: vector<6x1x1x4xi8>
278+
// CHECK-NEXT: return [[ARG]]
279+
func.func @transpose_shape_and_order_preserving(%arg : vector<6x1x1x4xi8>) -> vector<6x1x1x4xi8> {
280+
%0 = vector.transpose %arg, [0, 2, 1, 3] : vector<6x1x1x4xi8> to vector<6x1x1x4xi8>
281+
return %0 : vector<6x1x1x4xi8>
282+
}

0 commit comments

Comments
 (0)