Skip to content

Commit c886d66

Browse files
authored
[mlir] Add reshape propagation patterns for tensor.pad (#94489)
This PR adds fusion by collapsing and fusion by expansion patterns for `tensor.pad` ops in ElementwiseOpFusion. Pad ops can be expanded or collapsed as long as none of the padded dimensions will be expanded or collapsed.
1 parent 5b2f7a1 commit c886d66

File tree

3 files changed

+275
-0
lines changed

3 files changed

+275
-0
lines changed

mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp

+146
Original file line numberDiff line numberDiff line change
@@ -956,6 +956,69 @@ class FoldWithProducerReshapeOpByExpansion
956956
ControlFusionFn controlFoldingReshapes;
957957
};
958958

959+
class FoldPadWithProducerReshapeOpByExpansion
960+
: public OpRewritePattern<tensor::PadOp> {
961+
public:
962+
FoldPadWithProducerReshapeOpByExpansion(MLIRContext *context,
963+
ControlFusionFn foldReshapes,
964+
PatternBenefit benefit = 1)
965+
: OpRewritePattern<tensor::PadOp>(context, benefit),
966+
controlFoldingReshapes(std::move(foldReshapes)) {}
967+
968+
LogicalResult matchAndRewrite(tensor::PadOp padOp,
969+
PatternRewriter &rewriter) const override {
970+
tensor::CollapseShapeOp reshapeOp =
971+
padOp.getSource().getDefiningOp<tensor::CollapseShapeOp>();
972+
if (!reshapeOp)
973+
return failure();
974+
if (!reshapeOp->hasOneUse())
975+
return failure();
976+
977+
if (!controlFoldingReshapes(&padOp.getSourceMutable())) {
978+
return rewriter.notifyMatchFailure(padOp,
979+
"fusion blocked by control function");
980+
}
981+
982+
ArrayRef<int64_t> low = padOp.getStaticLow();
983+
ArrayRef<int64_t> high = padOp.getStaticHigh();
984+
SmallVector<ReassociationIndices> reassociations =
985+
reshapeOp.getReassociationIndices();
986+
987+
for (auto [reInd, l, h] : llvm::zip_equal(reassociations, low, high)) {
988+
if (reInd.size() != 1 && (l != 0 || h != 0))
989+
return failure();
990+
}
991+
992+
SmallVector<OpFoldResult> newLow, newHigh;
993+
RankedTensorType expandedType = reshapeOp.getSrcType();
994+
RankedTensorType paddedType = padOp.getResultType();
995+
SmallVector<int64_t> expandedPaddedShape(expandedType.getShape());
996+
for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
997+
if (reInd.size() == 1) {
998+
expandedPaddedShape[reInd[0]] = paddedType.getShape()[idx];
999+
}
1000+
for (size_t i = 0; i < reInd.size(); ++i) {
1001+
newLow.push_back(padOp.getMixedLowPad()[idx]);
1002+
newHigh.push_back(padOp.getMixedHighPad()[idx]);
1003+
}
1004+
}
1005+
1006+
Location loc = padOp->getLoc();
1007+
RankedTensorType expandedPaddedType = paddedType.clone(expandedPaddedShape);
1008+
auto newPadOp = rewriter.create<tensor::PadOp>(
1009+
loc, expandedPaddedType, reshapeOp.getSrc(), newLow, newHigh,
1010+
padOp.getConstantPaddingValue(), padOp.getNofold());
1011+
1012+
rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
1013+
padOp, padOp.getResultType(), newPadOp.getResult(), reassociations);
1014+
1015+
return success();
1016+
}
1017+
1018+
private:
1019+
ControlFusionFn controlFoldingReshapes;
1020+
};
1021+
9591022
/// Pattern to fold a tensor.expand_shape op with its producer generic op
9601023
/// by expanding the dimensionality of the loop in the producer op.
9611024
struct FoldReshapeWithGenericOpByExpansion
@@ -1702,6 +1765,85 @@ class FoldWithProducerReshapeOpByCollapsing
17021765
ControlFusionFn controlFoldingReshapes;
17031766
};
17041767

