Skip to content

Commit 1735c49

Browse files
committed
add tests, support dynamic expand
1 parent a33e68b commit 1735c49

File tree

3 files changed

+151
-11
lines changed

3 files changed

+151
-11
lines changed

mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
#include "mlir/Dialect/Arith/Utils/Utils.h"
1717
#include "mlir/Dialect/Linalg/IR/Linalg.h"
1818
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
19-
#include "mlir/Dialect/Linalg/Utils/Utils.h"
2019
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
2120
#include "mlir/IR/AffineExpr.h"
2221
#include "mlir/IR/AffineMap.h"
@@ -981,7 +980,7 @@ class FoldPadWithProducerReshapeOpByExpansion
981980
reshapeOp.getReassociationIndices();
982981

983982
for (auto [reInd, l, h] : llvm::zip_equal(reassociations, low, high)) {
984-
if (reInd.size() != 1 && l != 0 && h != 0)
983+
if (reInd.size() != 1 && (l != 0 || h != 0))
985984
return failure();
986985
}
987986

@@ -993,7 +992,7 @@ class FoldPadWithProducerReshapeOpByExpansion
993992
if (reInd.size() == 1) {
994993
expandedPaddedShape[reInd[0]] = paddedType.getShape()[idx];
995994
}
996-
for (auto ind : reInd) {
995+
for (size_t i = 0; i < reInd.size(); ++i) {
997996
newLow.push_back(padOp.getMixedLowPad()[idx]);
998997
newHigh.push_back(padOp.getMixedHighPad()[idx]);
999998
}
@@ -1798,23 +1797,35 @@ class FoldPadWithProducerReshapeOpByCollapsing
17981797
RankedTensorType collapsedType = reshapeOp.getSrcType();
17991798
RankedTensorType paddedType = padOp.getResultType();
18001799
SmallVector<int64_t> collapsedPaddedShape(collapsedType.getShape());
1800+
SmallVector<OpFoldResult> expandedPaddedSizes(
1801+
getMixedValues(reshapeOp.getStaticOutputShape(),
1802+
reshapeOp.getOutputShape(), rewriter));
1803+
AffineExpr d0, d1, d2;
1804+
bindDims(rewriter.getContext(), d0, d1, d2);
1805+
auto addMap = AffineMap::get(3, 0, {d0 + d1 + d2});
1806+
Location loc = reshapeOp->getLoc();
18011807
for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
1808+
OpFoldResult l = padOp.getMixedLowPad()[reInd[0]];
1809+
OpFoldResult h = padOp.getMixedHighPad()[reInd[0]];
18021810
if (reInd.size() == 1) {
18031811
collapsedPaddedShape[idx] = paddedType.getShape()[reInd[0]];
1812+
OpFoldResult paddedSize = affine::makeComposedFoldedAffineApply(
1813+
rewriter, loc, addMap, {l, h, expandedPaddedSizes[reInd[0]]});
1814+
expandedPaddedSizes[reInd[0]] = paddedSize;
18041815
}
1805-
newLow.push_back(padOp.getMixedLowPad()[reInd[0]]);
1806-
newHigh.push_back(padOp.getMixedHighPad()[reInd[0]]);
1816+
newLow.push_back(l);
1817+
newHigh.push_back(h);
18071818
}
18081819

1809-
Location loc = padOp->getLoc();
18101820
RankedTensorType collapsedPaddedType =
18111821
paddedType.clone(collapsedPaddedShape);
18121822
auto newPadOp = rewriter.create<tensor::PadOp>(
18131823
loc, collapsedPaddedType, reshapeOp.getSrc(), newLow, newHigh,
18141824
padOp.getConstantPaddingValue(), padOp.getNofold());
18151825

18161826
rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
1817-
padOp, padOp.getResultType(), newPadOp.getResult(), reassociations);
1827+
padOp, padOp.getResultType(), newPadOp.getResult(), reassociations,
1828+
expandedPaddedSizes);
18181829

