Skip to content

Commit a95ad2d

Browse files
IanWood1Max191
andauthored
[mlir] Add bubbling patterns for non intersecting reshapes (#103401)
Refactored @Max191's PR #94637 to move it to `Tensor` From the original PR >This PR adds fusion by expansion patterns to push a tensor.expand_shape up through a tensor.collapse_shape with non-intersecting reassociations. Sometimes parallel collapse_shape ops like this can block propagation of expand_shape ops, so this allows them to pass through each other. I'm not sure if I put the code/tests in the right places, so let me know where those go if they aren't. cc @MaheshRavishankar @hanhanW --------- Co-authored-by: Max Dawkins <[email protected]>
1 parent f6e3dbc commit a95ad2d

File tree

5 files changed

+141
-0
lines changed

5 files changed

+141
-0
lines changed

mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,10 @@ void populateDropRedundantInsertSliceRankExpansionPatterns(
6767
/// `tensor.collapse_shape` into other ops.
6868
void populateReassociativeReshapeFoldingPatterns(RewritePatternSet &patterns);
6969

70+
/// Populates `patterns` with patterns that bubble up `tensor.expand_shape`
71+
/// through `tensor.collapse_shape` ops.
72+
void populateBubbleUpExpandShapePatterns(RewritePatternSet &patterns);
73+
7074
/// Populates `patterns` with patterns that fold tensor.empty with its
7175
/// consumers.
7276
///

mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "mlir/Dialect/Linalg/IR/Linalg.h"
1818
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
1919
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
20+
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
2021
#include "mlir/IR/AffineExpr.h"
2122
#include "mlir/IR/AffineMap.h"
2223
#include "mlir/IR/Matchers.h"
@@ -2144,6 +2145,7 @@ struct LinalgElementwiseOpFusionPass
21442145
// Add elementwise op fusion patterns.
21452146
populateElementwiseOpsFusionPatterns(patterns, defaultControlFn);
21462147
populateFoldReshapeOpsByExpansionPatterns(patterns, defaultControlFn);
2148+
tensor::populateBubbleUpExpandShapePatterns(patterns);
21472149

21482150
// General canonicalization patterns.
21492151
affine::AffineApplyOp::getCanonicalizationPatterns(patterns, context);

mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,76 @@ struct FoldPaddingExpandIntoInsert : public OpRewritePattern<OpTy> {
140140
return success();
141141
}
142142
};
143+
144+
/// Pattern to bubble up a tensor.expand_shape op through a producer
145+
/// tensor.collapse_shape op that has non intersecting reassociations.
146+
struct BubbleUpExpandThroughParallelCollapse
147+
: public OpRewritePattern<tensor::ExpandShapeOp> {
148+
using OpRewritePattern<tensor::ExpandShapeOp>::OpRewritePattern;
149+
150+
LogicalResult matchAndRewrite(tensor::ExpandShapeOp expandOp,
151+
PatternRewriter &rewriter) const override {
152+
auto collapseOp =
153+
expandOp.getSrc().getDefiningOp<tensor::CollapseShapeOp>();
154+
if (!collapseOp)
155+
return failure();
156+
auto expandReInds = expandOp.getReassociationIndices();
157+
auto collapseReInds = collapseOp.getReassociationIndices();
158+
159+
// Reshapes are parallel to each other if none of the reassociation indices
160+
// have greater than 1 index for both reshapes.
161+
for (auto [expandReassociation, collapseReassociation] :
162+
llvm::zip_equal(expandReInds, collapseReInds)) {
163+
if (collapseReassociation.size() != 1 && expandReassociation.size() != 1)
164+
return failure();
165+
}
166+
167+
// Compute new reassociation indices and expanded/collaped shapes.
168+
SmallVector<ReassociationIndices> newExpandReInds, newCollapseReInds;
169+
Location loc = expandOp->getLoc();
170+
SmallVector<OpFoldResult> collapseSizes =
171+
tensor::getMixedSizes(rewriter, loc, collapseOp.getSrc());
172+
SmallVector<OpFoldResult> expandSizes(getMixedValues(
173+
expandOp.getStaticOutputShape(), expandOp.getOutputShape(), rewriter));
174+
SmallVector<OpFoldResult> newExpandSizes;
175+
int64_t index = 0, expandIndex = 0, collapseIndex = 0;
176+
for (auto [idx, collapseReassociation] : llvm::enumerate(collapseReInds)) {
177+
if (collapseReassociation.size() != 1) {
178+
ReassociationIndices newCollapseReassociation;
179+
for (size_t i = 0; i < collapseReassociation.size(); ++i) {
180+
newCollapseReassociation.push_back(index);
181+
newExpandReInds.push_back({index++});
182+
newExpandSizes.push_back(collapseSizes[collapseIndex++]);
183+
}
184+
newCollapseReInds.push_back(newCollapseReassociation);
185+
expandIndex++;
186+
continue;
187+
}
188+
ReassociationIndices newExpandReassociation;
189+
auto expandReassociation = expandReInds[idx];
190+
for (size_t i = 0; i < expandReassociation.size(); ++i) {
191+
newExpandReassociation.push_back(index);
192+
newCollapseReInds.push_back({index++});
193+
newExpandSizes.push_back(expandSizes[expandIndex++]);
194+
}
195+
newExpandReInds.push_back(newExpandReassociation);
196+
collapseIndex++;
197+
}
198+
199+
// Swap reshape order.
200+
SmallVector<Value> dynamicSizes;
201+
SmallVector<int64_t> staticSizes;
202+
dispatchIndexOpFoldResults(newExpandSizes, dynamicSizes, staticSizes);
203+
auto expandResultType = expandOp.getResultType().clone(staticSizes);
204+
auto newExpand = rewriter.create<tensor::ExpandShapeOp>(
205+
loc, expandResultType, collapseOp.getSrc(), newExpandReInds,
206+
newExpandSizes);
207+
rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
208+
expandOp, newExpand.getResult(), newCollapseReInds);
209+
return success();
210+
}
211+
};
212+
143213
} // namespace
144214

