Skip to content

Commit 536486f

Browse files
[MLIR][Linalg] Fix DataLayoutPropagation for tensor.unpack + linalg.generic (llvm#101755)
-- While pushing down tensor.unpack through linalg.generic we should take into account DPS. The current implementation was enforcing creating a tensor.empty() for the final output value. This should've just been the outs operand of the original linalg.generic. -- This commit thus adds a fix for the same. Signed-off-by: Abhishek Varma <[email protected]>
1 parent 6d68860 commit 536486f

File tree

2 files changed

+9
-21
lines changed

2 files changed

+9
-21
lines changed

mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1106,23 +1106,11 @@ pushDownUnPackOpThroughGenericOp(RewriterBase &rewriter, GenericOp genericOp,
11061106
auto innerDimsPos = destPack.getInnerDimsPos();
11071107
auto outerDimsPerm = destPack.getOuterDimsPerm();
11081108

1109-
// If the output type for the generic differs from the source
1110-
// unpack op, we need to create a new destination tensor. In the
1111-
// dynamic case we always need a new destination.
1112-
auto loc = genericOp.getLoc();
1113-
Value unPackDest = producerUnPackOp.getDest();
1114-
auto genericOutType =
1115-
cast<RankedTensorType>(genericOp.getDpsInitOperand(0)->get().getType());
1116-
if (producerUnPackOp.getDestType() != genericOutType ||
1117-
!genericOutType.hasStaticShape()) {
1118-
unPackDest = tensor::UnPackOp::createDestinationTensor(
1119-
rewriter, loc, newResult, mixedTiles, innerDimsPos, outerDimsPerm);
1120-
}
1121-
11221109
// Insert an unPackOp right after the packed generic.
11231110
Value unPackOpRes =
11241111
rewriter
1125-
.create<tensor::UnPackOp>(loc, newResult, unPackDest, innerDimsPos,
1112+
.create<tensor::UnPackOp>(genericOp.getLoc(), newResult,
1113+
destPack.getSource(), innerDimsPos,
11261114
mixedTiles, outerDimsPerm)
11271115
.getResult();
11281116

mlir/test/Dialect/Linalg/data-layout-propagation.mlir

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -436,7 +436,7 @@ func.func @unpack_on_output(%arg0: tensor<12x2x56x56x32xf32>) -> tensor<12x56x56
436436
// CHECK-SAME: outs(%[[PACKED_ARG0]]
437437
// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[RES]]
438438
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
439-
// CHECK-SAME: into %[[ARG0_EMPTY_UNPACK]]
439+
// CHECK-SAME: into %[[UNPACKED_ARG0]]
440440

441441
// -----
442442

@@ -475,7 +475,7 @@ func.func @unpack_on_input(%arg0: tensor<12x2x56x56x32xf32>, %init: tensor<12x56
475475
// CHECK-SAME: outs(%[[ARG1_PACK]]
476476
// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[RES]]
477477
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
478-
// CHECK-SAME: into %[[ARG0_UNPACK_EMPTY]]
478+
// CHECK-SAME: into %[[ARG1]]
479479

480480
// -----
481481

@@ -512,10 +512,9 @@ func.func @unpack_element_type_change(%arg0: tensor<12x2x56x56x32xf32>, %init: t
512512
// CHECK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP]]]
513513
// CHECK-SAME: ins(%[[ARG0_PACK]]
514514
// CHECK-SAME: outs(%[[ARG1_PACK]]
515-
// CHECK: %[[ARG0_NEW_EMPTY_UNPACK:.+]] = tensor.empty() : tensor<12x56x56x64xf16>
516515
// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[RES]]
517516
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
518-
// CHECK-SAME: into %[[ARG0_NEW_EMPTY_UNPACK]]
517+
// CHECK-SAME: into %[[ARG1]]
519518

520519
// -----
521520

@@ -536,6 +535,7 @@ func.func @forward_tensor_empty(%arg0: tensor<12x2x56x56x32xf32>) -> tensor<12x5
536535
// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
537536
// CHECK-LABEL: func.func @forward_tensor_empty
538537
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
538+
// CHECK: %[[FINAL_RES:.+]] = tensor.empty() : tensor<12x56x56x64xf32>
539539
// CHECK: %[[ARG0_UNPACK_EMPTY:.+]] = tensor.empty() : tensor<12x56x56x64xf32>
540540
// CHECK: %[[UNPACKED_ARG0:.+]] = tensor.unpack %[[ARG0]]
541541
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
@@ -551,7 +551,7 @@ func.func @forward_tensor_empty(%arg0: tensor<12x2x56x56x32xf32>) -> tensor<12x5
551551
// CHECK-SAME: outs(%[[DEST]]
552552
// CHECK: %[[UNPACKED:.+]] = tensor.unpack %[[RES]]
553553
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
554-
// CHECK-SAME: into %[[ARG0_UNPACK_EMPTY]]
554+
// CHECK-SAME: into %[[FINAL_RES]]
555555

556556
// -----
557557

@@ -913,6 +913,7 @@ func.func @unpack_different_destination_shape(%arg0: tensor<1x1x1080x1920x16xi32
913913
// CHECK-LABEL: func.func @unpack_different_destination_shape
914914
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
915915
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
916+
// CHECK: %[[FINAL_RES:.+]] = tensor.empty() : tensor<16x540x960xi32>
916917
// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x540x960x16xi32>
917918
// CHECK: %[[PACK_EMPTY:.+]] = tensor.empty() : tensor<1x1x1080x1920x16xi32>
918919
// CHECK: %[[PACK_ARG0:.+]] = tensor.pack
@@ -923,10 +924,9 @@ func.func @unpack_different_destination_shape(%arg0: tensor<1x1x1080x1920x16xi32
923924
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "parallel"]
924925
// CHECK-SAME: ins(%[[PACK_ARG0]], %[[ARG1]]
925926
// CHECK-SAME: outs(%[[INIT]]
926-
// CHECK: %[[UNPACK_NEW_DEST:.+]] = tensor.empty() : tensor<16x540x960xi32>
927927
// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[POOL]]
928928
// CHECK-SAME: inner_dims_pos = [0] inner_tiles = [16]
929-
// CHECK-SAME: into %[[UNPACK_NEW_DEST]]
929+
// CHECK-SAME: into %[[FINAL_RES]]
930930
// CHECK: return %[[UNPACK]] : tensor<16x540x960xi32>
931931

932932
// -----

0 commit comments

Comments
 (0)