Skip to content

[mlir] Add reshape propagation patterns for tensor.pad #94489

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jun 7, 2024

Conversation

Max191
Copy link
Contributor

@Max191 Max191 commented Jun 5, 2024

This PR adds fusion by collapsing and fusion by expansion patterns for tensor.pad ops in ElementwiseOpFusion. Pad ops can be expanded or collapsed as long as none of the padded dimensions will be expanded or collapsed.

@llvmbot
Copy link
Member

llvmbot commented Jun 5, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-linalg

Author: None (Max191)

Changes

This PR adds fusion by collapsing and fusion by expansion patterns for tensor.pad ops in ElementwiseOpFusion. Pad ops can be expanded or collapsed as long as none of the padded dimensions will be expanded or collapsed.


Full diff: https://github.com/llvm/llvm-project/pull/94489.diff

3 Files Affected:

  • (modified) mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp (+136)
  • (modified) mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir (+68)
  • (modified) mlir/test/Dialect/Linalg/reshape_fusion.mlir (+61)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index ad313c2d5ce60..d93ef9138c474 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -956,6 +956,64 @@ class FoldWithProducerReshapeOpByExpansion
   ControlFusionFn controlFoldingReshapes;
 };
 
+class FoldPadWithProducerReshapeOpByExpansion
+    : public OpRewritePattern<tensor::PadOp> {
+public:
+  FoldPadWithProducerReshapeOpByExpansion(MLIRContext *context,
+                                          ControlFusionFn foldReshapes,
+                                          PatternBenefit benefit = 1)
+      : OpRewritePattern<tensor::PadOp>(context, benefit),
+        controlFoldingReshapes(std::move(foldReshapes)) {}
+
+  LogicalResult matchAndRewrite(tensor::PadOp padOp,
+                                PatternRewriter &rewriter) const override {
+    tensor::CollapseShapeOp reshapeOp =
+        padOp.getSource().getDefiningOp<tensor::CollapseShapeOp>();
+    if (!reshapeOp)
+      return failure();
+    if (!reshapeOp->hasOneUse())
+      return failure();
+
+    ArrayRef<int64_t> low = padOp.getStaticLow();
+    ArrayRef<int64_t> high = padOp.getStaticHigh();
+    SmallVector<ReassociationIndices> reassociations =
+        reshapeOp.getReassociationIndices();
+
+    for (auto [reInd, l, h] : llvm::zip_equal(reassociations, low, high)) {
+      if (reInd.size() != 1 && (l != 0 || h != 0))
+        return failure();
+    }
+
+    SmallVector<OpFoldResult> newLow, newHigh;
+    RankedTensorType expandedType = reshapeOp.getSrcType();
+    RankedTensorType paddedType = padOp.getResultType();
+    SmallVector<int64_t> expandedPaddedShape(expandedType.getShape());
+    for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
+      if (reInd.size() == 1) {
+        expandedPaddedShape[reInd[0]] = paddedType.getShape()[idx];
+      }
+      for (size_t i = 0; i < reInd.size(); ++i) {
+        newLow.push_back(padOp.getMixedLowPad()[idx]);
+        newHigh.push_back(padOp.getMixedHighPad()[idx]);
+      }
+    }
+
+    Location loc = padOp->getLoc();
+    RankedTensorType expandedPaddedType = paddedType.clone(expandedPaddedShape);
+    auto newPadOp = rewriter.create<tensor::PadOp>(
+        loc, expandedPaddedType, reshapeOp.getSrc(), newLow, newHigh,
+        padOp.getConstantPaddingValue(), padOp.getNofold());
+
+    rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
+        padOp, padOp.getResultType(), newPadOp.getResult(), reassociations);
+
+    return success();
+  }
+
+private:
+  ControlFusionFn controlFoldingReshapes;
+};
+
 /// Pattern to fold a tensor.expand_shape op with its producer generic op
 /// by expanding the dimensionality of the loop in the producer op.
 struct FoldReshapeWithGenericOpByExpansion
@@ -1702,6 +1760,80 @@ class FoldWithProducerReshapeOpByCollapsing
   ControlFusionFn controlFoldingReshapes;
 };
 
