Skip to content

Commit a33e68b

Browse files
committed
[mlir] Add reshape propagation patterns for tensor.pad
1 parent b5b61cc commit a33e68b

File tree

1 file changed

+125
-0
lines changed

1 file changed

+125
-0
lines changed

mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "mlir/Dialect/Arith/Utils/Utils.h"
1717
#include "mlir/Dialect/Linalg/IR/Linalg.h"
1818
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
19+
#include "mlir/Dialect/Linalg/Utils/Utils.h"
1920
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
2021
#include "mlir/IR/AffineExpr.h"
2122
#include "mlir/IR/AffineMap.h"
@@ -956,6 +957,64 @@ class FoldWithProducerReshapeOpByExpansion
956957
ControlFusionFn controlFoldingReshapes;
957958
};
958959

960+
class FoldPadWithProducerReshapeOpByExpansion
961+
: public OpRewritePattern<tensor::PadOp> {
962+
public:
963+
FoldPadWithProducerReshapeOpByExpansion(MLIRContext *context,
964+
ControlFusionFn foldReshapes,
965+
PatternBenefit benefit = 1)
966+
: OpRewritePattern<tensor::PadOp>(context, benefit),
967+
controlFoldingReshapes(std::move(foldReshapes)) {}
968+
969+
LogicalResult matchAndRewrite(tensor::PadOp padOp,
970+
PatternRewriter &rewriter) const override {
971+
tensor::CollapseShapeOp reshapeOp =
972+
padOp.getSource().getDefiningOp<tensor::CollapseShapeOp>();
973+
if (!reshapeOp)
974+
return failure();
975+
if (!reshapeOp->hasOneUse())
976+
return failure();
977+
978+
ArrayRef<int64_t> low = padOp.getStaticLow();
979+
ArrayRef<int64_t> high = padOp.getStaticHigh();
980+
SmallVector<ReassociationIndices> reassociations =
981+
reshapeOp.getReassociationIndices();
982+
983+
for (auto [reInd, l, h] : llvm::zip_equal(reassociations, low, high)) {
984+
if (reInd.size() != 1 && l != 0 && h != 0)
985+
return failure();
986+
}
987+
988+
SmallVector<OpFoldResult> newLow, newHigh;
989+
RankedTensorType expandedType = reshapeOp.getSrcType();
990+
RankedTensorType paddedType = padOp.getResultType();
991+
SmallVector<int64_t> expandedPaddedShape(expandedType.getShape());
992+
for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
993+
if (reInd.size() == 1) {
994+
expandedPaddedShape[reInd[0]] = paddedType.getShape()[idx];
995+
}
996+
for (auto ind : reInd) {
997+
newLow.push_back(padOp.getMixedLowPad()[idx]);
998+
newHigh.push_back(padOp.getMixedHighPad()[idx]);
999+
}
1000+
}
1001+
1002+
Location loc = padOp->getLoc();
1003+
RankedTensorType expandedPaddedType = paddedType.clone(expandedPaddedShape);
1004+
auto newPadOp = rewriter.create<tensor::PadOp>(
1005+
loc, expandedPaddedType, reshapeOp.getSrc(), newLow, newHigh,
1006+
padOp.getConstantPaddingValue(), padOp.getNofold());
1007+
1008+
rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
1009+
padOp, padOp.getResultType(), newPadOp.getResult(), reassociations);
1010+
1011+
return success();
1012+
}
1013+
1014+
private:
1015+
ControlFusionFn controlFoldingReshapes;
1016+
};
1017+
9591018
/// Pattern to fold a tensor.expand_shape op with its producer generic op
9601019
/// by expanding the dimensionality of the loop in the producer op.
9611020
struct FoldReshapeWithGenericOpByExpansion
@@ -1702,6 +1761,68 @@ class FoldWithProducerReshapeOpByCollapsing
17021761
ControlFusionFn controlFoldingReshapes;
17031762
};
17041763

