Skip to content

Commit 63b926a

Browse files
authored
[mlir] Add apply_patterns.linalg.generalize_pack_unpack TD Op (#116373)
This PR introduces populateGeneralizePatterns, which collects the following patterns: * `GeneralizeOuterUnitDimsPackOpPattern`, * `GeneralizeOuterUnitDimsUnPackOpPattern` (currently a TODO). These patterns are wrapped in a new Transform Dialect Op: `apply_patterns.linalg.generalize_pack_unpack`. This Op facilitates creating more involved end-to-end compilation pipelines for `tensor.pack` and `tensor.unpack` operations. It will be required in an upcoming PR building on top of #115698. No new tests are added in this PR. Instead, existing tests from: * "generalize-tensor-pack.mlir" are reused. To achieve this: * I've updated the test to use `transform.apply_patterns.linalg.generalize_pack_unpack` instead of the flag `--test-linalg-transform-patterns="test-generalize-tensor-pack"`, avoiding artificial tests solely for the TD Op. * The TD sequence is saved to a new file, "generalize_pack.mlir", and pre-loaded using the option: `--transform-preload-library='transform-library-paths=%p/td/generalize_pack.mlir'` This avoids duplicating the sequence for every "split" in the input file. * Added "lit.local.cfg" to exclude the "test/Dialect/Linalg/td" directory from test discovery, ensuring "generalize_pack.mlir" is not treated as a test file.
1 parent 1dcb3db commit 63b926a

File tree

7 files changed

+44
-4
lines changed

7 files changed

+44
-4
lines changed

mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,18 @@ def ApplyEraseUnnecessaryInputsPatternsOp : Op<Transform_Dialect,
4141
let assemblyFormat = "attr-dict";
4242
}
4343

44+
def ApplyGeneralizeTensorPackUnpackPatternsOp
45+
: Op<Transform_Dialect, "apply_patterns.linalg.generalize_pack_unpack",
46+
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
47+
let description = [{
48+
Collect patterns to generalize tensor.pack and tensor.unpack (i.e. to
49+
decompose it into e.g. tensor::PadOp, linalg::transposeOp etc). Requires
50+
all outer dims to be unit.
51+
}];
52+
53+
let assemblyFormat = "attr-dict";
54+
}
55+
4456
def ApplyFoldUnitExtentDimsViaReshapesPatternsOp : Op<Transform_Dialect,
4557
"apply_patterns.linalg.fold_unit_extent_dims_via_reshapes",
4658
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1516,8 +1516,8 @@ struct GeneralizePadOpPattern : public OpRewritePattern<tensor::PadOp> {
15161516
};
15171517

15181518
/// Rewrites a tensor::PackOp into a sequence of:
1519-
/// * tensor::PadOp + linalg::TransposeOp +
1520-
/// tensor::EmptyOp + tensor::InsertSliceOp ops.
1519+
/// * tensor::PadOp + linalg::TransposeOp + tensor::EmptyOp +
1520+
/// tensor::InsertSliceOp ops.
15211521
///
15221522
/// Required that all the outer dims of the input tensor::PackOp are 1.
15231523
///
@@ -1683,6 +1683,11 @@ void populateLinalgGenericOpsSpecializationPatterns(
16831683
void populateDecomposeConvolutionPatterns(RewritePatternSet &patterns,
16841684
PatternBenefit benefit = 1);
16851685

1686+
/// Populates patterns to decompose tensor.pack and tensor.unpack Ops into e.g.
1687+
/// tensor.pad, linalg.transpose, tensor.{insert|extract}_slice. Require all
1688+
/// outer dims to be unit.
1689+
void populateGeneralizePatterns(RewritePatternSet &patterns);
1690+
16861691
/// Populates patterns to transform linalg.conv_2d_xxx operations into
16871692
/// linalg.generic (for img2col packing) and linalg.matmul.
16881693
/// \see rewriteInIm2Col for more details.

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,11 @@ void transform::ApplyEraseUnnecessaryInputsPatternsOp::populatePatterns(
229229
linalg::populateEraseUnnecessaryInputsPatterns(patterns);
230230
}
231231

232+
void transform::ApplyGeneralizeTensorPackUnpackPatternsOp::populatePatterns(
233+
RewritePatternSet &patterns) {
234+
linalg::populateGeneralizePatterns(patterns);
235+
}
236+
232237
void transform::ApplyFoldUnitExtentDimsViaReshapesPatternsOp::populatePatterns(
233238
RewritePatternSet &patterns) {
234239
linalg::ControlDropUnitDims options;

mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1618,3 +1618,8 @@ void linalg::populateDecomposeConvolutionPatterns(RewritePatternSet &patterns,
16181618
DownscaleSizeOneWindowed2DConvolution<PoolingNchwMaxOp, PoolingNcwMaxOp>>(
16191619
patterns.getContext(), benefit);
16201620
}
1621+
1622+
void linalg::populateGeneralizePatterns(RewritePatternSet &patterns) {
1623+
// TODO: Add and test patterns for tensor.unpack
1624+
patterns.add<GeneralizeOuterUnitDimsPackOpPattern>(patterns.getContext());
1625+
}

mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
// RUN: mlir-opt -split-input-file --test-linalg-transform-patterns="test-generalize-tensor-pack" %s | FileCheck %s
2-
1+
// RUN: mlir-opt --transform-preload-library='transform-library-paths=%p/td/generalize-pack.mlir' -split-input-file --transform-interpreter %s | FileCheck %s
32

43
func.func @simple_KCRS_to_KCRSsr(%arg0: tensor<?x?xi32>, %arg1: tensor<1x1x?x1xi32>) -> tensor<1x1x?x1xi32> {
54
%c8 = arith.constant 8 : index
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# Skip the directory with input TD sequences
2+
config.excludes = ["td"]
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
module @transforms attributes { transform.with_named_sequence } {
2+
transform.named_sequence @__transform_main(%module: !transform.any_op {transform.readonly}) {
3+
%pack = transform.structured.match ops{["tensor.pack"]} in %module : (!transform.any_op) -> !transform.any_op
4+
5+
%1 = transform.get_parent_op %pack {isolated_from_above} : (!transform.any_op) -> !transform.any_op
6+
transform.apply_patterns to %1 {
7+
transform.apply_patterns.linalg.generalize_pack_unpack
8+
} : !transform.any_op
9+
10+
transform.yield
11+
}
12+
}

0 commit comments

Comments
 (0)