1768+
class FoldPadWithProducerReshapeOpByCollapsing
1769+
: public OpRewritePattern<tensor::PadOp> {
1770+
public:
1771+
FoldPadWithProducerReshapeOpByCollapsing(MLIRContext *context,
1772+
ControlFusionFn foldReshapes,
1773+
PatternBenefit benefit = 1)
1774+
: OpRewritePattern<tensor::PadOp>(context, benefit),
1775+
controlFoldingReshapes(std::move(foldReshapes)) {}
1776+
1777+
LogicalResult matchAndRewrite(tensor::PadOp padOp,
1778+
PatternRewriter &rewriter) const override {
1779+
tensor::ExpandShapeOp reshapeOp =
1780+
padOp.getSource().getDefiningOp<tensor::ExpandShapeOp>();
1781+
if (!reshapeOp)
1782+
return failure();
1783+
if (!reshapeOp->hasOneUse())
1784+
return failure();
1785+
1786+
if (!controlFoldingReshapes(&padOp.getSourceMutable())) {
1787+
return rewriter.notifyMatchFailure(padOp,
1788+
"fusion blocked by control function");
1789+
}
1790+
1791+
ArrayRef<int64_t> low = padOp.getStaticLow();
1792+
ArrayRef<int64_t> high = padOp.getStaticHigh();
1793+
SmallVector<ReassociationIndices> reassociations =
1794+
reshapeOp.getReassociationIndices();
1795+
1796+
for (auto reInd : reassociations) {
1797+
if (reInd.size() == 1)
1798+
continue;
1799+
if (llvm::any_of(reInd, [&](int64_t ind) {
1800+
return low[ind] != 0 || high[ind] != 0;
1801+
})) {
1802+
return failure();
1803+
}
1804+
}
1805+
1806+
SmallVector<OpFoldResult> newLow, newHigh;
1807+
RankedTensorType collapsedType = reshapeOp.getSrcType();
1808+
RankedTensorType paddedType = padOp.getResultType();
1809+
SmallVector<int64_t> collapsedPaddedShape(collapsedType.getShape());
1810+
SmallVector<OpFoldResult> expandedPaddedSizes(
1811+
getMixedValues(reshapeOp.getStaticOutputShape(),
1812+
reshapeOp.getOutputShape(), rewriter));
1813+
AffineExpr d0, d1, d2;
1814+
bindDims(rewriter.getContext(), d0, d1, d2);
1815+
auto addMap = AffineMap::get(3, 0, {d0 + d1 + d2});
1816+
Location loc = reshapeOp->getLoc();
1817+
for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
1818+
OpFoldResult l = padOp.getMixedLowPad()[reInd[0]];
1819+
OpFoldResult h = padOp.getMixedHighPad()[reInd[0]];
1820+
if (reInd.size() == 1) {
1821+
collapsedPaddedShape[idx] = paddedType.getShape()[reInd[0]];
1822+
OpFoldResult paddedSize = affine::makeComposedFoldedAffineApply(
1823+
rewriter, loc, addMap, {l, h, expandedPaddedSizes[reInd[0]]});
1824+
expandedPaddedSizes[reInd[0]] = paddedSize;
1825+
}
1826+
newLow.push_back(l);
1827+
newHigh.push_back(h);
1828+
}
1829+
1830+
RankedTensorType collapsedPaddedType =
1831+
paddedType.clone(collapsedPaddedShape);
1832+
auto newPadOp = rewriter.create<tensor::PadOp>(
1833+
loc, collapsedPaddedType, reshapeOp.getSrc(), newLow, newHigh,
1834+
padOp.getConstantPaddingValue(), padOp.getNofold());
1835+
1836+
rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
1837+
padOp, padOp.getResultType(), newPadOp.getResult(), reassociations,
1838+
expandedPaddedSizes);
1839+
1840+
return success();
1841+
}
1842+
1843+
private:
1844+
ControlFusionFn controlFoldingReshapes;
1845+
};
1846+
17051847
/// Pattern to collapse dimensions.
17061848
template <typename LinalgType>
17071849
class CollapseLinalgDimensions : public OpRewritePattern<LinalgType> {
@@ -1937,6 +2079,8 @@ void mlir::linalg::populateFoldReshapeOpsByExpansionPatterns(
19372079
const ControlFusionFn &controlFoldingReshapes) {
19382080
patterns.add<FoldReshapeWithGenericOpByExpansion>(patterns.getContext(),
19392081
controlFoldingReshapes);
2082+
patterns.add<FoldPadWithProducerReshapeOpByExpansion>(patterns.getContext(),
2083+
controlFoldingReshapes);
19402084
patterns.add<FoldWithProducerReshapeOpByExpansion>(patterns.getContext(),
19412085
controlFoldingReshapes);
19422086
}
@@ -1946,6 +2090,8 @@ void mlir::linalg::populateFoldReshapeOpsByCollapsingPatterns(
19462090
const ControlFusionFn &controlFoldingReshapes) {
19472091
patterns.add<FoldWithProducerReshapeOpByCollapsing>(patterns.getContext(),
19482092
controlFoldingReshapes);
2093+
patterns.add<FoldPadWithProducerReshapeOpByCollapsing>(
2094+
patterns.getContext(), controlFoldingReshapes);
19492095
}
19502096

19512097
void mlir::linalg::populateElementwiseOpsFusionPatterns(

mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir

+68
Original file line numberDiff line numberDiff line change
@@ -537,3 +537,71 @@ func.func @no_fold_non_consecutive_reduction_dims(%arg0 : tensor<?x?xi32>, %sz0:
537537
// CHECK: %[[GENERIC:.+]] = linalg.generic
538538
// CHECK-SAME: ins(%[[EXPAND_ARG0]] :
539539
// CHECK: return %[[GENERIC]]
540+
541+
// -----
542+
543+
func.func @fuse_by_collapsing_pad(%arg0 : tensor<2x12x5x336x9xi32>) -> tensor<8x3x4x17x6x7x8x14xi32> {
544+
%expand = tensor.expand_shape %arg0 [[0], [1, 2], [3], [4, 5, 6], [7]] output_shape [2, 3, 4, 5, 6, 7, 8, 9] : tensor<2x12x5x336x9xi32> into tensor<2x3x4x5x6x7x8x9xi32>
545+
%cst = arith.constant 0 : i32
546+
%padded_0 = tensor.pad %expand low[1, 0, 0, 8, 0, 0, 0, 3] high[5, 0, 0, 4, 0, 0, 0, 2] {
547+
^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index,
548+
%arg5: index, %arg6: index, %arg7: index, %arg8: index):
549+
tensor.yield %cst : i32
550+
} : tensor<2x3x4x5x6x7x8x9xi32> to tensor<8x3x4x17x6x7x8x14xi32>
551+
return %padded_0 : tensor<8x3x4x17x6x7x8x14xi32>
552+
}
553+
// CHECK: func @fuse_by_collapsing_pad(
554+
// CHECK-SAME: %[[ARG0:.+]]: tensor<2x12x5x336x9xi32>)
555+
// CHECK: %[[PAD:.+]] = tensor.pad %[[ARG0]]
556+
// CHECK-SAME: low[1, 0, 8, 0, 3] high[5, 0, 4, 0, 2]
557+
// CHECK: tensor<2x12x5x336x9xi32> to tensor<8x12x17x336x14xi32>
558+
// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[PAD]] {{\[}}[0], [1, 2], [3], [4, 5, 6], [7]]
559+
// CHECK-SAME: output_shape [8, 3, 4, 17, 6, 7, 8, 14] : tensor<8x12x17x336x14xi32> into tensor<8x3x4x17x6x7x8x14xi32>
560+
// CHECK: return %[[EXPAND]]
561+
562+
// -----
563+
564+
func.func @no_fuse_by_collapsing_pad(%arg0 : tensor<2x12x5x336x9xi32>) -> tensor<8x5x4x17x6x7x8x14xi32> {
565+
%expand = tensor.expand_shape %arg0 [[0], [1, 2], [3], [4, 5, 6], [7]] output_shape [2, 3, 4, 5, 6, 7, 8, 9] : tensor<2x12x5x336x9xi32> into tensor<2x3x4x5x6x7x8x9xi32>
566+
%cst = arith.constant 0 : i32
567+
%padded_0 = tensor.pad %expand low[1, 2, 0, 8, 0, 0, 0, 3] high[5, 0, 0, 4, 0, 0, 0, 2] {
568+
^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index,
569+
%arg5: index, %arg6: index, %arg7: index, %arg8: index):
570+
tensor.yield %cst : i32
571+
} : tensor<2x3x4x5x6x7x8x9xi32> to tensor<8x5x4x17x6x7x8x14xi32>
572+
return %padded_0 : tensor<8x5x4x17x6x7x8x14xi32>
573+
}
574+
// CHECK: func @no_fuse_by_collapsing_pad(
575+
// CHECK-SAME: %[[ARG0:.+]]: tensor<2x12x5x336x9xi32>)
576+
// CHECK: %[[EXPAND_ARG0:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1, 2], [3], [4, 5, 6], [7]]
577+
// CHECK-SAME: output_shape [2, 3, 4, 5, 6, 7, 8, 9] : tensor<2x12x5x336x9xi32> into tensor<2x3x4x5x6x7x8x9xi32>
578+
// CHECK: %[[PAD:.+]] = tensor.pad %[[EXPAND_ARG0]]
579+
// CHECK-SAME: low[1, 2, 0, 8, 0, 0, 0, 3] high[5, 0, 0, 4, 0, 0, 0, 2]
580+
// CHECK: tensor<2x3x4x5x6x7x8x9xi32> to tensor<8x5x4x17x6x7x8x14xi32>
581+
// CHECK: return %[[PAD]]
582+
583+
// -----
584+
585+
func.func @fuse_by_collapsing_dynamic_pad(%arg0 : tensor<?x?x?x?xf32>,
586+
%s0 : index, %s1 : index, %s2 : index, %s3 : index, %s4 : index, %s5 : index,
587+
%l0 : index, %l1 : index, %h0 : index, %h1 : index) -> tensor<?x?x?x?x?x?xf32> {
588+
%expand = tensor.expand_shape %arg0 [[0], [1, 2], [3], [4, 5]] output_shape [%s0, %s1, %s2, %s3, %s4, %s5] : tensor<?x?x?x?xf32> into tensor<?x?x?x?x?x?xf32>
589+
%cst = arith.constant 0.0 : f32
590+
%padded_0 = tensor.pad %expand low[%l0, 0, 0, %l1, 0, 0] high[%h0, 0, 0, %h1, 0, 0] {
591+
^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index, %arg5: index, %arg6: index):
592+
tensor.yield %cst : f32
593+
} : tensor<?x?x?x?x?x?xf32> to tensor<?x?x?x?x?x?xf32>
594+
return %padded_0 : tensor<?x?x?x?x?x?xf32>
595+
}
596+
// CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0, s1, s2] -> (s0 + s1 + s2)>
597+
// CHECK: func @fuse_by_collapsing_dynamic_pad(
598+
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?x?xf32>
599+
// CHECK-SAME: %[[S0:.+]]: index, %[[S1:.+]]: index, %[[S2:.+]]: index, %[[S3:.+]]: index, %[[S4:.+]]: index, %[[S5:.+]]: index, %[[L0:.+]]: index, %[[L1:.+]]: index, %[[H0:.+]]: index, %[[H1:.+]]: index
600+
// CHECK: %[[PAD_SIZE0:.+]] = affine.apply #[[MAP]]()[%[[L0]], %[[H0]], %[[S0]]]
601+
// CHECK: %[[PAD_SIZE1:.+]] = affine.apply #[[MAP]]()[%[[L1]], %[[H1]], %[[S3]]]
602+
// CHECK: %[[PAD:.+]] = tensor.pad %[[ARG0]]
603+
// CHECK-SAME: low[%[[L0]], 0, %[[L1]], 0] high[%[[H0]], 0, %[[H1]], 0]
604+
// CHECK: tensor<?x?x?x?xf32> to tensor<?x?x?x?xf32>
605+
// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[PAD]] {{\[}}[0], [1, 2], [3], [4, 5]]
606+
// CHECK-SAME: output_shape [%[[PAD_SIZE0]], %[[S1]], %[[S2]], %[[PAD_SIZE1]], %[[S4]], %[[S5]]] : tensor<?x?x?x?xf32> into tensor<?x?x?x?x?x?xf32>
607+
// CHECK: return %[[EXPAND]]