145215
void mlir::tensor::populateReassociativeReshapeFoldingPatterns(
@@ -152,3 +222,8 @@ void mlir::tensor::populateReassociativeReshapeFoldingPatterns(
152222
FoldPaddingExpandIntoInsert<tensor::ParallelInsertSliceOp>>(
153223
patterns.getContext());
154224
}
225+
226+
void mlir::tensor::populateBubbleUpExpandShapePatterns(
227+
RewritePatternSet &patterns) {
228+
patterns.add<BubbleUpExpandThroughParallelCollapse>(patterns.getContext());
229+
}
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
// RUN: mlir-opt -split-input-file -test-tensor-transform-patterns=test-expand-shape-bubbling %s | FileCheck %s
2+
3+
func.func @bubble_parallel_reshapes(%arg0: tensor<?x?x?x?xf32>, %s0: index, %s1: index, %s2: index, %s3: index) -> tensor<?x?x?x?xf32> {
4+
%collapse = tensor.collapse_shape %arg0 [[0], [1, 2], [3]] : tensor<?x?x?x?xf32> into tensor<?x?x?xf32>
5+
%expand = tensor.expand_shape %collapse [[0], [1], [2, 3]]
6+
output_shape [%s0, %s1, %s2, %s3] : tensor<?x?x?xf32> into tensor<?x?x?x?xf32>
7+
return %expand : tensor<?x?x?x?xf32>
8+
}
9+
// CHECK: func @bubble_parallel_reshapes
10+
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?x?xf32>
11+
// CHECK-SAME: %[[S0:.+]]: index, %[[S1:.+]]: index, %[[S2:.+]]: index, %[[S3:.+]]: index
12+
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
13+
// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
14+
// CHECK-DAG: %[[DIM1:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?x?x?xf32>
15+
// CHECK-DAG: %[[DIM2:.+]] = tensor.dim %[[ARG0]], %[[C2]] : tensor<?x?x?x?xf32>
16+
// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1], [2], [3, 4]]
17+
// CHECK-SAME: output_shape [%[[S0]], %[[DIM1]], %[[DIM2]], %[[S2]], %[[S3]]] : tensor<?x?x?x?xf32> into tensor<?x?x?x?x?xf32>
18+
// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[EXPAND]] {{\[}}[0], [1, 2], [3], [4]] : tensor<?x?x?x?x?xf32> into tensor<?x?x?x?xf32>
19+
// CHECK: return %[[COLLAPSE]]
20+
21+
// -----
22+
23+
func.func @no_bubble_full_intersecting_reshapes(%arg0: tensor<?x?x?x?xf32>, %s0: index, %s1: index, %s2: index, %s3: index) -> tensor<?x?x?x?xf32> {
24+
%collapse = tensor.collapse_shape %arg0 [[0], [1, 2], [3]] : tensor<?x?x?x?xf32> into tensor<?x?x?xf32>
25+
%expand = tensor.expand_shape %collapse [[0], [1, 2], [3]]
26+
output_shape [%s0, %s1, %s2, %s3] : tensor<?x?x?xf32> into tensor<?x?x?x?xf32>
27+
return %expand : tensor<?x?x?x?xf32>
28+
}
29+
// CHECK: func @no_bubble_full_intersecting_reshapes
30+
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?x?xf32>
31+
// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0], [1, 2], [3]]
32+
// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[COLLAPSE]] {{\[}}[0], [1, 2], [3]]
33+
// CHECK: return %[[EXPAND]]
34+
35+
// -----
36+
37+
func.func @no_bubble_partial_intersecting_reshapes(%arg0: tensor<?x?x?x?xf32>, %s0: index, %s1: index, %s2: index, %s3: index) -> tensor<?x?x?x?xf32> {
38+
%collapse = tensor.collapse_shape %arg0 [[0, 1, 2], [3]] : tensor<?x?x?x?xf32> into tensor<?x?xf32>
39+
%expand = tensor.expand_shape %collapse [[0, 1], [2, 3]]
40+
output_shape [%s0, %s1, %s2, %s3] : tensor<?x?xf32> into tensor<?x?x?x?xf32>
41+
return %expand : tensor<?x?x?x?xf32>
42+
}
43+
// CHECK: func @no_bubble_partial_intersecting_reshapes
44+
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?x?xf32>
45+
// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1, 2], [3]]
46+
// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[COLLAPSE]] {{\[}}[0, 1], [2, 3]]
47+
// CHECK: return %[[EXPAND]]

mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,11 @@ struct TestTensorTransforms
7272
llvm::cl::desc("Test folding of expand_shape/collapse_shape"),
7373
llvm::cl::init(false)};
7474

