Skip to content

Commit ae08e41

Browse files
yifeizh2AlexisPerry
authored andcommitted
[mlir][linalg] Fix empty outer dim case for packing reshape op (llvm#96732)
This PR fixes the issue reported in [comment](llvm#93529 (comment)).
1 parent b7aaa68 commit ae08e41

File tree

2 files changed

+41
-1
lines changed

2 files changed

+41
-1
lines changed

mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -605,7 +605,8 @@ static bool isDimsDivisibleByTileSizes(ArrayRef<int64_t> dimsPos,
605605
static int64_t applyPermutationAndReindexReassoc(
606606
SmallVector<ReassociationIndices> &reassocIndices,
607607
ArrayRef<int64_t> permutation) {
608-
applyPermutationToVector<ReassociationIndices>(reassocIndices, permutation);
608+
if (!permutation.empty())
609+
applyPermutationToVector<ReassociationIndices>(reassocIndices, permutation);
609610
int64_t nextPos = 0;
610611
for (ReassociationIndices &indices : reassocIndices) {
611612
for (auto &index : indices) {

mlir/test/Dialect/Linalg/data-layout-propagation.mlir

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -926,6 +926,24 @@ func.func @bubble_up_pack_through_collapse(%1: tensor<?x16x4xf32>, %dim : index)
926926

927927
// -----
928928

929+
func.func @bubble_up_pack_through_collapse_empty_outer_dims_perm(%1: tensor<?x16x4xf32>, %dim : index) -> tensor<?x4x8x1xf32> {
930+
%collapsed = tensor.collapse_shape %1 [[0, 1], [2]] : tensor<?x16x4xf32> into tensor<?x4xf32>
931+
%2 = tensor.empty(%dim) : tensor<?x4x8x1xf32>
932+
%pack = tensor.pack %collapsed inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %2 : tensor<?x4xf32> -> tensor<?x4x8x1xf32>
933+
func.return %pack : tensor<?x4x8x1xf32>
934+
}
935+
// CHECK-LABEL: func.func @bubble_up_pack_through_collapse_empty_outer_dims_perm
936+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
937+
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
938+
// CHECK: %[[C0:.+]] = arith.constant 0 : index
939+
// CHECK: %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x16x4xf32>
940+
// CHECK: %[[EMPTY:.+]] = tensor.empty(%[[DIM]]) : tensor<?x2x4x8x1xf32>
941+
// CHECK: %[[PACK:.+]] = tensor.pack %[[ARG0]] inner_dims_pos = [1, 2] inner_tiles = [8, 1] into %[[EMPTY]] : tensor<?x16x4xf32> -> tensor<?x2x4x8x1xf32>
942+
// CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[PACK]] {{\[}}[0, 1], [2], [3], [4]] : tensor<?x2x4x8x1xf32> into tensor<?x4x8x1xf32>
943+
// CHECK: return %[[COLLAPSED]] : tensor<?x4x8x1xf32>
944+
945+
// -----
946+
929947
func.func @bubble_up_permuted_pack_through_collapse(%1: tensor<4x192x16x256xf32>) -> tensor<4x32x3072x8x1xf32> {
930948
%collapsed = tensor.collapse_shape %1 [[0], [1, 2], [3]] : tensor<4x192x16x256xf32> into tensor<4x3072x256xf32>
931949
%2 = tensor.empty() : tensor<4x32x3072x8x1xf32>
@@ -1269,6 +1287,27 @@ func.func @push_down_unpack_through_expand(%5: tensor<?x32x8x8xf32>, %dim: index
12691287

12701288
// -----
12711289

1290+
func.func @push_down_unpack_through_expand_empty_outer_dims_perm(%5: tensor<?x32x8x8xf32>, %dim: index, %sz0: index) -> tensor<?x256x256xf32> {
1291+
%6 = tensor.empty(%dim) : tensor<?x256xf32>
1292+
%unpack = tensor.unpack %5 inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %6 : tensor<?x32x8x8xf32> -> tensor<?x256xf32>
1293+
%expanded = tensor.expand_shape %unpack [[0, 1], [2]] output_shape [%sz0, 256, 256] : tensor<?x256xf32> into tensor<?x256x256xf32>
1294+
func.return %expanded : tensor<?x256x256xf32>
1295+
}
1296+
// CHECK-LABEL: func.func @push_down_unpack_through_expand_empty_outer_dims_perm
1297+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
1298+
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
1299+
// CHECK: %[[C32:.+]] = arith.constant 32 : index
1300+
// CHECK: %[[C0:.+]] = arith.constant 0 : index
1301+
// CHECK: %[[DIM0:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x32x8x8xf32>
1302+
// CHECK: %[[SZ0:.+]] = arith.divui %[[DIM0]], %[[C32]] : index
1303+
// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2], [3], [4]] output_shape [%[[SZ0]], 32, 32, 8, 8] : tensor<?x32x8x8xf32> into tensor<?x32x32x8x8xf32>
1304+
// CHECK: %[[DIM:.+]] = tensor.dim %[[EXPANDED]], %[[C0]] : tensor<?x32x32x8x8xf32>
1305+
// CHECK: %[[EMPTY:.+]] = tensor.empty(%[[DIM]]) : tensor<?x256x256xf32>
1306+
// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[EXPANDED:.+]] inner_dims_pos = [1, 2] inner_tiles = [8, 8] into %[[EMPTY]] : tensor<?x32x32x8x8xf32> -> tensor<?x256x256xf32>
1307+
// CHECK: return %[[UNPACK]] : tensor<?x256x256xf32>
1308+
1309+
// -----
1310+
12721311
func.func @push_down_permuted_unpack_through_expand(%5: tensor<4x32x384x8x8xf32>) -> tensor<4x12x256x256xf32> {
12731312
%6 = tensor.empty() : tensor<4x3072x256xf32>
12741313
%unpack = tensor.unpack %5 outer_dims_perm = [0, 2, 1] inner_dims_pos = [2, 1] inner_tiles = [8, 8] into %6 : tensor<4x32x384x8x8xf32> -> tensor<4x3072x256xf32>

0 commit comments

Comments
 (0)