Skip to content

Commit b9a071d

Browse files
authored
[mlir][Linalg] Add folders for linalg.transpose (#81709)
This PR adds folders for linalg transpose ops with only one dimension or an identity permutation. The folding removes the `linalg.transpose` and just propagates the input tensor.
1 parent 8603a7b commit b9a071d

File tree

7 files changed

+69
-50
lines changed

7 files changed

+69
-50
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,7 @@ def MapOp : LinalgStructuredBase_Op<"map", [
245245
}
246246
```
247247

248-
Shortened print form is available. Applies to simple maps with one
248+
Shortened print form is available. Applies to simple maps with one
249249
non-yield operation inside the body.
250250

251251
The example above will be printed as:
@@ -458,6 +458,7 @@ def TransposeOp : LinalgStructuredBase_Op<"transpose", [
458458
::mlir::OperationState & odsState);
459459
}];
460460

461+
let hasFolder = 1;
461462
let hasCustomAssemblyFormat = 1;
462463
let hasVerifier = 1;
463464
}

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1786,6 +1786,22 @@ void TransposeOp::getEffects(
17861786
getDpsInits());
17871787
}
17881788

1789+
LogicalResult TransposeOp::fold(FoldAdaptor adaptor,
1790+
SmallVectorImpl<OpFoldResult> &result) {
1791+
// Single dimension transpose.
1792+
if (getPermutation().size() == 0) {
1793+
result.push_back(getInput());
1794+
return success();
1795+
}
1796+
// Identity permutation.
1797+
if (isIdentityPermutation(getPermutation())) {
1798+
result.push_back(getInput());
1799+
return success();
1800+
}
1801+
1802+
return failure();
1803+
}
1804+
17891805
//===----------------------------------------------------------------------===//
17901806
// BroadcastOp
17911807
//===----------------------------------------------------------------------===//

mlir/test/Dialect/Linalg/canonicalize.mlir

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1029,3 +1029,38 @@ func.func @broadcast_same_shape(%input: tensor<2x3xf32>, %init: tensor<2x3xf32>)
10291029
%0 = linalg.broadcast ins(%input: tensor<2x3xf32>) outs(%init: tensor<2x3xf32>) dimensions = []
10301030
return %0 : tensor<2x3xf32>
10311031
}
1032+
1033+
// ----
1034+
1035+
func.func @transpose_1d(%input: tensor<16xf32>,
1036+
%init: tensor<16xf32>) -> tensor<16xf32> {
1037+
%transpose = linalg.transpose
1038+
ins(%input:tensor<16xf32>)
1039+
outs(%init:tensor<16xf32>)
1040+
permutation = [0]
1041+
func.return %transpose : tensor<16xf32>
1042+
}
1043+
1044+
// CHECK-LABEL: func @transpose_1d(
1045+
// CHECK-SAME: %[[INPUT:[a-zA-Z0-9]+]]: tensor<16xf32>,
1046+
// CHECK-SAME: %[[INIT:[a-zA-Z0-9]+]]: tensor<16xf32>)
1047+
// CHECK-NOT: linalg.transpose
1048+
// CHECK: return %[[INPUT]] : tensor<16xf32>
1049+
1050+
// -----
1051+
1052+
func.func @transpose_identity_perm(%input: tensor<16x32x64xf32>,
1053+
%init: tensor<16x32x64xf32>) -> tensor<16x32x64xf32> {
1054+
%transpose = linalg.transpose
1055+
ins(%input:tensor<16x32x64xf32>)
1056+
outs(%init:tensor<16x32x64xf32>)
1057+
permutation = [0, 1, 2]
1058+
func.return %transpose : tensor<16x32x64xf32>
1059+
}
1060+
1061+
// CHECK-LABEL: func @transpose_identity_perm(
1062+
// CHECK-SAME: %[[INPUT:[a-zA-Z0-9]+]]: tensor<16x32x64xf32>,
1063+
// CHECK-SAME: %[[INIT:[a-zA-Z0-9]+]]: tensor<16x32x64xf32>)
1064+
// CHECK-NOT: linalg.transpose
1065+
// CHECK: return %[[INPUT]] : tensor<16x32x64xf32>
1066+

mlir/test/Dialect/Linalg/generalize-tensor-pack-tile.mlir

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -48,12 +48,8 @@ func.func @pad_and_pack(%arg0: tensor<13x15xf32>, %arg1: tensor<2x8x8x2xf32>, %a
4848
// CHECK: %[[PAD:.+]] = tensor.pad %[[SRC_SLICE]]
4949
// CHECK: tensor.yield %[[PAD_VAL]]
5050
// CHECK: } : tensor<?x?xf32> to tensor<8x2xf32>
51-
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8x2xf32>
52-
// CHECK: %[[TRANSP:.+]] = linalg.transpose
53-
// CHECK-SAME: ins(%[[PAD]] : tensor<8x2xf32>)
54-
// CHECK-SAME: outs(%[[EMPTY]] : tensor<8x2xf32>)
55-
// CHECK-SAME: permutation = [0, 1]
56-
// CHECK: %{{.+}} = tensor.insert_slice %[[TRANSP]] into %{{.+}}
51+
// CHECK-NOT: linalg.transpose
52+
// CHECK: %{{.+}} = tensor.insert_slice %[[PAD]] into %{{.+}}
5753

5854
module attributes {transform.with_named_sequence} {
5955
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
@@ -81,12 +77,8 @@ func.func @KC_to_CKkc(%arg0: tensor<128x256xf32>, %arg1: tensor<32x4x32x8xf32>)
8177
// CHECK-DAG: %[[IN_C:.+]] = affine.apply #[[MAP2]](%[[C]])
8278
// CHECK: %[[TILE:.+]] = tensor.extract_slice %[[SRC]]
8379
// CHECK-SAME: [%[[IN_K]], %[[IN_C]]] [32, 8] [1, 1]
84-
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<32x8xf32>
85-
// CHECK: %[[TRANSP:.+]] = linalg.transpose
86-
// CHECK-SAME: ins(%[[TILE]]
87-
// CHECK-SAME: outs(%[[EMPTY]]
88-
// CHECK-SAME: permutation = [0, 1]
89-
// CHECK: %[[SUB_ITER:.+]] = tensor.insert_slice %[[TRANSP]] into %{{[a-zA-Z0-9]+}}
80+
// CHECK-NOT: linalg.transpose
81+
// CHECK: %[[SUB_ITER:.+]] = tensor.insert_slice %[[TILE]] into %{{[a-zA-Z0-9]+}}
9082
// CHECK-SAME: [0, 0, 0, 0] [1, 1, 32, 8] [1, 1, 1, 1] : tensor<32x8xf32> into tensor<1x1x32x8xf32>
9183
// CHECK: %{{.+}} = tensor.insert_slice %[[SUB_ITER]] into %{{[a-zA-Z0-9]+}}
9284
// CHECK-SAME: [%[[C]], %[[K]], 0, 0] [1, 1, 32, 8] [1, 1, 1, 1] : tensor<1x1x32x8xf32> into tensor<32x4x32x8xf32>

mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,8 @@ func.func @simple_pad_and_pack(%input: tensor<5x1xf32>, %output: tensor<1x1x8x2x
2929
// CHECK-SAME: %[[PAD_VAL:[a-zA-Z0-9]+]]
3030
// CHECK: %[[PAD:.+]] = tensor.pad %[[SRC]] low[0, 0] high[3, 1]
3131
// CHECK: tensor.yield %[[PAD_VAL]]
32-
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8x2xf32>
33-
// CHECK: %[[TRANSP:.+]] = linalg.transpose
34-
// CHECK-SAME: ins(%[[PAD]] : tensor<8x2xf32>)
35-
// CHECK-SAME: outs(%[[EMPTY]] : tensor<8x2xf32>)
36-
// CHECK-SAME: permutation = [0, 1]
37-
// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[TRANSP]] into %[[DEST]]
32+
// CHECK-NOT: linalg.transpose
33+
// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[PAD]] into %[[DEST]]
3834
// CHECK-SAME: [0, 0, 0, 0] [1, 1, 8, 2] [1, 1, 1, 1]
3935
// CHECK: return %[[INSERT]]
4036

@@ -47,12 +43,8 @@ func.func @simple_NC_to_CNnc(%arg0: tensor<32x8xf32>, %arg1: tensor<1x1x32x8xf32
4743
// CHECK-LABEL: func.func @simple_NC_to_CNnc
4844
// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]
4945
// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
50-
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<32x8xf32>
51-
// CHECK: %[[TRANSP:.+]] = linalg.transpose
52-
// CHECK-SAME: ins(%[[SRC]] : tensor<32x8xf32>)
53-
// CHECK-SAME: outs(%[[EMPTY]] : tensor<32x8xf32>)
54-
// CHECK-SAME: permutation = [0, 1]
55-
// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[TRANSP]] into %[[DEST]]
46+
// CHECK-NOT: linalg.transpose
47+
// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[SRC]] into %[[DEST]]
5648
// CHECK-SAME: [0, 0, 0, 0] [1, 1, 32, 8] [1, 1, 1, 1]
5749
// CHECK: return %[[INSERT]]
5850

mlir/test/Dialect/Linalg/generalize-tensor-unpack-tile.mlir

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -57,12 +57,8 @@ func.func @unpack_and_extract_slice(%arg0: tensor<2x8x8x2xf32>, %arg1: tensor<13
5757
// CHECK-SAME: [%[[I]], %[[J]]] [%[[OUT_I_SZ]], %[[OUT_J_SZ]]]
5858
// CHECK: %[[TILE:.+]] = tensor.extract_slice %[[SRC_SLICE]]
5959
// CHECK-SAME: [0, 0, 0, 0] [1, 1, 8, 2] [1, 1, 1, 1] : tensor<1x1x8x2xf32> to tensor<8x2xf32>
60-
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8x2xf32>
61-
// CHECK: %[[TRANSP:.+]] = linalg.transpose
62-
// CHECK-SAME: ins(%[[TILE]] : tensor<8x2xf32>)
63-
// CHECK-SAME: outs(%[[EMPTY]] : tensor<8x2xf32>)
64-
// CHECK-SAME: permutation = [0, 1]
65-
// CHECK: %[[UNPACK_TILE:.+]] = tensor.extract_slice %[[TRANSP]]
60+
// CHECK-NOT: linalg.transpose
61+
// CHECK: %[[UNPACK_TILE:.+]] = tensor.extract_slice %[[TILE]]
6662
// CHECK-SAME: [0, 0] [%[[OUT_I_SZ]], %[[OUT_J_SZ]]] [1, 1]
6763
// CHECK: %[[INSERT1:.+]] = tensor.insert_slice %[[UNPACK_TILE]] into %[[ITER_SLICE]]
6864
// CHECK-SAME: [0, 0] [%[[OUT_I_SZ]], %[[OUT_J_SZ]]] [1, 1]
@@ -96,12 +92,8 @@ func.func @CKkc_to_KC(%arg0: tensor<32x4x32x8xf32>, %arg1: tensor<128x256xf32>)
9692
// CHECK-SAME: [%[[IN_C]], %[[IN_K]], 0, 0] [1, 1, 32, 8] [1, 1, 1, 1]
9793
// CHECK: %[[TILE:.+]] = tensor.extract_slice %[[SRC_SLICE]]
9894
// CHECK-SAME: [0, 0, 0, 0] [1, 1, 32, 8] [1, 1, 1, 1] : tensor<1x1x32x8xf32> to tensor<32x8xf32>
99-
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<32x8xf32>
100-
// CHECK: %[[TRANSP:.+]] = linalg.transpose
101-
// CHECK-SAME: ins(%[[TILE]]
102-
// CHECK-SAME: outs(%[[EMPTY]]
103-
// CHECK-SAME: permutation = [0, 1]
104-
// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[TRANSP]] into %{{[a-zA-Z0-9]+}}
95+
// CHECK-NOT: linalg.transpose
96+
// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[TILE]] into %{{[a-zA-Z0-9]+}}
10597
// CHECK-SAME: [%[[K]], %[[C]]] [32, 8] [1, 1]
10698

10799

mlir/test/Dialect/Linalg/generalize-tensor-unpack.mlir

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -27,14 +27,10 @@ func.func @simple_unpack_and_extract_slice(%input: tensor<1x1x8x2xf32>, %output:
2727
// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]
2828
// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
2929
// CHECK: %[[TILE:.+]] = tensor.extract_slice %[[SRC]][0, 0, 0, 0] [1, 1, 8, 2] [1, 1, 1, 1]
30-
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8x2xf32>
31-
// CHECK: %[[TRANSP:.+]] = linalg.transpose
32-
// CHECK-SAME: ins(%[[TILE]] : tensor<8x2xf32>)
33-
// CHECK-SAME: outs(%[[EMPTY]] : tensor<8x2xf32>)
34-
// CHECK-SAME: permutation = [0, 1]
30+
// CHECK-NOT: linalg.transpose
3531
// They have the same type, so the insert_slice op is folded
3632
// away.
37-
// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[TRANSP]][0, 0] [5, 1] [1, 1]
33+
// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[TILE]][0, 0] [5, 1] [1, 1]
3834
// CHECK: return %[[SLICE]]
3935

4036
// -----
@@ -47,14 +43,10 @@ func.func @simple_CNnc_to_NC(%arg0: tensor<1x1x32x8xf32>, %arg1: tensor<32x8xf32
4743
// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]
4844
// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
4945
// CHECK: %[[TILE:.+]] = tensor.extract_slice %[[SRC]][0, 0, 0, 0] [1, 1, 32, 8] [1, 1, 1, 1]
50-
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<32x8xf32>
51-
// CHECK: %[[TRANSP:.+]] = linalg.transpose
52-
// CHECK-SAME: ins(%[[TILE]] : tensor<32x8xf32>)
53-
// CHECK-SAME: outs(%[[EMPTY]] : tensor<32x8xf32>)
54-
// CHECK-SAME: permutation = [0, 1]
46+
// CHECK-NOT: linalg.transpose
5547
// They have the same type, so the insert_slice op is folded
5648
// away.
57-
// CHECK: return %[[TRANSP]]
49+
// CHECK: return %[[TILE]]
5850

5951
// -----
6052

@@ -75,7 +67,6 @@ func.func @simple_NCHWc_to_NCHW(%arg0: tensor<2x1x16x8x32xf32>, %arg1: tensor<2x
7567
// away.
7668
// CHECK: return %[[TRANSP]]
7769

78-
7970
// -----
8071

8172
func.func @simple_NHWC_to_NCHW(%arg0: tensor<1x16x8x32xf32>, %arg1: tensor<1x32x16x8xf32>) -> tensor<1x32x16x8xf32> {

0 commit comments

Comments
 (0)