Skip to content

Commit c06ddc0

Browse files
committed
[mlir] Add missing patterns to linalg.decompose_pack_unpack TD Op
This PR is a follow-up to llvm#116373 and llvm#116439, where a Transform Dialect (TD) operation was introduced to collect patterns for decomposing tensor.pack. The second patch renamed the patterns and the TD Op. Originally, adding patterns for `tensor.unpack` was marked as a TODO, which this PR addresses. No new tests are introduced in this PR. Instead, existing tests from: * "decompose-tensor-unpack.mlir" are reused. To achieve this: * The test is updated to use the TD operation `transform.apply_patterns.linalg.decompose_pack_unpack` instead of the flag `--test-linalg-transform-patterns="test-decompose-tensor-unpack"`, avoiding artificial tests created solely for the TD Op. * The TD sequence is saved to a new file, "decompose_unpack.mlir", and preloaded using the option.
1 parent 998bdae commit c06ddc0

File tree

4 files changed

+20
-3
lines changed

4 files changed

+20
-3
lines changed

mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1656,8 +1656,8 @@ void linalg::populateDecomposeConvolutionPatterns(RewritePatternSet &patterns,
16561656
}
16571657

16581658
void linalg::populateDecomposePackUnpackPatterns(RewritePatternSet &patterns) {
1659-
// TODO: Add and test patterns for tensor.unpack
16601659
patterns.add<DecomposeOuterUnitDimsPackOpPattern>(patterns.getContext());
1660+
patterns.add<DecomposeOuterUnitDimsUnPackOpPattern>(patterns.getContext());
16611661
}
16621662

16631663
void linalg::populateDecomposePadPatterns(RewritePatternSet &patterns) {

mlir/test/Dialect/Linalg/decompose-tensor-unpack-tile.mlir

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1-
// RUN: mlir-opt -split-input-file --transform-interpreter --canonicalize --test-linalg-transform-patterns="test-decompose-tensor-unpack" %s | FileCheck %s
1+
// RUN: mlir-opt -split-input-file -transform-interpreter --canonicalize \
2+
// RUN: -transform-preload-library='transform-library-paths=%p/td/decompose-unpack.mlir' \
3+
// RUN: -transform-interpreter=entry-point=decompose_unpack \
4+
// RUN: -transform-interpreter %s | FileCheck %s
25

36
func.func @KCRSsr_to_KCRS(%arg0: tensor<1x1x4x8x8x32xf32>, %arg1: tensor<1x1x128x64xf32>) -> tensor<1x1x128x64xf32> {
47
%0 = tensor.unpack %arg0 inner_dims_pos = [3, 2] inner_tiles = [8, 32] into %arg1 : tensor<1x1x4x8x8x32xf32> -> tensor<1x1x128x64xf32>

mlir/test/Dialect/Linalg/decompose-tensor-unpack.mlir

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
// RUN: mlir-opt -split-input-file --test-linalg-transform-patterns="test-decompose-tensor-unpack" %s | FileCheck %s
1+
// RUN: mlir-opt -split-input-file \
2+
// RUN: -transform-preload-library='transform-library-paths=%p/td/decompose-unpack.mlir' \
3+
// RUN: -transform-interpreter=entry-point=decompose_unpack %s | FileCheck %s
24

35
func.func @simple_KCRSsr_to_KCRS(%arg0: tensor<1x1x1x1x8x32xf32>, %arg1: tensor<1x1x32x8xf32>) -> tensor<1x1x32x8xf32> {
46
%0 = tensor.unpack %arg0 inner_dims_pos = [3, 2] inner_tiles = [8, 32] into %arg1 : tensor<1x1x1x1x8x32xf32> -> tensor<1x1x32x8xf32>
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
module @transforms attributes { transform.with_named_sequence } {
2+
transform.named_sequence @decompose_unpack(%module: !transform.any_op {transform.readonly}) {
3+
%pack = transform.structured.match ops{["tensor.unpack"]} in %module : (!transform.any_op) -> !transform.any_op
4+
5+
%1 = transform.get_parent_op %pack {isolated_from_above} : (!transform.any_op) -> !transform.any_op
6+
transform.apply_patterns to %1 {
7+
transform.apply_patterns.linalg.decompose_pack_unpack
8+
} : !transform.any_op
9+
10+
transform.yield
11+
}
12+
}

0 commit comments

Comments
 (0)