@@ -9,3 +9,68 @@ func.func @test_drop_rank_expansion(%src: tensor<128x480xf32>, %dest: tensor<1x1
9
9
%extracted_slice = tensor.extract_slice %inserted_slice [0 , 0 , 0 , 0 ] [1 , 1 , 123 , 456 ] [1 , 1 , 1 , 1 ] : tensor <1 x1 x128 x480 xf32 > to tensor <123 x456 xf32 >
10
10
return %extracted_slice : tensor <123 x456 xf32 >
11
11
}
12
+
13
+ // -----
14
+
15
+ func.func @fold_casting_insert_slice_of_extract_slice (%in : tensor <?x8 x2 x8 xf32 >, %dest : tensor <8 x1 x8 xf32 >) -> tensor <8 x1 x8 xf32 > {
16
+ %extracted_slice = tensor.extract_slice %in [0 , 0 , 0 , 0 ] [1 , 8 , 1 , 8 ] [1 , 1 , 1 , 1 ] : tensor <?x8 x2 x8 xf32 > to tensor <8 x8 xf32 >
17
+ %inserted_slice = tensor.insert_slice %extracted_slice into %dest [0 , 0 , 0 ] [8 , 1 , 8 ] [1 , 1 , 1 ] : tensor <8 x8 xf32 > into tensor <8 x1 x8 xf32 >
18
+ return %inserted_slice : tensor <8 x1 x8 xf32 >
19
+ }
20
+ // CHECK-LABEL: func.func @fold_casting_insert_slice_of_extract_slice(
21
+ // CHECK-SAME: %[[ARG0:.*]]: tensor<?x8x2x8xf32>
22
+ // CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][0, 0, 0, 0] [1, 8, 1, 8] [1, 1, 1, 1]
23
+ // CHECK-SAME: : tensor<?x8x2x8xf32> to tensor<8x1x8xf32>
24
+ // CHECK: return %[[EXTRACTED_SLICE]] : tensor<8x1x8xf32>
25
+
26
+ // -----
27
+
28
+ func.func @fold_casting_insert_slice_of_strided_extract_slice (%in : tensor <?x8 x2 x8 xf32 >, %dest : tensor <1 x4 x8 xf32 >) -> tensor <1 x4 x8 xf32 > {
29
+ %extracted_slice = tensor.extract_slice %in [0 , 0 , 0 , 0 ] [1 , 4 , 1 , 8 ] [1 , 2 , 1 , 1 ] : tensor <?x8 x2 x8 xf32 > to tensor <4 x8 xf32 >
30
+ %inserted_slice = tensor.insert_slice %extracted_slice into %dest [0 , 0 , 0 ] [1 , 4 , 8 ] [1 , 1 , 1 ] : tensor <4 x8 xf32 > into tensor <1 x4 x8 xf32 >
31
+ return %inserted_slice : tensor <1 x4 x8 xf32 >
32
+ }
33
+ // CHECK-LABEL: func.func @fold_casting_insert_slice_of_strided_extract_slice(
34
+ // CHECK-SAME: %[[ARG0:.*]]: tensor<?x8x2x8xf32>
35
+ // CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][0, 0, 0, 0] [1, 4, 1, 8] [1, 2, 1, 1]
36
+ // CHECK-SAME: : tensor<?x8x2x8xf32> to tensor<1x4x8xf32>
37
+ // CHECK: return %[[EXTRACTED_SLICE]] : tensor<1x4x8xf32>
38
+
39
+ // -----
40
+
41
+ func.func @no_fold_more_unit_dims_insert_slice_of_extract_slice (%in : tensor <?x8 x8 xf32 >, %dest : tensor <1 x1 x8 x8 xf32 >) -> tensor <1 x1 x8 x8 xf32 > {
42
+ %extracted_slice = tensor.extract_slice %in [0 , 0 , 0 ] [1 , 8 , 8 ] [1 , 1 , 1 ] : tensor <?x8 x8 xf32 > to tensor <8 x8 xf32 >
43
+ %inserted_slice = tensor.insert_slice %extracted_slice into %dest [0 , 0 , 0 , 0 ] [1 , 1 , 8 , 8 ] [1 , 1 , 1 , 1 ] : tensor <8 x8 xf32 > into tensor <1 x1 x8 x8 xf32 >
44
+ return %inserted_slice : tensor <1 x1 x8 x8 xf32 >
45
+ }
46
+ // CHECK-LABEL: func.func @no_fold_more_unit_dims_insert_slice_of_extract_slice(
47
+ // CHECK-SAME: %[[ARG0:.*]]: tensor<?x8x8xf32>
48
+ // CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]]
49
+ // CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[EXTRACTED_SLICE]]
50
+ // CHECK: return %[[INSERTED_SLICE]] : tensor<1x1x8x8xf32>
51
+
52
+ // -----
53
+
54
+ func.func @no_fold_strided_insert_slice_of_extract_slice (%in : tensor <?x8 x2 x8 xf32 >, %dest : tensor <1 x4 x4 xf32 >) -> tensor <1 x4 x4 xf32 > {
55
+ %extracted_slice = tensor.extract_slice %in [0 , 0 , 0 , 0 ] [1 , 8 , 1 , 8 ] [1 , 1 , 1 , 1 ] : tensor <?x8 x2 x8 xf32 > to tensor <8 x8 xf32 >
56
+ %inserted_slice = tensor.insert_slice %extracted_slice into %dest [0 , 0 , 0 ] [1 , 8 , 8 ] [1 , 2 , 2 ] : tensor <8 x8 xf32 > into tensor <1 x4 x4 xf32 >
57
+ return %inserted_slice : tensor <1 x4 x4 xf32 >
58
+ }
59
+ // CHECK-LABEL: func.func @no_fold_strided_insert_slice_of_extract_slice(
60
+ // CHECK-SAME: %[[ARG0:.*]]: tensor<?x8x2x8xf32>
61
+ // CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]]
62
+ // CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[EXTRACTED_SLICE]]
63
+ // CHECK: return %[[INSERTED_SLICE]] : tensor<1x4x4xf32>
64
+
65
+ // -----
66
+
67
+ func.func @no_fold_non_casting_insert_slice_of_extract_slice (%in : tensor <1 x1 x1 x8 x8 xf32 >, %dest : tensor <2 x8 x8 xf32 >) -> tensor <2 x8 x8 xf32 > {
68
+ %extracted_slice = tensor.extract_slice %in [0 , 0 , 0 , 0 , 0 ] [1 , 1 , 1 , 8 , 8 ] [1 , 1 , 1 , 1 , 1 ] : tensor <1 x1 x1 x8 x8 xf32 > to tensor <8 x8 xf32 >
69
+ %inserted_slice = tensor.insert_slice %extracted_slice into %dest [0 , 0 , 0 ] [1 , 8 , 8 ] [1 , 1 , 1 ] : tensor <8 x8 xf32 > into tensor <2 x8 x8 xf32 >
70
+ return %inserted_slice : tensor <2 x8 x8 xf32 >
71
+ }
72
+ // CHECK-LABEL: func.func @no_fold_non_casting_insert_slice_of_extract_slice(
73
+ // CHECK-SAME: %[[ARG0:.*]]: tensor<1x1x1x8x8xf32>
74
+ // CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]]
75
+ // CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[EXTRACTED_SLICE]]
76
+ // CHECK: return %[[INSERTED_SLICE]] : tensor<2x8x8xf32>
0 commit comments