+class FoldPadWithProducerReshapeOpByCollapsing
+    : public OpRewritePattern<tensor::PadOp> {
+public:
+  FoldPadWithProducerReshapeOpByCollapsing(MLIRContext *context,
+                                           ControlFusionFn foldReshapes,
+                                           PatternBenefit benefit = 1)
+      : OpRewritePattern<tensor::PadOp>(context, benefit),
+        controlFoldingReshapes(std::move(foldReshapes)) {}
+
+  LogicalResult matchAndRewrite(tensor::PadOp padOp,
+                                PatternRewriter &rewriter) const override {
+    tensor::ExpandShapeOp reshapeOp =
+        padOp.getSource().getDefiningOp<tensor::ExpandShapeOp>();
+    if (!reshapeOp)
+      return failure();
+    if (!reshapeOp->hasOneUse())
+      return failure();
+
+    ArrayRef<int64_t> low = padOp.getStaticLow();
+    ArrayRef<int64_t> high = padOp.getStaticHigh();
+    SmallVector<ReassociationIndices> reassociations =
+        reshapeOp.getReassociationIndices();
+
+    for (auto reInd : reassociations) {
+      if (reInd.size() == 1)
+        continue;
+      if (llvm::any_of(reInd, [&](int64_t ind) {
+            return low[ind] != 0 || high[ind] != 0;
+          })) {
+        return failure();
+      }
+    }
+
+    SmallVector<OpFoldResult> newLow, newHigh;
+    RankedTensorType collapsedType = reshapeOp.getSrcType();
+    RankedTensorType paddedType = padOp.getResultType();
+    SmallVector<int64_t> collapsedPaddedShape(collapsedType.getShape());
+    SmallVector<OpFoldResult> expandedPaddedSizes(
+        getMixedValues(reshapeOp.getStaticOutputShape(),
+                       reshapeOp.getOutputShape(), rewriter));
+    AffineExpr d0, d1, d2;
+    bindDims(rewriter.getContext(), d0, d1, d2);
+    auto addMap = AffineMap::get(3, 0, {d0 + d1 + d2});
+    Location loc = reshapeOp->getLoc();
+    for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
+      OpFoldResult l = padOp.getMixedLowPad()[reInd[0]];
+      OpFoldResult h = padOp.getMixedHighPad()[reInd[0]];
+      if (reInd.size() == 1) {
+        collapsedPaddedShape[idx] = paddedType.getShape()[reInd[0]];
+        OpFoldResult paddedSize = affine::makeComposedFoldedAffineApply(
+            rewriter, loc, addMap, {l, h, expandedPaddedSizes[reInd[0]]});
+        expandedPaddedSizes[reInd[0]] = paddedSize;
+      }
+      newLow.push_back(l);
+      newHigh.push_back(h);
+    }
+
+    RankedTensorType collapsedPaddedType =
+        paddedType.clone(collapsedPaddedShape);
+    auto newPadOp = rewriter.create<tensor::PadOp>(
+        loc, collapsedPaddedType, reshapeOp.getSrc(), newLow, newHigh,
+        padOp.getConstantPaddingValue(), padOp.getNofold());
+
+    rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
+        padOp, padOp.getResultType(), newPadOp.getResult(), reassociations,
+        expandedPaddedSizes);
+
+    return success();
+  }
+
+private:
+  ControlFusionFn controlFoldingReshapes;
+};
+
 /// Pattern to collapse dimensions.
 template <typename LinalgType>
 class CollapseLinalgDimensions : public OpRewritePattern<LinalgType> {
@@ -1937,6 +2069,8 @@ void mlir::linalg::populateFoldReshapeOpsByExpansionPatterns(
     const ControlFusionFn &controlFoldingReshapes) {
   patterns.add<FoldReshapeWithGenericOpByExpansion>(patterns.getContext(),
                                                     controlFoldingReshapes);
+  patterns.add<FoldPadWithProducerReshapeOpByExpansion>(patterns.getContext(),
+                                                        controlFoldingReshapes);
   patterns.add<FoldWithProducerReshapeOpByExpansion>(patterns.getContext(),
                                                      controlFoldingReshapes);
 }
@@ -1946,6 +2080,8 @@ void mlir::linalg::populateFoldReshapeOpsByCollapsingPatterns(
     const ControlFusionFn &controlFoldingReshapes) {
   patterns.add<FoldWithProducerReshapeOpByCollapsing>(patterns.getContext(),
                                                       controlFoldingReshapes);
+  patterns.add<FoldPadWithProducerReshapeOpByCollapsing>(
+      patterns.getContext(), controlFoldingReshapes);
 }
 
 void mlir::linalg::populateElementwiseOpsFusionPatterns(
diff --git a/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir b/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir
index 0d40df534a3bb..600f0dea31f4a 100644
--- a/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir
+++ b/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir
@@ -537,3 +537,71 @@ func.func @no_fold_non_consecutive_reduction_dims(%arg0 : tensor<?x?xi32>, %sz0:
 //      CHECK:   %[[GENERIC:.+]] = linalg.generic
 // CHECK-SAME:       ins(%[[EXPAND_ARG0]] :
 //      CHECK:   return %[[GENERIC]]
+
+// -----
+
+func.func @fuse_by_collapsing_pad(%arg0 : tensor<2x12x5x336x9xi32>) -> tensor<8x3x4x17x6x7x8x14xi32> {
+  %expand = tensor.expand_shape %arg0 [[0], [1, 2], [3], [4, 5, 6], [7]] output_shape [2, 3, 4, 5, 6, 7, 8, 9] : tensor<2x12x5x336x9xi32> into tensor<2x3x4x5x6x7x8x9xi32>
+  %cst = arith.constant 0 : i32
+  %padded_0 = tensor.pad %expand low[1, 0, 0, 8, 0, 0, 0, 3] high[5, 0, 0, 4, 0, 0, 0, 2] {
+  ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index,
+       %arg5: index, %arg6: index, %arg7: index, %arg8: index):
+    tensor.yield %cst : i32
+  } : tensor<2x3x4x5x6x7x8x9xi32> to tensor<8x3x4x17x6x7x8x14xi32>
+  return %padded_0 : tensor<8x3x4x17x6x7x8x14xi32>
+}
+//      CHECK: func @fuse_by_collapsing_pad(
+// CHECK-SAME:   %[[ARG0:.+]]: tensor<2x12x5x336x9xi32>)
+//      CHECK:   %[[PAD:.+]] = tensor.pad %[[ARG0]]
+// CHECK-SAME:       low[1, 0, 8, 0, 3] high[5, 0, 4, 0, 2]
+//      CHECK:       tensor<2x12x5x336x9xi32> to tensor<8x12x17x336x14xi32>
+//      CHECK:   %[[EXPAND:.+]] = tensor.expand_shape %[[PAD]] {{\[}}[0], [1, 2], [3], [4, 5, 6], [7]]
+// CHECK-SAME:       output_shape [8, 3, 4, 17, 6, 7, 8, 14] : tensor<8x12x17x336x14xi32> into tensor<8x3x4x17x6x7x8x14xi32>
+//      CHECK:   return %[[EXPAND]]
+
+// -----
+
+func.func @no_fuse_by_collapsing_pad(%arg0 : tensor<2x12x5x336x9xi32>) -> tensor<8x5x4x17x6x7x8x14xi32> {
+  %expand = tensor.expand_shape %arg0 [[0], [1, 2], [3], [4, 5, 6], [7]] output_shape [2, 3, 4, 5, 6, 7, 8, 9] : tensor<2x12x5x336x9xi32> into tensor<2x3x4x5x6x7x8x9xi32>
+  %cst = arith.constant 0 : i32
+  %padded_0 = tensor.pad %expand low[1, 2, 0, 8, 0, 0, 0, 3] high[5, 0, 0, 4, 0, 0, 0, 2] {
+  ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index,
+       %arg5: index, %arg6: index, %arg7: index, %arg8: index):
+    tensor.yield %cst : i32
+  } : tensor<2x3x4x5x6x7x8x9xi32> to tensor<8x5x4x17x6x7x8x14xi32>
+  return %padded_0 : tensor<8x5x4x17x6x7x8x14xi32>
+}
+//      CHECK: func @no_fuse_by_collapsing_pad(
+// CHECK-SAME:   %[[ARG0:.+]]: tensor<2x12x5x336x9xi32>)
+//      CHECK:   %[[EXPAND_ARG0:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1, 2], [3], [4, 5, 6], [7]]
+// CHECK-SAME:       output_shape [2, 3, 4, 5, 6, 7, 8, 9] : tensor<2x12x5x336x9xi32> into tensor<2x3x4x5x6x7x8x9xi32>
+//      CHECK:   %[[PAD:.+]] = tensor.pad %[[EXPAND_ARG0]]
+// CHECK-SAME:       low[1, 2, 0, 8, 0, 0, 0, 3] high[5, 0, 0, 4, 0, 0, 0, 2]
+//      CHECK:       tensor<2x3x4x5x6x7x8x9xi32> to tensor<8x5x4x17x6x7x8x14xi32>
+//      CHECK:   return %[[PAD]]
+
+// -----
+
+func.func @fuse_by_collapsing_dynamic_pad(%arg0 : tensor<?x?x?x?xf32>,
+    %s0 : index, %s1 : index, %s2 : index, %s3 : index, %s4 : index, %s5 : index,
+    %l0 : index, %l1 : index, %h0 : index, %h1 : index) -> tensor<?x?x?x?x?x?xf32> {
+  %expand = tensor.expand_shape %arg0 [[0], [1, 2], [3], [4, 5]] output_shape [%s0, %s1, %s2, %s3, %s4, %s5] : tensor<?x?x?x?xf32> into tensor<?x?x?x?x?x?xf32>
+  %cst = arith.constant 0.0 : f32
+  %padded_0 = tensor.pad %expand low[%l0, 0, 0, %l1, 0, 0] high[%h0, 0, 0, %h1, 0, 0] {
+  ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index, %arg5: index, %arg6: index):
+    tensor.yield %cst : f32
+  } : tensor<?x?x?x?x?x?xf32> to tensor<?x?x?x?x?x?xf32>
+  return %padded_0 : tensor<?x?x?x?x?x?xf32>
+}
+//  CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0, s1, s2] -> (s0 + s1 + s2)>
+//      CHECK: func @fuse_by_collapsing_dynamic_pad(
+// CHECK-SAME:   %[[ARG0:.+]]: tensor<?x?x?x?xf32>
+// CHECK-SAME:   %[[S0:.+]]: index, %[[S1:.+]]: index, %[[S2:.+]]: index, %[[S3:.+]]: index, %[[S4:.+]]: index, %[[S5:.+]]: index, %[[L0:.+]]: index, %[[L1:.+]]: index, %[[H0:.+]]: index, %[[H1:.+]]: index
+//      CHECK:   %[[PAD_SIZE0:.+]] = affine.apply #[[MAP]]()[%[[L0]], %[[H0]], %[[S0]]]
+//      CHECK:   %[[PAD_SIZE1:.+]] = affine.apply #[[MAP]]()[%[[L1]], %[[H1]], %[[S3]]]
+//      CHECK:   %[[PAD:.+]] = tensor.pad %[[ARG0]]
+// CHECK-SAME:       low[%[[L0]], 0, %[[L1]], 0] high[%[[H0]], 0, %[[H1]], 0]
+//      CHECK:       tensor<?x?x?x?xf32> to tensor<?x?x?x?xf32>
+//      CHECK:   %[[EXPAND:.+]] = tensor.expand_shape %[[PAD]] {{\[}}[0], [1, 2], [3], [4, 5]]
+// CHECK-SAME:       output_shape [%[[PAD_SIZE0]], %[[S1]], %[[S2]], %[[PAD_SIZE1]], %[[S4]], %[[S5]]] : tensor<?x?x?x?xf32> into tensor<?x?x?x?x?x?xf32>
+//      CHECK:   return %[[EXPAND]]
diff --git a/mlir/test/Dialect/Linalg/reshape_fusion.mlir b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
index f42666f81bbad..b8df5fc88e199 100644
--- a/mlir/test/Dialect/Linalg/reshape_fusion.mlir
+++ b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
@@ -826,3 +826,64 @@ func.func @linalg_add_reshape_producer_fusion(%arg0 : tensor<?x7x?x8xf32>,
 // CHECK-SAME:     [0, 1], [2, 3]
 // CHECK-SAME:     tensor<?x7x?x8xf32> into tensor<?x?xf32>
 //      CHECK:   return %[[T4]]
+
+// -----
+
+func.func @fuse_by_expanding_pad(%arg0 : tensor<2x3x4x5x6x7x8x9xi32>) -> tensor<8x12x17x336x14xi32> {
+  %collapse = tensor.collapse_shape %arg0 [[0], [1, 2], [3], [4, 5, 6], [7]] : tensor<2x3x4x5x6x7x8x9xi32> into tensor<2x12x5x336x9xi32>
+  %cst = arith.constant 0 : i32
+  %padded_0 = tensor.pad %collapse low[1, 0, 8, 0, 3] high[5, 0, 4, 0, 2] {
+  ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index, %arg5: index):
+    tensor.yield %cst : i32
+  } : tensor<2x12x5x336x9xi32> to tensor<8x12x17x336x14xi32>
+  return %padded_0 : tensor<8x12x17x336x14xi32>
+}
+//      CHECK: func @fuse_by_expanding_pad(
+// CHECK-SAME:   %[[ARG0:.+]]: tensor<2x3x4x5x6x7x8x9xi32>)
+//      CHECK:   %[[PAD:.+]] = tensor.pad %[[ARG0]]
+// CHECK-SAME:       low[1, 0, 0, 8, 0, 0, 0, 3] high[5, 0, 0, 4, 0, 0, 0, 2]
+//      CHECK:       tensor<2x3x4x5x6x7x8x9xi32> to tensor<8x3x4x17x6x7x8x14xi32>
+//      CHECK:   %[[COLLAPSE:.+]] = tensor.collapse_shape %[[PAD]] {{\[}}[0], [1, 2], [3], [4, 5, 6], [7]]
+// CHECK-SAME:       : tensor<8x3x4x17x6x7x8x14xi32> into tensor<8x12x17x336x14xi32>
+//      CHECK:   return %[[COLLAPSE]]
+
+// -----
+
+func.func @no_fuse_by_expanding_pad(%arg0 : tensor<2x3x4x5x6x7x8x9xi32>) -> tensor<8x12x17x339x14xi32> {
+  %collapse = tensor.collapse_shape %arg0 [[0], [1, 2], [3], [4, 5, 6], [7]] : tensor<2x3x4x5x6x7x8x9xi32> into tensor<2x12x5x336x9xi32>
+  %cst = arith.constant 0 : i32
+  %padded_0 = tensor.pad %collapse low[1, 0, 8, 0, 3] high[5, 0, 4, 3, 2] {
+  ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index, %arg5: index):
+    tensor.yield %cst : i32
+  } : tensor<2x12x5x336x9xi32> to tensor<8x12x17x339x14xi32>
+  return %padded_0 : tensor<8x12x17x339x14xi32>
+}
+//      CHECK: func @no_fuse_by_expanding_pad(
+// CHECK-SAME:   %[[ARG0:.+]]: tensor<2x3x4x5x6x7x8x9xi32>)
+//      CHECK:   %[[COLLAPSE:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0], [1, 2], [3], [4, 5, 6], [7]]
+// CHECK-SAME:       : tensor<2x3x4x5x6x7x8x9xi32> into tensor<2x12x5x336x9xi32>
+//      CHECK:   %[[PAD:.+]] = tensor.pad %[[COLLAPSE]]
+// CHECK-SAME:       low[1, 0, 8, 0, 3] high[5, 0, 4, 3, 2]
+//      CHECK:       tensor<2x12x5x336x9xi32> to tensor<8x12x17x339x14xi32>
+//      CHECK:   return %[[PAD]]
+
+// -----
+
+func.func @fuse_by_expanding_dynamic_pad(%arg0 : tensor<?x?x?x?x?x?xi32>, %l0: index, %l1: index, %h0: index, %h1: index) -> tensor<?x?x?x?xi32> {
+  %collapse = tensor.collapse_shape %arg0 [[0], [1, 2], [3], [4, 5]] : tensor<?x?x?x?x?x?xi32> into tensor<?x?x?x?xi32>
+  %cst = arith.constant 0 : i32
+  %padded_0 = tensor.pad %collapse low[%l0, 0, %l1, 0] high[%h0, 0, %h1, 0] {
+  ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index):
+    tensor.yield %cst : i32
+  } : tensor<?x?x?x?xi32> to tensor<?x?x?x?xi32>
+  return %padded_0 : tensor<?x?x?x?xi32>
+}
+//      CHECK: func @fuse_by_expanding_dynamic_pad(
+// CHECK-SAME:   %[[ARG0:.+]]: tensor<?x?x?x?x?x?xi32>
+// CHECK-SAME:   %[[L0:.+]]: index, %[[L1:.+]]: index, %[[H0:.+]]: index, %[[H1:.+]]: index
+//      CHECK:   %[[PAD:.+]] = tensor.pad %[[ARG0]]
+// CHECK-SAME:       low[%[[L0]], 0, 0, %[[L1]], 0, 0] high[%[[H0]], 0, 0, %[[H1]], 0, 0]
+//      CHECK:       tensor<?x?x?x?x?x?xi32> to tensor<?x?x?x?x?x?xi32>
+//      CHECK:   %[[COLLAPSE:.+]] = tensor.collapse_shape %[[PAD]] {{\[}}[0], [1, 2], [3], [4, 5]]
+// CHECK-SAME:       : tensor<?x?x?x?x?x?xi32> into tensor<?x?x?x?xi32>
+//      CHECK:   return %[[COLLAPSE]]

