16
16
#include " mlir/Dialect/Arith/Utils/Utils.h"
17
17
#include " mlir/Dialect/Linalg/IR/Linalg.h"
18
18
#include " mlir/Dialect/Linalg/Transforms/Transforms.h"
19
+ #include " mlir/Dialect/Linalg/Utils/Utils.h"
19
20
#include " mlir/Dialect/SparseTensor/IR/SparseTensor.h"
20
21
#include " mlir/IR/AffineExpr.h"
21
22
#include " mlir/IR/AffineMap.h"
@@ -956,6 +957,64 @@ class FoldWithProducerReshapeOpByExpansion
956
957
ControlFusionFn controlFoldingReshapes;
957
958
};
958
959
960
+ class FoldPadWithProducerReshapeOpByExpansion
961
+ : public OpRewritePattern<tensor::PadOp> {
962
+ public:
963
+ FoldPadWithProducerReshapeOpByExpansion (MLIRContext *context,
964
+ ControlFusionFn foldReshapes,
965
+ PatternBenefit benefit = 1 )
966
+ : OpRewritePattern<tensor::PadOp>(context, benefit),
967
+ controlFoldingReshapes (std::move(foldReshapes)) {}
968
+
969
+ LogicalResult matchAndRewrite (tensor::PadOp padOp,
970
+ PatternRewriter &rewriter) const override {
971
+ tensor::CollapseShapeOp reshapeOp =
972
+ padOp.getSource ().getDefiningOp <tensor::CollapseShapeOp>();
973
+ if (!reshapeOp)
974
+ return failure ();
975
+ if (!reshapeOp->hasOneUse ())
976
+ return failure ();
977
+
978
+ ArrayRef<int64_t > low = padOp.getStaticLow ();
979
+ ArrayRef<int64_t > high = padOp.getStaticHigh ();
980
+ SmallVector<ReassociationIndices> reassociations =
981
+ reshapeOp.getReassociationIndices ();
982
+
983
+ for (auto [reInd, l, h] : llvm::zip_equal (reassociations, low, high)) {
984
+ if (reInd.size () != 1 && l != 0 && h != 0 )
985
+ return failure ();
986
+ }
987
+
988
+ SmallVector<OpFoldResult> newLow, newHigh;
989
+ RankedTensorType expandedType = reshapeOp.getSrcType ();
990
+ RankedTensorType paddedType = padOp.getResultType ();
991
+ SmallVector<int64_t > expandedPaddedShape (expandedType.getShape ());
992
+ for (auto [idx, reInd] : llvm::enumerate (reassociations)) {
993
+ if (reInd.size () == 1 ) {
994
+ expandedPaddedShape[reInd[0 ]] = paddedType.getShape ()[idx];
995
+ }
996
+ for (auto ind : reInd) {
997
+ newLow.push_back (padOp.getMixedLowPad ()[idx]);
998
+ newHigh.push_back (padOp.getMixedHighPad ()[idx]);
999
+ }
1000
+ }
1001
+
1002
+ Location loc = padOp->getLoc ();
1003
+ RankedTensorType expandedPaddedType = paddedType.clone (expandedPaddedShape);
1004
+ auto newPadOp = rewriter.create <tensor::PadOp>(
1005
+ loc, expandedPaddedType, reshapeOp.getSrc (), newLow, newHigh,
1006
+ padOp.getConstantPaddingValue (), padOp.getNofold ());
1007
+
1008
+ rewriter.replaceOpWithNewOp <tensor::CollapseShapeOp>(
1009
+ padOp, padOp.getResultType (), newPadOp.getResult (), reassociations);
1010
+
1011
+ return success ();
1012
+ }
1013
+
1014
+ private:
1015
+ ControlFusionFn controlFoldingReshapes;
1016
+ };
1017
+
959
1018
// / Pattern to fold a tensor.expand_shape op with its producer generic op
960
1019
// / by expanding the dimensionality of the loop in the producer op.
961
1020
struct FoldReshapeWithGenericOpByExpansion
@@ -1702,6 +1761,68 @@ class FoldWithProducerReshapeOpByCollapsing
1702
1761
ControlFusionFn controlFoldingReshapes;
1703
1762
};
1704
1763
1764
+ class FoldPadWithProducerReshapeOpByCollapsing
1765
+ : public OpRewritePattern<tensor::PadOp> {
1766
+ public:
1767
+ FoldPadWithProducerReshapeOpByCollapsing (MLIRContext *context,
1768
+ ControlFusionFn foldReshapes,
1769
+ PatternBenefit benefit = 1 )
1770
+ : OpRewritePattern<tensor::PadOp>(context, benefit),
1771
+ controlFoldingReshapes (std::move(foldReshapes)) {}
1772
+
1773
+ LogicalResult matchAndRewrite (tensor::PadOp padOp,
1774
+ PatternRewriter &rewriter) const override {
1775
+ tensor::ExpandShapeOp reshapeOp =
1776
+ padOp.getSource ().getDefiningOp <tensor::ExpandShapeOp>();
1777
+ if (!reshapeOp)
1778
+ return failure ();
1779
+ if (!reshapeOp->hasOneUse ())
1780
+ return failure ();
1781
+
1782
+ ArrayRef<int64_t > low = padOp.getStaticLow ();
1783
+ ArrayRef<int64_t > high = padOp.getStaticHigh ();
1784
+ SmallVector<ReassociationIndices> reassociations =
1785
+ reshapeOp.getReassociationIndices ();
1786
+
1787
+ for (auto reInd : reassociations) {
1788
+ if (reInd.size () == 1 )
1789
+ continue ;
1790
+ if (llvm::any_of (reInd, [&](int64_t ind) {
1791
+ return low[ind] != 0 || high[ind] != 0 ;
1792
+ })) {
1793
+ return failure ();
1794
+ }
1795
+ }
1796
+
1797
+ SmallVector<OpFoldResult> newLow, newHigh;
1798
+ RankedTensorType collapsedType = reshapeOp.getSrcType ();
1799
+ RankedTensorType paddedType = padOp.getResultType ();
1800
+ SmallVector<int64_t > collapsedPaddedShape (collapsedType.getShape ());
1801
+ for (auto [idx, reInd] : llvm::enumerate (reassociations)) {
1802
+ if (reInd.size () == 1 ) {
1803
+ collapsedPaddedShape[idx] = paddedType.getShape ()[reInd[0 ]];
1804
+ }
1805
+ newLow.push_back (padOp.getMixedLowPad ()[reInd[0 ]]);
1806
+ newHigh.push_back (padOp.getMixedHighPad ()[reInd[0 ]]);
1807
+ }
1808
+
1809
+ Location loc = padOp->getLoc ();
1810
+ RankedTensorType collapsedPaddedType =
1811
+ paddedType.clone (collapsedPaddedShape);
1812
+ auto newPadOp = rewriter.create <tensor::PadOp>(
1813
+ loc, collapsedPaddedType, reshapeOp.getSrc (), newLow, newHigh,
1814
+ padOp.getConstantPaddingValue (), padOp.getNofold ());
1815
+
1816
+ rewriter.replaceOpWithNewOp <tensor::ExpandShapeOp>(
1817
+ padOp, padOp.getResultType (), newPadOp.getResult (), reassociations);
1818
+
1819
+ return success ();
1820
+ }
1821
+
1822
+ private:
1823
+ ControlFusionFn controlFoldingReshapes;
1824
+ };
1825
+
1705
1826
// / Pattern to collapse dimensions.
1706
1827
template <typename LinalgType>
1707
1828
class CollapseLinalgDimensions : public OpRewritePattern <LinalgType> {
@@ -1937,6 +2058,8 @@ void mlir::linalg::populateFoldReshapeOpsByExpansionPatterns(
1937
2058
const ControlFusionFn &controlFoldingReshapes) {
1938
2059
patterns.add <FoldReshapeWithGenericOpByExpansion>(patterns.getContext (),
1939
2060
controlFoldingReshapes);
2061
+ // patterns.add<FoldPadWithProducerReshapeOpByExpansion>(patterns.getContext(),
2062
+ // controlFoldingReshapes);
1940
2063
patterns.add <FoldWithProducerReshapeOpByExpansion>(patterns.getContext (),
1941
2064
controlFoldingReshapes);
1942
2065
}
@@ -1946,6 +2069,8 @@ void mlir::linalg::populateFoldReshapeOpsByCollapsingPatterns(
1946
2069
const ControlFusionFn &controlFoldingReshapes) {
1947
2070
patterns.add <FoldWithProducerReshapeOpByCollapsing>(patterns.getContext (),
1948
2071
controlFoldingReshapes);
2072
+ // patterns.add<FoldPadWithProducerReshapeOpByCollapsing>(
2073
+ // patterns.getContext(), controlFoldingReshapes);
1949
2074
}
1950
2075
1951
2076
void mlir::linalg::populateElementwiseOpsFusionPatterns (
0 commit comments