18191830
return success();
18201831
}
@@ -2058,8 +2069,8 @@ void mlir::linalg::populateFoldReshapeOpsByExpansionPatterns(
20582069
const ControlFusionFn &controlFoldingReshapes) {
20592070
patterns.add<FoldReshapeWithGenericOpByExpansion>(patterns.getContext(),
20602071
controlFoldingReshapes);
2061-
// patterns.add<FoldPadWithProducerReshapeOpByExpansion>(patterns.getContext(),
2062-
// controlFoldingReshapes);
2072+
patterns.add<FoldPadWithProducerReshapeOpByExpansion>(patterns.getContext(),
2073+
controlFoldingReshapes);
20632074
patterns.add<FoldWithProducerReshapeOpByExpansion>(patterns.getContext(),
20642075
controlFoldingReshapes);
20652076
}
@@ -2069,8 +2080,8 @@ void mlir::linalg::populateFoldReshapeOpsByCollapsingPatterns(
20692080
const ControlFusionFn &controlFoldingReshapes) {
20702081
patterns.add<FoldWithProducerReshapeOpByCollapsing>(patterns.getContext(),
20712082
controlFoldingReshapes);
2072-
// patterns.add<FoldPadWithProducerReshapeOpByCollapsing>(
2073-
// patterns.getContext(), controlFoldingReshapes);
2083+
patterns.add<FoldPadWithProducerReshapeOpByCollapsing>(
2084+
patterns.getContext(), controlFoldingReshapes);
20742085
}
20752086

20762087
void mlir::linalg::populateElementwiseOpsFusionPatterns(

mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -537,3 +537,71 @@ func.func @no_fold_non_consecutive_reduction_dims(%arg0 : tensor<?x?xi32>, %sz0:
537537
// CHECK: %[[GENERIC:.+]] = linalg.generic
538538
// CHECK-SAME: ins(%[[EXPAND_ARG0]] :
539539
// CHECK: return %[[GENERIC]]
540+
541+
// -----
542+
543+
func.func @fuse_by_collapsing_pad(%arg0 : tensor<2x12x5x336x9xi32>) -> tensor<8x3x4x17x6x7x8x14xi32> {
544+
%expand = tensor.expand_shape %arg0 [[0], [1, 2], [3], [4, 5, 6], [7]] output_shape [2, 3, 4, 5, 6, 7, 8, 9] : tensor<2x12x5x336x9xi32> into tensor<2x3x4x5x6x7x8x9xi32>
545+
%cst = arith.constant 0 : i32
546+
%padded_0 = tensor.pad %expand low[1, 0, 0, 8, 0, 0, 0, 3] high[5, 0, 0, 4, 0, 0, 0, 2] {
547+
^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index,
548+
%arg5: index, %arg6: index, %arg7: index, %arg8: index):
549+
tensor.yield %cst : i32
550+
} : tensor<2x3x4x5x6x7x8x9xi32> to tensor<8x3x4x17x6x7x8x14xi32>
551+
return %padded_0 : tensor<8x3x4x17x6x7x8x14xi32>
552+
}
553+
// CHECK: func @fuse_by_collapsing_pad(
554+
// CHECK-SAME: %[[ARG0:.+]]: tensor<2x12x5x336x9xi32>)
555+
// CHECK: %[[PAD:.+]] = tensor.pad %[[ARG0]]
556+
// CHECK-SAME: low[1, 0, 8, 0, 3] high[5, 0, 4, 0, 2]
557+
// CHECK: tensor<2x12x5x336x9xi32> to tensor<8x12x17x336x14xi32>
558+
// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[PAD]] {{\[}}[0], [1, 2], [3], [4, 5, 6], [7]]
559+
// CHECK-SAME: output_shape [8, 3, 4, 17, 6, 7, 8, 14] : tensor<8x12x17x336x14xi32> into tensor<8x3x4x17x6x7x8x14xi32>
560+
// CHECK: return %[[EXPAND]]
561+
562+
// -----
563+
564+
func.func @no_fuse_by_collapsing_pad(%arg0 : tensor<2x12x5x336x9xi32>) -> tensor<8x5x4x17x6x7x8x14xi32> {
565+
%expand = tensor.expand_shape %arg0 [[0], [1, 2], [3], [4, 5, 6], [7]] output_shape [2, 3, 4, 5, 6, 7, 8, 9] : tensor<2x12x5x336x9xi32> into tensor<2x3x4x5x6x7x8x9xi32>
566+
%cst = arith.constant 0 : i32
567+
%padded_0 = tensor.pad %expand low[1, 2, 0, 8, 0, 0, 0, 3] high[5, 0, 0, 4, 0, 0, 0, 2] {
568+
^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index,
569+
%arg5: index, %arg6: index, %arg7: index, %arg8: index):
570+
tensor.yield %cst : i32
571+
} : tensor<2x3x4x5x6x7x8x9xi32> to tensor<8x5x4x17x6x7x8x14xi32>
572+
return %padded_0 : tensor<8x5x4x17x6x7x8x14xi32>
573+
}
574+
// CHECK: func @no_fuse_by_collapsing_pad(
575+
// CHECK-SAME: %[[ARG0:.+]]: tensor<2x12x5x336x9xi32>)
576+
// CHECK: %[[EXPAND_ARG0:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1, 2], [3], [4, 5, 6], [7]]
577+
// CHECK-SAME: output_shape [2, 3, 4, 5, 6, 7, 8, 9] : tensor<2x12x5x336x9xi32> into tensor<2x3x4x5x6x7x8x9xi32>
578+
// CHECK: %[[PAD:.+]] = tensor.pad %[[EXPAND_ARG0]]
579+
// CHECK-SAME: low[1, 2, 0, 8, 0, 0, 0, 3] high[5, 0, 0, 4, 0, 0, 0, 2]
580+
// CHECK: tensor<2x3x4x5x6x7x8x9xi32> to tensor<8x5x4x17x6x7x8x14xi32>
581+
// CHECK: return %[[PAD]]
582+
583+
// -----
584+
585+
func.func @fuse_by_collapsing_dynamic_pad(%arg0 : tensor<?x?x?x?xf32>,
586+
%s0 : index, %s1 : index, %s2 : index, %s3 : index, %s4 : index, %s5 : index,
587+
%l0 : index, %l1 : index, %h0 : index, %h1 : index) -> tensor<?x?x?x?x?x?xf32> {
588+
%expand = tensor.expand_shape %arg0 [[0], [1, 2], [3], [4, 5]] output_shape [%s0, %s1, %s2, %s3, %s4, %s5] : tensor<?x?x?x?xf32> into tensor<?x?x?x?x?x?xf32>
589+
%cst = arith.constant 0.0 : f32
590+
%padded_0 = tensor.pad %expand low[%l0, 0, 0, %l1, 0, 0] high[%h0, 0, 0, %h1, 0, 0] {
591+
^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index, %arg5: index, %arg6: index):
592+
tensor.yield %cst : f32
593+
} : tensor<?x?x?x?x?x?xf32> to tensor<?x?x?x?x?x?xf32>
594+
return %padded_0 : tensor<?x?x?x?x?x?xf32>
595+
}
596+
// CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0, s1, s2] -> (s0 + s1 + s2)>
597+
// CHECK: func @fuse_by_collapsing_dynamic_pad(
598+
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?x?xf32>
599+
// CHECK-SAME: %[[S0:.+]]: index, %[[S1:.+]]: index, %[[S2:.+]]: index, %[[S3:.+]]: index, %[[S4:.+]]: index, %[[S5:.+]]: index, %[[L0:.+]]: index, %[[L1:.+]]: index, %[[H0:.+]]: index, %[[H1:.+]]: index
600+
// CHECK: %[[PAD_SIZE0:.+]] = affine.apply #[[MAP]]()[%[[L0]], %[[H0]], %[[S0]]]
601+
// CHECK: %[[PAD_SIZE1:.+]] = affine.apply #[[MAP]]()[%[[L1]], %[[H1]], %[[S3]]]
602+
// CHECK: %[[PAD:.+]] = tensor.pad %[[ARG0]]
603+
// CHECK-SAME: low[%[[L0]], 0, %[[L1]], 0] high[%[[H0]], 0, %[[H1]], 0]
604+
// CHECK: tensor<?x?x?x?xf32> to tensor<?x?x?x?xf32>
605+
// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[PAD]] {{\[}}[0], [1, 2], [3], [4, 5]]
606+
// CHECK-SAME: output_shape [%[[PAD_SIZE0]], %[[S1]], %[[S2]], %[[PAD_SIZE1]], %[[S4]], %[[S5]]] : tensor<?x?x?x?xf32> into tensor<?x?x?x?x?x?xf32>
607+
// CHECK: return %[[EXPAND]]