Copy link
Contributor

@MaheshRavishankar MaheshRavishankar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The control function is not being used. Please fix that.

@@ -1937,6 +2069,8 @@ void mlir::linalg::populateFoldReshapeOpsByExpansionPatterns(
const ControlFusionFn &controlFoldingReshapes) {
patterns.add<FoldReshapeWithGenericOpByExpansion>(patterns.getContext(),
controlFoldingReshapes);
patterns.add<FoldPadWithProducerReshapeOpByExpansion>(patterns.getContext(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AFAICS both the patterns added here are "propagating by collapsing". You are moving the expand_shape down and collapse_shape up. In both cases the pad is happening on collapsed dimensions. So you should add both to the populateFoldReshapeOpsByCollapsingPatterns.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FoldPadWithProducerReshapeOpByExpansion pushes the producer collapse_shape down, expanding the pad. FoldPadWithProducerReshapeOpByCollapsing pushes the producer expand_shape down, collapsing the pad. I think these are in the right places here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I twisted myself in a knot when looking at tests. You are right.

@Max191 Max191 requested a review from MaheshRavishankar June 6, 2024 16:33
@Max191
Copy link
Contributor Author

Max191 commented Jun 6, 2024

The control function is not being used. Please fix that.

Oh oops, I forgot to fix this. Fixing now

@Max191 Max191 force-pushed the pad-reshape-propagation branch from 1735c49 to ecf0347 Compare June 6, 2024 17:24
@@ -1937,6 +2069,8 @@ void mlir::linalg::populateFoldReshapeOpsByExpansionPatterns(
const ControlFusionFn &controlFoldingReshapes) {
patterns.add<FoldReshapeWithGenericOpByExpansion>(patterns.getContext(),
controlFoldingReshapes);
patterns.add<FoldPadWithProducerReshapeOpByExpansion>(patterns.getContext(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I twisted myself in a knot when looking at tests. You are right.

@Max191 Max191 merged commit c886d66 into llvm:main Jun 7, 2024
7 checks passed
@HerrCai0907 HerrCai0907 mentioned this pull request Jun 13, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants