Skip to content

[mlir] Rename GeneralizeOuterUnitDims{Un}PackOpPatterns #116439

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,12 @@ def ApplyEraseUnnecessaryInputsPatternsOp : Op<Transform_Dialect,
let assemblyFormat = "attr-dict";
}

def ApplyGeneralizeTensorPackUnpackPatternsOp
: Op<Transform_Dialect, "apply_patterns.linalg.generalize_pack_unpack",
def ApplyDecomposeTensorPackUnpackPatternsOp
: Op<Transform_Dialect, "apply_patterns.linalg.decompose_pack_unpack",
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
let description = [{
Collect patterns to generalize tensor.pack and tensor.unpack (i.e. to
decompose it into e.g. tensor::PadOp, linalg::transposeOp etc). Requires
all outer dims to be unit.
Collect patterns to decompose tensor.pack and tensor.unpack into e.g.
tensor::PadOp, linalg::transposeOp Ops. Requires all outer dims to be unit.
}];

let assemblyFormat = "attr-dict";
Expand Down
6 changes: 3 additions & 3 deletions mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -1548,7 +1548,7 @@ struct GeneralizePadOpPattern : public OpRewritePattern<tensor::PadOp> {
/// into %arg1[0, 0, 0, 0] [1, 1, 2, %tile_dim_1] [1, 1, 1, 1]
/// : tensor<2x?xf32> into tensor<1x1x2x?xf32>
/// ```
struct GeneralizeOuterUnitDimsPackOpPattern
struct DecomposeOuterUnitDimsPackOpPattern
: public OpRewritePattern<tensor::PackOp> {
using OpRewritePattern<tensor::PackOp>::OpRewritePattern;
LogicalResult matchAndRewrite(tensor::PackOp packOp,
Expand All @@ -1558,7 +1558,7 @@ struct GeneralizeOuterUnitDimsPackOpPattern
/// Rewrites a tensor::UnPackOp into a sequence of rank-reduced extract_slice op
/// + transpose op + insert_slice op, where the tensor::UnPackOp has outer dims
/// being all 1s.
struct GeneralizeOuterUnitDimsUnPackOpPattern
struct DecomposeOuterUnitDimsUnPackOpPattern
: public OpRewritePattern<tensor::UnPackOp> {
using OpRewritePattern<tensor::UnPackOp>::OpRewritePattern;
LogicalResult matchAndRewrite(tensor::UnPackOp unpackOp,
Expand Down Expand Up @@ -1686,7 +1686,7 @@ void populateDecomposeConvolutionPatterns(RewritePatternSet &patterns,
/// Populates patterns to decompose tensor.pack and tensor.unpack Ops into e.g.
/// tensor.pad, linalg.transpose, tensor.{insert|extract}_slice. Require all
/// outer dims to be unit.
void populateGeneralizePatterns(RewritePatternSet &patterns);
void populateDecomposePackUnpackPatterns(RewritePatternSet &patterns);

/// Populates patterns to transform linalg.conv_2d_xxx operations into
/// linalg.generic (for img2col packing) and linalg.matmul.
Expand Down
4 changes: 2 additions & 2 deletions mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -229,9 +229,9 @@ void transform::ApplyEraseUnnecessaryInputsPatternsOp::populatePatterns(
linalg::populateEraseUnnecessaryInputsPatterns(patterns);
}

void transform::ApplyGeneralizeTensorPackUnpackPatternsOp::populatePatterns(
void transform::ApplyDecomposeTensorPackUnpackPatternsOp::populatePatterns(
RewritePatternSet &patterns) {
linalg::populateGeneralizePatterns(patterns);
linalg::populateDecomposePackUnpackPatterns(patterns);
}

void transform::ApplyFoldUnitExtentDimsViaReshapesPatternsOp::populatePatterns(
Expand Down
8 changes: 4 additions & 4 deletions mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1138,7 +1138,7 @@ getPackUnpackRankReducedPerm(ArrayRef<int64_t> shape,
return perm;
}

LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite(
LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
tensor::PackOp packOp, PatternRewriter &rewriter) const {
// TODO: support the case that outer dimensions are not all 1s. A
// tensor.expand_shape will be generated in this case.
Expand Down Expand Up @@ -1239,7 +1239,7 @@ LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite(
return success();
}

LogicalResult GeneralizeOuterUnitDimsUnPackOpPattern::matchAndRewrite(
LogicalResult DecomposeOuterUnitDimsUnPackOpPattern::matchAndRewrite(
tensor::UnPackOp unpackOp, PatternRewriter &rewriter) const {
int64_t srcRank = unpackOp.getSourceRank();
int64_t destRank = unpackOp.getDestRank();
Expand Down Expand Up @@ -1619,7 +1619,7 @@ void linalg::populateDecomposeConvolutionPatterns(RewritePatternSet &patterns,
patterns.getContext(), benefit);
}

void linalg::populateGeneralizePatterns(RewritePatternSet &patterns) {
void linalg::populateDecomposePackUnpackPatterns(RewritePatternSet &patterns) {
// TODO: Add and test patterns for tensor.unpack
patterns.add<GeneralizeOuterUnitDimsPackOpPattern>(patterns.getContext());
patterns.add<DecomposeOuterUnitDimsPackOpPattern>(patterns.getContext());
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
// RUN: mlir-opt -split-input-file --transform-interpreter --canonicalize --test-linalg-transform-patterns="test-generalize-tensor-pack" %s | FileCheck %s
// RUN: mlir-opt -split-input-file -transform-interpreter --canonicalize \
// RUN: -transform-preload-library='transform-library-paths=%p/td/decompose-pack.mlir' \
// RUN: -transform-interpreter=entry-point=decompose_pack \
// RUN: -transform-interpreter %s | FileCheck %s

func.func @KCRS_to_KCRSsr(%arg0: tensor<1x1x128x64xf32>, %arg1: tensor<1x1x4x8x8x32xf32>) -> tensor<1x1x4x8x8x32xf32> {
%0 = tensor.pack %arg0 inner_dims_pos = [3, 2] inner_tiles = [8, 32] into %arg1 : tensor<1x1x128x64xf32> -> tensor<1x1x4x8x8x32xf32>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
// RUN: mlir-opt --transform-preload-library='transform-library-paths=%p/td/generalize-pack.mlir' -split-input-file --transform-interpreter %s | FileCheck %s
// RUN: mlir-opt -split-input-file \
// RUN: -transform-preload-library='transform-library-paths=%p/td/decompose-pack.mlir' \
// RUN: -transform-interpreter=entry-point=decompose_pack %s | FileCheck %s

func.func @simple_KCRS_to_KCRSsr(%arg0: tensor<?x?xi32>, %arg1: tensor<1x1x?x1xi32>) -> tensor<1x1x?x1xi32> {
%c8 = arith.constant 8 : index
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: mlir-opt -split-input-file --transform-interpreter --canonicalize --test-linalg-transform-patterns="test-generalize-tensor-unpack" %s | FileCheck %s
// RUN: mlir-opt -split-input-file --transform-interpreter --canonicalize --test-linalg-transform-patterns="test-decompose-tensor-unpack" %s | FileCheck %s

func.func @KCRSsr_to_KCRS(%arg0: tensor<1x1x4x8x8x32xf32>, %arg1: tensor<1x1x128x64xf32>) -> tensor<1x1x128x64xf32> {
%0 = tensor.unpack %arg0 inner_dims_pos = [3, 2] inner_tiles = [8, 32] into %arg1 : tensor<1x1x4x8x8x32xf32> -> tensor<1x1x128x64xf32>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: mlir-opt -split-input-file --test-linalg-transform-patterns="test-generalize-tensor-unpack" %s | FileCheck %s
// RUN: mlir-opt -split-input-file --test-linalg-transform-patterns="test-decompose-tensor-unpack" %s | FileCheck %s

func.func @simple_KCRSsr_to_KCRS(%arg0: tensor<1x1x1x1x8x32xf32>, %arg1: tensor<1x1x32x8xf32>) -> tensor<1x1x32x8xf32> {
%0 = tensor.unpack %arg0 inner_dims_pos = [3, 2] inner_tiles = [8, 32] into %arg1 : tensor<1x1x1x1x8x32xf32> -> tensor<1x1x32x8xf32>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
module @transforms attributes { transform.with_named_sequence } {
transform.named_sequence @__transform_main(%module: !transform.any_op {transform.readonly}) {
transform.named_sequence @decompose_pack(%module: !transform.any_op {transform.readonly}) {
%pack = transform.structured.match ops{["tensor.pack"]} in %module : (!transform.any_op) -> !transform.any_op

%1 = transform.get_parent_op %pack {isolated_from_above} : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %1 {
transform.apply_patterns.linalg.generalize_pack_unpack
transform.apply_patterns.linalg.decompose_pack_unpack
} : !transform.any_op

transform.yield
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// DEFINE: %{compile} = mlir-opt %s \
// DEFINE: -transform-interpreter -test-transform-dialect-erase-schedule |\
// DEFINE: mlir-opt --test-linalg-transform-patterns="test-generalize-tensor-pack"\
// DEFINE: mlir-opt --test-linalg-transform-patterns="test-decompose-tensor-pack"\
// DEFINE: --test-transform-dialect-erase-schedule \
// DEFINE: -one-shot-bufferize="bufferize-function-boundaries" \
// DEFINE: -buffer-deallocation-pipeline="private-function-dynamic-ownership" \
Expand Down
24 changes: 12 additions & 12 deletions mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,13 +74,13 @@ struct TestLinalgTransforms
*this, "test-generalize-pad-tensor",
llvm::cl::desc("Test transform pad tensor by copying with generic ops"),
llvm::cl::init(false)};
Option<bool> testGeneralizeTensorPackOp{
*this, "test-generalize-tensor-pack",
Option<bool> testDecomposeTensorPackOp{
*this, "test-decompose-tensor-pack",
llvm::cl::desc("Test transform that generalizes pack ops into a sequence "
"of tensor and Linalg ops"),
llvm::cl::init(false)};
Option<bool> testGeneralizeTensorUnPackOp{
*this, "test-generalize-tensor-unpack",
Option<bool> testDecomposeTensorUnPackOp{
*this, "test-decompose-tensor-unpack",
llvm::cl::desc(
"Test transform that generalizes unpack ops into a sequence "
"of tensor and Linalg ops"),
Expand Down Expand Up @@ -172,15 +172,15 @@ static void applyGeneralizePadTensorPatterns(func::FuncOp funcOp) {
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
}

static void applyGeneralizeTensorPackPatterns(func::FuncOp funcOp) {
static void applyDecomposeTensorPackPatterns(func::FuncOp funcOp) {
RewritePatternSet patterns(funcOp.getContext());
patterns.add<GeneralizeOuterUnitDimsPackOpPattern>(funcOp.getContext());
patterns.add<DecomposeOuterUnitDimsPackOpPattern>(funcOp.getContext());
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
}

static void applyGeneralizeTensorUnPackPatterns(func::FuncOp funcOp) {
static void applyDecomposeTensorUnPackPatterns(func::FuncOp funcOp) {
RewritePatternSet patterns(funcOp.getContext());
patterns.add<GeneralizeOuterUnitDimsUnPackOpPattern>(funcOp.getContext());
patterns.add<DecomposeOuterUnitDimsUnPackOpPattern>(funcOp.getContext());
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
}

Expand Down Expand Up @@ -237,10 +237,10 @@ void TestLinalgTransforms::runOnOperation() {
return applyLinalgToVectorPatterns(getOperation());
if (testGeneralizePadTensor)
return applyGeneralizePadTensorPatterns(getOperation());
if (testGeneralizeTensorPackOp)
return applyGeneralizeTensorPackPatterns(getOperation());
if (testGeneralizeTensorUnPackOp)
return applyGeneralizeTensorUnPackPatterns(getOperation());
if (testDecomposeTensorPackOp)
return applyDecomposeTensorPackPatterns(getOperation());
if (testDecomposeTensorUnPackOp)
return applyDecomposeTensorUnPackPatterns(getOperation());
if (testSwapSubTensorPadTensor)
return applyExtractSliceOfPadTensorSwapPattern(getOperation());
if (testBubbleUpExtractSliceOpPattern)
Expand Down
Loading