mlir/test/Dialect/Linalg/reshape_fusion.mlir

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -826,3 +826,64 @@ func.func @linalg_add_reshape_producer_fusion(%arg0 : tensor<?x7x?x8xf32>,
826826
// CHECK-SAME: [0, 1], [2, 3]
827827
// CHECK-SAME: tensor<?x7x?x8xf32> into tensor<?x?xf32>
828828
// CHECK: return %[[T4]]
829+
830+
// -----
831+
832+
func.func @fuse_by_expanding_pad(%arg0 : tensor<2x3x4x5x6x7x8x9xi32>) -> tensor<8x12x17x336x14xi32> {
833+
%collapse = tensor.collapse_shape %arg0 [[0], [1, 2], [3], [4, 5, 6], [7]] : tensor<2x3x4x5x6x7x8x9xi32> into tensor<2x12x5x336x9xi32>
834+
%cst = arith.constant 0 : i32
835+
%padded_0 = tensor.pad %collapse low[1, 0, 8, 0, 3] high[5, 0, 4, 0, 2] {
836+
^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index, %arg5: index):
837+
tensor.yield %cst : i32
838+
} : tensor<2x12x5x336x9xi32> to tensor<8x12x17x336x14xi32>
839+
return %padded_0 : tensor<8x12x17x336x14xi32>
840+
}
841+
// CHECK: func @fuse_by_expanding_pad(
842+
// CHECK-SAME: %[[ARG0:.+]]: tensor<2x3x4x5x6x7x8x9xi32>)
843+
// CHECK: %[[PAD:.+]] = tensor.pad %[[ARG0]]
844+
// CHECK-SAME: low[1, 0, 0, 8, 0, 0, 0, 3] high[5, 0, 0, 4, 0, 0, 0, 2]
845+
// CHECK: tensor<2x3x4x5x6x7x8x9xi32> to tensor<8x3x4x17x6x7x8x14xi32>
846+
// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[PAD]] {{\[}}[0], [1, 2], [3], [4, 5, 6], [7]]
847+
// CHECK-SAME: : tensor<8x3x4x17x6x7x8x14xi32> into tensor<8x12x17x336x14xi32>
848+
// CHECK: return %[[COLLAPSE]]
849+
850+
// -----
851+
852+
func.func @no_fuse_by_expanding_pad(%arg0 : tensor<2x3x4x5x6x7x8x9xi32>) -> tensor<8x12x17x339x14xi32> {
853+
%collapse = tensor.collapse_shape %arg0 [[0], [1, 2], [3], [4, 5, 6], [7]] : tensor<2x3x4x5x6x7x8x9xi32> into tensor<2x12x5x336x9xi32>
854+
%cst = arith.constant 0 : i32
855+
%padded_0 = tensor.pad %collapse low[1, 0, 8, 0, 3] high[5, 0, 4, 3, 2] {
856+
^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index, %arg5: index):
857+
tensor.yield %cst : i32
858+
} : tensor<2x12x5x336x9xi32> to tensor<8x12x17x339x14xi32>
859+
return %padded_0 : tensor<8x12x17x339x14xi32>
860+
}
861+
// CHECK: func @no_fuse_by_expanding_pad(
862+
// CHECK-SAME: %[[ARG0:.+]]: tensor<2x3x4x5x6x7x8x9xi32>)
863+
// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0], [1, 2], [3], [4, 5, 6], [7]]
864+
// CHECK-SAME: : tensor<2x3x4x5x6x7x8x9xi32> into tensor<2x12x5x336x9xi32>
865+
// CHECK: %[[PAD:.+]] = tensor.pad %[[COLLAPSE]]
866+
// CHECK-SAME: low[1, 0, 8, 0, 3] high[5, 0, 4, 3, 2]
867+
// CHECK: tensor<2x12x5x336x9xi32> to tensor<8x12x17x339x14xi32>
868+
// CHECK: return %[[PAD]]
869+
870+
// -----
871+
872+
func.func @fuse_by_expanding_dynamic_pad(%arg0 : tensor<?x?x?x?x?x?xi32>, %l0: index, %l1: index, %h0: index, %h1: index) -> tensor<?x?x?x?xi32> {
873+
%collapse = tensor.collapse_shape %arg0 [[0], [1, 2], [3], [4, 5]] : tensor<?x?x?x?x?x?xi32> into tensor<?x?x?x?xi32>
874+
%cst = arith.constant 0 : i32
875+
%padded_0 = tensor.pad %collapse low[%l0, 0, %l1, 0] high[%h0, 0, %h1, 0] {
876+
^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index):
877+
tensor.yield %cst : i32
878+
} : tensor<?x?x?x?xi32> to tensor<?x?x?x?xi32>
879+
return %padded_0 : tensor<?x?x?x?xi32>
880+
}
881+
// CHECK: func @fuse_by_expanding_dynamic_pad(
882+
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?x?x?x?xi32>
883+
// CHECK-SAME: %[[L0:.+]]: index, %[[L1:.+]]: index, %[[H0:.+]]: index, %[[H1:.+]]: index
884+
// CHECK: %[[PAD:.+]] = tensor.pad %[[ARG0]]
885+
// CHECK-SAME: low[%[[L0]], 0, 0, %[[L1]], 0, 0] high[%[[H0]], 0, 0, %[[H1]], 0, 0]
886+
// CHECK: tensor<?x?x?x?x?x?xi32> to tensor<?x?x?x?x?x?xi32>
887+
// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[PAD]] {{\[}}[0], [1, 2], [3], [4, 5]]
888+
// CHECK-SAME: : tensor<?x?x?x?x?x?xi32> into tensor<?x?x?x?xi32>
889+
// CHECK: return %[[COLLAPSE]]

0 commit comments

Comments
 (0)