75+
Option<bool> testBubbleUpExpandShapePatterns{
76+
*this, "test-expand-shape-bubbling",
77+
llvm::cl::desc("Test folding of expand_shape/collapse_shape"),
78+
llvm::cl::init(false)};
79+
7580
Option<bool> testFoldIntoPackAndUnpack{
7681
*this, "test-fold-into-pack-and-unpack",
7782
llvm::cl::desc("Test folding ops into tensor.pack and tensor.unpack"),
@@ -102,6 +107,12 @@ static void applyReassociativeReshapeFoldingPatterns(Operation *rootOp) {
102107
(void)applyPatternsAndFoldGreedily(rootOp, std::move(patterns));
103108
}
104109

110+
static void applyBubbleUpExpandShapePatterns(Operation *rootOp) {
111+
RewritePatternSet patterns(rootOp->getContext());
112+
tensor::populateBubbleUpExpandShapePatterns(patterns);
113+
(void)applyPatternsAndFoldGreedily(rootOp, std::move(patterns));
114+
}
115+
105116
static void applyFoldIntoPackAndUnpackPatterns(Operation *rootOp) {
106117
RewritePatternSet patterns(rootOp->getContext());
107118
tensor::populateFoldIntoPackAndUnpackPatterns(patterns);
@@ -386,6 +397,8 @@ void TestTensorTransforms::runOnOperation() {
386397
applyDropRedundantInsertSliceRankExpansionPatterns(rootOp);
387398
if (testReassociativeReshapeFolding)
388399
applyReassociativeReshapeFoldingPatterns(rootOp);
400+
if (testBubbleUpExpandShapePatterns)
401+
applyBubbleUpExpandShapePatterns(rootOp);
389402
if (testFoldIntoPackAndUnpack)
390403
applyFoldIntoPackAndUnpackPatterns(rootOp);
391404
if (testRewriteExtractSliceWithTiledCollapseShape) {

0 commit comments

Comments
 (0)