Skip to content

Commit faf5d74

Browse files
Max191Max Dawkins
and
Max Dawkins
authored
[mlir] Fix DataLayoutPropagation foldings invalidating IR (#140103)
Fixes a bug in DataLayoutPropagation that was replacing generic op destinations with tensor.empty() ops, even when the destination operand was being used. Addresses post-merge comment: https://github.com/llvm/llvm-project/pull/138332/files/a9c1dccc3f73793bdd9e1f51ab3a6e15403a8338#r2091193712 Signed-off-by: Max Dawkins <[email protected]> Co-authored-by: Max Dawkins <[email protected]>
1 parent 2cdb7f3 commit faf5d74

File tree

2 files changed

+37
-17
lines changed

2 files changed

+37
-17
lines changed

mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -312,10 +312,17 @@ static GenericOp packGenericOp(RewriterBase &rewriter, GenericOp genericOp,
312312
SmallVector<Value> inputOperands;
313313
SmallVector<Value> inputOperandsFromUnpackedSource;
314314
SmallVector<AffineMap> indexingMaps;
315+
auto hasEquivalentTiles = [](PackOp packOp, UnPackOp unPackOp) {
316+
return packOp.getOuterDimsPerm() == unPackOp.getOuterDimsPerm() &&
317+
packOp.getInnerDimsPos() == unPackOp.getInnerDimsPos() &&
318+
llvm::equal(packOp.getMixedTiles(), unPackOp.getMixedTiles());
319+
};
315320
for (OpOperand *inputOperand : genericOp.getDpsInputOperands()) {
316321
auto [packedOperand, packedIndexingMap] = getOrCreatePackedViewOfOperand(
317322
rewriter, loc, packInfo, genericOp, inputOperand);
318-
if (auto unpackOp = inputOperand->get().getDefiningOp<linalg::UnPackOp>()) {
323+
auto unpackOp = inputOperand->get().getDefiningOp<linalg::UnPackOp>();
324+
auto packOp = packedOperand.getDefiningOp<linalg::PackOp>();
325+
if (packOp && unpackOp && hasEquivalentTiles(packOp, unpackOp)) {
319326
inputOperandsFromUnpackedSource.push_back(unpackOp.getSource());
320327
} else {
321328
inputOperandsFromUnpackedSource.push_back(packedOperand);
@@ -324,14 +331,16 @@ static GenericOp packGenericOp(RewriterBase &rewriter, GenericOp genericOp,
324331
indexingMaps.push_back(packedIndexingMap);
325332
}
326333

327-
// If the pack and unpack op can be folded:
328-
// 1) use unpack op source op for operand to fold unpack -> pack sequence.
329-
// 2) init tensor of the generic op can be replaced by the destination of the
330-
// pack op.
334+
// If the unpack->pack sequences can be folded, replace use the sources of
335+
// the unpack ops in any unpack->pack chains on the generic op operands.
331336
if (isFoldableUnpackPack) {
332337
inputOperands = inputOperandsFromUnpackedSource;
333-
if (auto destPack = dest.getDefiningOp<linalg::PackOp>())
334-
dest = destPack.getDest();
338+
if (auto destPack = dest.getDefiningOp<linalg::PackOp>()) {
339+
auto destUnPack = destPack.getSource().getDefiningOp<linalg::UnPackOp>();
340+
if (destUnPack && hasEquivalentTiles(destPack, destUnPack)) {
341+
dest = destUnPack.getSource();
342+
}
343+
}
335344
}
336345

337346
int64_t numInnerLoops = packInfo.getNumTiledLoops();

mlir/test/Dialect/Linalg/data-layout-propagation.mlir

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -455,10 +455,9 @@ func.func @unpack_on_output(%arg0: tensor<12x2x56x56x32xf32>) -> tensor<12x56x56
455455
// CHECK: %[[UNPACKED_ARG0:.+]] = linalg.unpack %[[ARG0]]
456456
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
457457
// CHECK-SAME: into %[[ARG0_EMPTY_UNPACK]]
458-
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
459458
// CHECK: %[[RES:.+]] = linalg.generic
460459
// CHECK-SAME: indexing_maps = [#[[$MAP]]]
461-
// CHECK-SAME: outs(%[[EMPTY]]
460+
// CHECK-SAME: outs(%[[ARG0]]
462461
// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[RES]]
463462
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
464463
// CHECK-SAME: into %[[UNPACKED_ARG0]]
@@ -482,11 +481,14 @@ func.func @unpack_on_input(%arg0: tensor<12x2x56x56x32xf32>, %init: tensor<12x56
482481
// CHECK-LABEL: func.func @unpack_on_input
483482
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
484483
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
485-
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
484+
// CHECK: %[[ARG1_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
485+
// CHECK: %[[ARG1_PACK:.+]] = linalg.pack %[[ARG1]]
486+
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
487+
// CHECK-SAME: into %[[ARG1_PACK_EMPTY]]
486488
// CHECK: %[[RES:.+]] = linalg.generic
487489
// CHECK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP]]]
488490
// CHECK-SAME: ins(%[[ARG0]]
489-
// CHECK-SAME: outs(%[[EMPTY]]
491+
// CHECK-SAME: outs(%[[ARG1_PACK]]
490492
// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[RES]]
491493
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
492494
// CHECK-SAME: into %[[ARG1]]
@@ -510,11 +512,14 @@ func.func @unpack_element_type_change(%arg0: tensor<12x2x56x56x32xf32>, %init: t
510512
// CHECK-LABEL: func.func @unpack_element_type_change
511513
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
512514
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
513-
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf16>
515+
// CHECK: %[[ARG1_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf16>
516+
// CHECK: %[[ARG1_PACK:.+]] = linalg.pack %[[ARG1]]
517+
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
518+
// CHECK-SAME: into %[[ARG1_PACK_EMPTY]]
514519
// CHECK: %[[RES:.+]] = linalg.generic
515520
// CHECK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP]]]
516521
// CHECK-SAME: ins(%[[ARG0]]
517-
// CHECK-SAME: outs(%[[EMPTY]]
522+
// CHECK-SAME: outs(%[[ARG1_PACK]]
518523
// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[RES]]
519524
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
520525
// CHECK-SAME: into %[[ARG1]]
@@ -1397,10 +1402,13 @@ func.func @push_unpack_in_padded_domain_foldable(%arg0: tensor<8x8x4x8xf32>, %de
13971402
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
13981403
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
13991404
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]
1400-
// CHECK: %[[EMPTY:.+]] = tensor.empty
1405+
// CHECK: %[[ARG2_PACK_EMPTY:.+]] = tensor.empty
1406+
// CHECK: %[[ARG2_PACK:.+]] = linalg.pack %[[ARG2]]
1407+
// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [4, 8]
1408+
// CHECK-SAME: into %[[ARG2_PACK_EMPTY]]
14011409
// CHECK: %[[GENERIC:.+]] = linalg.generic
14021410
// CHECK-SAME: ins(%[[ARG0]] : tensor<8x8x4x8xf32>)
1403-
// CHECK-SAME: outs(%[[EMPTY]] : tensor<?x8x4x8xbf16>)
1411+
// CHECK-SAME: outs(%[[ARG2_PACK]] : tensor<?x8x4x8xbf16>)
14041412
// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[GENERIC]]
14051413
// CHECK-SAME: into %[[ARG2]]
14061414
// CHECK: return %[[UNPACK]] : tensor<?x64xbf16>
@@ -1419,10 +1427,13 @@ func.func @push_unpack_in_padded_domain_out_used(%arg0: tensor<8x8x4x8xf32>, %ar
14191427
// CHECK-LABEL: func.func @push_unpack_in_padded_domain_out_used
14201428
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
14211429
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
1422-
// CHECK: %[[EMPTY:.+]] = tensor.empty
1430+
// CHECK: %[[ARG1_PACK_EMPTY:.+]] = tensor.empty
1431+
// CHECK: %[[ARG1_PACK:.+]] = linalg.pack %[[ARG1]]
1432+
// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [4, 8]
1433+
// CHECK-SAME: into %[[ARG1_PACK_EMPTY]]
14231434
// CHECK: %[[GENERIC:.+]] = linalg.generic
14241435
// CHECK-SAME: ins(%[[ARG0]] : tensor<8x8x4x8xf32>)
1425-
// CHECK-SAME: outs(%[[EMPTY]] : tensor<?x8x4x8xf32>)
1436+
// CHECK-SAME: outs(%[[ARG1_PACK]] : tensor<?x8x4x8xf32>)
14261437
// CHECK: %[[UNPACK2:.+]] = linalg.unpack %[[GENERIC]]
14271438
// CHECK-SAME: into %[[ARG1]]
14281439
// CHECK: return %[[UNPACK2]] : tensor<?x64xf32>

0 commit comments

Comments
 (0)