@@ -926,6 +926,24 @@ func.func @bubble_up_pack_through_collapse(%1: tensor<?x16x4xf32>, %dim : index)
926
926
927
927
// -----
928
928
929
+ func.func @bubble_up_pack_through_collapse_empty_outer_dims_perm (%1: tensor <?x16 x4 xf32 >, %dim : index ) -> tensor <?x4 x8 x1 xf32 > {
930
+ %collapsed = tensor.collapse_shape %1 [[0 , 1 ], [2 ]] : tensor <?x16 x4 xf32 > into tensor <?x4 xf32 >
931
+ %2 = tensor.empty (%dim ) : tensor <?x4 x8 x1 xf32 >
932
+ %pack = tensor.pack %collapsed inner_dims_pos = [0 , 1 ] inner_tiles = [8 , 1 ] into %2 : tensor <?x4 xf32 > -> tensor <?x4 x8 x1 xf32 >
933
+ func.return %pack : tensor <?x4 x8 x1 xf32 >
934
+ }
935
+ // CHECK-LABEL: func.func @bubble_up_pack_through_collapse_empty_outer_dims_perm
936
+ // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
937
+ // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
938
+ // CHECK: %[[C0:.+]] = arith.constant 0 : index
939
+ // CHECK: %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x16x4xf32>
940
+ // CHECK: %[[EMPTY:.+]] = tensor.empty(%[[DIM]]) : tensor<?x2x4x8x1xf32>
941
+ // CHECK: %[[PACK:.+]] = tensor.pack %[[ARG0]] inner_dims_pos = [1, 2] inner_tiles = [8, 1] into %[[EMPTY]] : tensor<?x16x4xf32> -> tensor<?x2x4x8x1xf32>
942
+ // CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[PACK]] {{\[}}[0, 1], [2], [3], [4]] : tensor<?x2x4x8x1xf32> into tensor<?x4x8x1xf32>
943
+ // CHECK: return %[[COLLAPSED]] : tensor<?x4x8x1xf32>
944
+
945
+ // -----
946
+
929
947
func.func @bubble_up_permuted_pack_through_collapse (%1: tensor <4 x192 x16 x256 xf32 >) -> tensor <4 x32 x3072 x8 x1 xf32 > {
930
948
%collapsed = tensor.collapse_shape %1 [[0 ], [1 , 2 ], [3 ]] : tensor <4 x192 x16 x256 xf32 > into tensor <4 x3072 x256 xf32 >
931
949
%2 = tensor.empty () : tensor <4 x32 x3072 x8 x1 xf32 >
@@ -1269,6 +1287,27 @@ func.func @push_down_unpack_through_expand(%5: tensor<?x32x8x8xf32>, %dim: index
1269
1287
1270
1288
// -----
1271
1289
1290
+ func.func @push_down_unpack_through_expand_empty_outer_dims_perm (%5: tensor <?x32 x8 x8 xf32 >, %dim: index , %sz0: index ) -> tensor <?x256 x256 xf32 > {
1291
+ %6 = tensor.empty (%dim ) : tensor <?x256 xf32 >
1292
+ %unpack = tensor.unpack %5 inner_dims_pos = [0 , 1 ] inner_tiles = [8 , 8 ] into %6 : tensor <?x32 x8 x8 xf32 > -> tensor <?x256 xf32 >
1293
+ %expanded = tensor.expand_shape %unpack [[0 , 1 ], [2 ]] output_shape [%sz0 , 256 , 256 ] : tensor <?x256 xf32 > into tensor <?x256 x256 xf32 >
1294
+ func.return %expanded : tensor <?x256 x256 xf32 >
1295
+ }
1296
+ // CHECK-LABEL: func.func @push_down_unpack_through_expand_empty_outer_dims_perm
1297
+ // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
1298
+ // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
1299
+ // CHECK: %[[C32:.+]] = arith.constant 32 : index
1300
+ // CHECK: %[[C0:.+]] = arith.constant 0 : index
1301
+ // CHECK: %[[DIM0:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x32x8x8xf32>
1302
+ // CHECK: %[[SZ0:.+]] = arith.divui %[[DIM0]], %[[C32]] : index
1303
+ // CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2], [3], [4]] output_shape [%[[SZ0]], 32, 32, 8, 8] : tensor<?x32x8x8xf32> into tensor<?x32x32x8x8xf32>
1304
+ // CHECK: %[[DIM:.+]] = tensor.dim %[[EXPANDED]], %[[C0]] : tensor<?x32x32x8x8xf32>
1305
+ // CHECK: %[[EMPTY:.+]] = tensor.empty(%[[DIM]]) : tensor<?x256x256xf32>
1306
+ // CHECK: %[[UNPACK:.+]] = tensor.unpack %[[EXPANDED:.+]] inner_dims_pos = [1, 2] inner_tiles = [8, 8] into %[[EMPTY]] : tensor<?x32x32x8x8xf32> -> tensor<?x256x256xf32>
1307
+ // CHECK: return %[[UNPACK]] : tensor<?x256x256xf32>
1308
+
1309
+ // -----
1310
+
1272
1311
func.func @push_down_permuted_unpack_through_expand (%5: tensor <4 x32 x384 x8 x8 xf32 >) -> tensor <4 x12 x256 x256 xf32 > {
1273
1312
%6 = tensor.empty () : tensor <4 x3072 x256 xf32 >
1274
1313
%unpack = tensor.unpack %5 outer_dims_perm = [0 , 2 , 1 ] inner_dims_pos = [2 , 1 ] inner_tiles = [8 , 8 ] into %6 : tensor <4 x32 x384 x8 x8 xf32 > -> tensor <4 x3072 x256 xf32 >
0 commit comments