1764+
class FoldPadWithProducerReshapeOpByCollapsing
1765+
: public OpRewritePattern<tensor::PadOp> {
1766+
public:
1767+
FoldPadWithProducerReshapeOpByCollapsing(MLIRContext *context,
1768+
ControlFusionFn foldReshapes,
1769+
PatternBenefit benefit = 1)
1770+
: OpRewritePattern<tensor::PadOp>(context, benefit),
1771+
controlFoldingReshapes(std::move(foldReshapes)) {}
1772+
1773+
LogicalResult matchAndRewrite(tensor::PadOp padOp,
1774+
PatternRewriter &rewriter) const override {
1775+
tensor::ExpandShapeOp reshapeOp =
1776+
padOp.getSource().getDefiningOp<tensor::ExpandShapeOp>();
1777+
if (!reshapeOp)
1778+
return failure();
1779+
if (!reshapeOp->hasOneUse())
1780+
return failure();
1781+
1782+
ArrayRef<int64_t> low = padOp.getStaticLow();
1783+
ArrayRef<int64_t> high = padOp.getStaticHigh();
1784+
SmallVector<ReassociationIndices> reassociations =
1785+
reshapeOp.getReassociationIndices();
1786+
1787+
for (auto reInd : reassociations) {
1788+
if (reInd.size() == 1)
1789+
continue;
1790+
if (llvm::any_of(reInd, [&](int64_t ind) {
1791+
return low[ind] != 0 || high[ind] != 0;
1792+
})) {
1793+
return failure();
1794+
}
1795+
}
1796+
1797+
SmallVector<OpFoldResult> newLow, newHigh;
1798+
RankedTensorType collapsedType = reshapeOp.getSrcType();
1799+
RankedTensorType paddedType = padOp.getResultType();
1800+
SmallVector<int64_t> collapsedPaddedShape(collapsedType.getShape());
1801+
for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
1802+
if (reInd.size() == 1) {
1803+
collapsedPaddedShape[idx] = paddedType.getShape()[reInd[0]];
1804+
}
1805+
newLow.push_back(padOp.getMixedLowPad()[reInd[0]]);
1806+
newHigh.push_back(padOp.getMixedHighPad()[reInd[0]]);
1807+
}
1808+
1809+
Location loc = padOp->getLoc();
1810+
RankedTensorType collapsedPaddedType =
1811+
paddedType.clone(collapsedPaddedShape);
1812+
auto newPadOp = rewriter.create<tensor::PadOp>(
1813+
loc, collapsedPaddedType, reshapeOp.getSrc(), newLow, newHigh,
1814+
padOp.getConstantPaddingValue(), padOp.getNofold());
1815+
1816+
rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
1817+
padOp, padOp.getResultType(), newPadOp.getResult(), reassociations);
1818+
1819+
return success();
1820+
}
1821+
1822+
private:
1823+
ControlFusionFn controlFoldingReshapes;
1824+
};
1825+
17051826
/// Pattern to collapse dimensions.
17061827
template <typename LinalgType>
17071828
class CollapseLinalgDimensions : public OpRewritePattern<LinalgType> {
@@ -1937,6 +2058,8 @@ void mlir::linalg::populateFoldReshapeOpsByExpansionPatterns(
19372058
const ControlFusionFn &controlFoldingReshapes) {
19382059
patterns.add<FoldReshapeWithGenericOpByExpansion>(patterns.getContext(),
19392060
controlFoldingReshapes);
2061+
// patterns.add<FoldPadWithProducerReshapeOpByExpansion>(patterns.getContext(),
2062+
// controlFoldingReshapes);
19402063
patterns.add<FoldWithProducerReshapeOpByExpansion>(patterns.getContext(),
19412064
controlFoldingReshapes);
19422065
}
@@ -1946,6 +2069,8 @@ void mlir::linalg::populateFoldReshapeOpsByCollapsingPatterns(
19462069
const ControlFusionFn &controlFoldingReshapes) {
19472070
patterns.add<FoldWithProducerReshapeOpByCollapsing>(patterns.getContext(),
19482071
controlFoldingReshapes);
2072+
// patterns.add<FoldPadWithProducerReshapeOpByCollapsing>(
2073+
// patterns.getContext(), controlFoldingReshapes);
19492074
}
19502075

19512076
void mlir::linalg::populateElementwiseOpsFusionPatterns(

0 commit comments

Comments
 (0)