Skip to content

[mlir][linalg] Extract GeneralizePadOpPattern into a standalone transformation #117329

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,17 @@ def ApplyDecomposeTensorPackUnpackPatternsOp
let assemblyFormat = "attr-dict";
}

def ApplyDecomposeTensorPadPatternsOp
: Op<Transform_Dialect, "apply_patterns.linalg.decompose_pad",
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
let description = [{
Collect patterns to decompose tensor.pad into e.g. tensor::EmptyOp,
linalg::FillOp and tensor::InsertSliceOp.
}];

let assemblyFormat = "attr-dict";
}

def ApplyFoldUnitExtentDimsViaReshapesPatternsOp : Op<Transform_Dialect,
"apply_patterns.linalg.fold_unit_extent_dims_via_reshapes",
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
Expand Down
8 changes: 6 additions & 2 deletions mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -1503,8 +1503,8 @@ using OptimizeCopyFn =

/// Rewrite a tensor::PadOp into a sequence of EmptyOp, FillOp and
/// InsertSliceOp. For now, only constant padding values are supported.
struct GeneralizePadOpPattern : public OpRewritePattern<tensor::PadOp> {
GeneralizePadOpPattern(MLIRContext *context, PatternBenefit benefit = 1)
struct DecomposePadOpPattern : public OpRewritePattern<tensor::PadOp> {
DecomposePadOpPattern(MLIRContext *context, PatternBenefit benefit = 1)
: OpRewritePattern<tensor::PadOp>(context, benefit) {}
LogicalResult matchAndRewrite(tensor::PadOp padOp,
PatternRewriter &rewriter) const override;
Expand Down Expand Up @@ -1688,6 +1688,10 @@ void populateDecomposeConvolutionPatterns(RewritePatternSet &patterns,
/// outer dims to be unit.
void populateDecomposePackUnpackPatterns(RewritePatternSet &patterns);

/// Populates patterns to decompose tensor.pad into e.g.
/// tensor.empty, linalg.fill, tensor.insert_slice.
void populateDecomposePadPatterns(RewritePatternSet &patterns);

/// Populates patterns to transform linalg.conv_2d_xxx operations into
/// linalg.generic (for img2col packing) and linalg.matmul.
/// \see rewriteInIm2Col for more details.
Expand Down
4 changes: 3 additions & 1 deletion mlir/lib/Conversion/TensorToLinalg/TensorToLinalg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,5 +25,7 @@ using namespace mlir;
//===----------------------------------------------------------------------===//

void mlir::populateTensorToLinalgPatterns(RewritePatternSet &patterns) {
patterns.add<mlir::linalg::GeneralizePadOpPattern>(patterns.getContext());
// TODO: Add the remaining patterns, e.g. to decompose Pack/Unpack Ops.
// Alternatively, delete this file.
patterns.add<mlir::linalg::DecomposePadOpPattern>(patterns.getContext());
}
11 changes: 10 additions & 1 deletion mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,11 @@ void transform::ApplyDecomposeTensorPackUnpackPatternsOp::populatePatterns(
linalg::populateDecomposePackUnpackPatterns(patterns);
}

void transform::ApplyDecomposeTensorPadPatternsOp::populatePatterns(
RewritePatternSet &patterns) {
linalg::populateDecomposePadPatterns(patterns);
}

void transform::ApplyFoldUnitExtentDimsViaReshapesPatternsOp::populatePatterns(
RewritePatternSet &patterns) {
linalg::ControlDropUnitDims options;
Expand Down Expand Up @@ -3491,8 +3496,12 @@ transform::VectorizeChildrenAndApplyPatternsOp::applyToOne(
// Add misc. vectorization patterns (e.g. for tensor.insert_slice)
linalg::populateInsertSliceVectorizationPatterns(patterns);

if (getVectorizePadding())
if (getVectorizePadding()) {
linalg::populatePadOpVectorizationPatterns(patterns);
// This creates an alternative path for lowering tensor.pad - by
// decomposing it into e.g. linalg.fill.
linalg::populateDecomposePadPatterns(patterns);
}
vector::populateVectorStepLoweringPatterns(patterns);

TrackingListener listener(state, *this);
Expand Down
10 changes: 7 additions & 3 deletions mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -921,7 +921,7 @@ LogicalResult mlir::linalg::CopyVectorizationPattern::matchAndRewrite(

/// Filling `dest` using FillOp constant padding value if possible.
/// Otherwise, generate a tensor::GenerateOp.
Value GeneralizePadOpPattern::createFillOrGenerateOp(
Value DecomposePadOpPattern::createFillOrGenerateOp(
RewriterBase &rewriter, tensor::PadOp padOp, Value dest,
const SmallVector<Value> &dynSizes) const {
auto padValue = padOp.getConstantPaddingValue();
Expand All @@ -938,8 +938,8 @@ Value GeneralizePadOpPattern::createFillOrGenerateOp(
}

LogicalResult
GeneralizePadOpPattern::matchAndRewrite(tensor::PadOp padOp,
PatternRewriter &rewriter) const {
DecomposePadOpPattern::matchAndRewrite(tensor::PadOp padOp,
PatternRewriter &rewriter) const {
// Given an OpFoldResult, return an index-typed value.
auto getIdxValue = [&](OpFoldResult ofr) {
if (auto val = llvm::dyn_cast_if_present<Value>(ofr))
Expand Down Expand Up @@ -1623,3 +1623,7 @@ void linalg::populateDecomposePackUnpackPatterns(RewritePatternSet &patterns) {
// TODO: Add and test patterns for tensor.unpack
patterns.add<DecomposeOuterUnitDimsPackOpPattern>(patterns.getContext());
}

void linalg::populateDecomposePadPatterns(RewritePatternSet &patterns) {
patterns.add<DecomposePadOpPattern>(patterns.getContext());
}
6 changes: 0 additions & 6 deletions mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2770,12 +2770,6 @@ void mlir::linalg::populateInsertSliceVectorizationPatterns(

void mlir::linalg::populatePadOpVectorizationPatterns(
RewritePatternSet &patterns, PatternBenefit baseBenefit) {
// TODO: The following pattern implements "decomposition" and
// optional "vectorization". Seperate "decomposition" into a sepereate
// pre-processing pattern group.
patterns.add<GeneralizePadOpPattern>(patterns.getContext(), baseBenefit);

// Try these specialized patterns first before resorting to the generic one.
patterns.add<PadOpVectorizationWithTransferReadPattern,
PadOpVectorizationWithTransferWritePattern,
PadOpVectorizationWithInsertSlicePattern>(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: mlir-opt -split-input-file --test-linalg-transform-patterns="test-generalize-pad-tensor" %s | FileCheck %s
// RUN: mlir-opt -split-input-file --test-linalg-transform-patterns="test-decompose-pad-tensor" %s | FileCheck %s

// CHECK-LABEL: func @generalize_pad_tensor_static_shape(
// CHECK-SAME: %[[IN:.*]]: tensor<1x28x28x1xf32>) -> tensor<1x32x32x1xf32> {
Expand Down
6 changes: 6 additions & 0 deletions mlir/test/Dialect/Linalg/vectorization-pad-patterns.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,8 @@ module attributes {transform.with_named_sequence} {
%func_op = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.op<"func.func">

transform.apply_patterns to %func_op {
// TODO: Split into two tests, one for each pattern
transform.apply_patterns.linalg.decompose_pad
transform.apply_patterns.linalg.pad_vectorization
} : !transform.op<"func.func">
transform.yield
Expand Down Expand Up @@ -236,6 +238,8 @@ module attributes {transform.with_named_sequence} {
%func_op = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.op<"func.func">

transform.apply_patterns to %func_op {
// TODO: Split into two tests, one for each pattern
transform.apply_patterns.linalg.decompose_pad
transform.apply_patterns.linalg.pad_vectorization
} : !transform.op<"func.func">
transform.yield
Expand Down Expand Up @@ -270,6 +274,8 @@ module attributes {transform.with_named_sequence} {
%func_op = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.op<"func.func">

transform.apply_patterns to %func_op {
// TODO: Split into two tests, one for each pattern
transform.apply_patterns.linalg.decompose_pad
transform.apply_patterns.linalg.pad_vectorization
} : !transform.op<"func.func">
transform.yield
Expand Down
12 changes: 6 additions & 6 deletions mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@ struct TestLinalgTransforms
llvm::cl::desc("Test a set of patterns that rewrite a linalg contraction "
"in vector.contract form"),
llvm::cl::init(false)};
Option<bool> testGeneralizePadTensor{
*this, "test-generalize-pad-tensor",
Option<bool> testDecomposePadTensor{
*this, "test-decompose-pad-tensor",
llvm::cl::desc("Test transform pad tensor by copying with generic ops"),
llvm::cl::init(false)};
Option<bool> testDecomposeTensorPackOp{
Expand Down Expand Up @@ -166,9 +166,9 @@ static void applyLinalgToVectorPatterns(func::FuncOp funcOp) {
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
}

static void applyGeneralizePadTensorPatterns(func::FuncOp funcOp) {
static void applyDecomposePadPatterns(func::FuncOp funcOp) {
RewritePatternSet patterns(funcOp.getContext());
patterns.add<GeneralizePadOpPattern>(funcOp.getContext());
patterns.add<DecomposePadOpPattern>(funcOp.getContext());
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
}

Expand Down Expand Up @@ -235,8 +235,8 @@ void TestLinalgTransforms::runOnOperation() {
return applyVectorTransferForwardingPatterns(getOperation());
if (testGenericToVectorPattern)
return applyLinalgToVectorPatterns(getOperation());
if (testGeneralizePadTensor)
return applyGeneralizePadTensorPatterns(getOperation());
if (testDecomposePadTensor)
return applyDecomposePadPatterns(getOperation());
if (testDecomposeTensorPackOp)
return applyDecomposeTensorPackPatterns(getOperation());
if (testDecomposeTensorUnPackOp)
Expand Down
Loading