mlir/test/Dialect/Linalg/reshape_fusion.mlir

+61
Original file line numberDiff line numberDiff line change
@@ -826,3 +826,64 @@ func.func @linalg_add_reshape_producer_fusion(%arg0 : tensor<?x7x?x8xf32>,
826826
// CHECK-SAME: [0, 1], [2, 3]
827827
// CHECK-SAME: tensor<?x7x?x8xf32> into tensor<?x?xf32>
828828
// CHECK: return %[[T4]]
829+
830+
// -----
831+
832+
func.func @fuse_by_expanding_pad(%arg0 : tensor<2x3x4x5x6x7x8x9xi32>) -> tensor<8x12x17x336x14xi32> {
833+
%collapse = tensor.collapse_shape %arg0 [[0], [1, 2], [3], [4, 5, 6], [7]] : tensor<2x3x4x5x6x7x8x9xi32> into tensor<2x12x5x336x9xi32>
834+
%cst = arith.constant 0 : i32
835+
%padded_0 = tensor.pad %collapse low[1, 0, 8, 0, 3] high[5, 0, 4, 0, 2] {
836+
^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index, %arg5: index):
837+
tensor.yield %cst : i32
838+
} : tensor<2x12x5x336x9xi32> to tensor<8x12x17x336x14xi32>
839+
return %padded_0 : tensor<8x12x17x336x14xi32>
840+
}
841+
// CHECK: func @fuse_by_expanding_pad(
842+
// CHECK-SAME: %[[ARG0:.+]]: tensor<2x3x4x5x6x7x8x9xi32>)
843+
// CHECK: %[[PAD:.+]] = tensor.pad %[[ARG0]]
844+
// CHECK-SAME: low[1, 0, 0, 8, 0, 0, 0, 3] high[5, 0, 0, 4, 0, 0, 0, 2]
845+
// CHECK: tensor<2x3x4x5x6x7x8x9xi32> to tensor<8x3x4x17x6x7x8x14xi32>
846+
// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[PAD]] {{\[}}[0], [1, 2], [3], [4, 5, 6], [7]]
847+
// CHECK-SAME: : tensor<8x3x4x17x6x7x8x14xi32> into tensor<8x12x17x336x14xi32>
848+
// CHECK: return %[[COLLAPSE]]
849+
850+
// -----
851+
852+
func.func @no_fuse_by_expanding_pad(%arg0 : tensor<2x3x4x5x6x7x8x9xi32>) -> tensor<8x12x17x339x14xi32> {
853+
%collapse = tensor.collapse_shape %arg0 [[0], [1, 2], [3], [4, 5, 6], [7]] : tensor<2x3x4x5x6x7x8x9xi32> into tensor<2x12x5x336x9xi32>
854+
%cst = arith.constant 0 : i32
855+
%padded_0 = tensor.pad %collapse low[1, 0, 8, 0, 3] high[5, 0, 4, 3, 2] {
856+
^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index, %arg5: index):
857+
tensor.yield %cst : i32
858+
} : tensor<2x12x5x336x9xi32> to tensor<8x12x17x339x14xi32>
859+
return %padded_0 : tensor<8x12x17x339x14xi32>
860+
}
861+
// CHECK: func @no_fuse_by_expanding_pad(
862+
// CHECK-SAME: %[[ARG0:.+]]: tensor<2x3x4x5x6x7x8x9xi32>)
863+
// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0], [1, 2], [3], [4, 5, 6], [7]]
864+
// CHECK-SAME: : tensor<2x3x4x5x6x7x8x9xi32> into tensor<2x12x5x336x9xi32>
865+
// CHECK: %[[PAD:.+]] = tensor.pad %[[COLLAPSE]]
866+
// CHECK-SAME: low[1, 0, 8, 0, 3] high[5, 0, 4, 3, 2]
867+
// CHECK: tensor<2x12x5x336x9xi32> to tensor<8x12x17x339x14xi32>
868+
// CHECK: return %[[PAD]]
869+
870+
// -----
871+
872+
func.func @fuse_by_expanding_dynamic_pad(%arg0 : tensor<?x?x?x?x?x?xi32>, %l0: index, %l1: index, %h0: index, %h1: index) -> tensor<?x?x?x?xi32> {
873+
%collapse = tensor.collapse_shape %arg0 [[0], [1, 2], [3], [4, 5]] : tensor<?x?x?x?x?x?xi32> into tensor<?x?x?x?xi32>
874+
%cst = arith.constant 0 : i32
875+
%padded_0 = tensor.pad %collapse low[%l0, 0, %l1, 0] high[%h0, 0, %h1, 0] {
876+
^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index):
877+
tensor.yield %cst : i32
878+
} : tensor<?x?x?x?xi32> to tensor<?x?x?x?xi32>
879+
return %padded_0 : tensor<?x?x?x?xi32>
880+
}
881+
// CHECK: func @fuse_by_expanding_dynamic_pad(
882+
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?x?x?x?xi32>
883+
// CHECK-SAME: %[[L0:.+]]: index, %[[L1:.+]]: index, %[[H0:.+]]: index, %[[H1:.+]]: index
884+
// CHECK: %[[PAD:.+]] = tensor.pad %[[ARG0]]
885+
// CHECK-SAME: low[%[[L0]], 0, 0, %[[L1]], 0, 0] high[%[[H0]], 0, 0, %[[H1]], 0, 0]
886+
// CHECK: tensor<?x?x?x?x?x?xi32> to tensor<?x?x?x?x?x?xi32>
887+
// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[PAD]] {{\[}}[0], [1, 2], [3], [4, 5]]
888+
// CHECK-SAME: : tensor<?x?x?x?x?x?xi32> into tensor<?x?x?x?xi32>
889+
// CHECK: return %[[COLLAPSE]]

0 commit comments

Comments
 (0)