Skip to content

Commit 0daf20b

Browse files
authored
[mlir][vector] transpose(broadcast) -> broadcast canonicalization (#135096)
Example seen in the 'real world': ``` %0 = vector.broadcast %arg0 : vector<1xi8> to vector<1x8xi8> %1 = vector.transpose %0, [1, 0] : vector<1x8xi8> to vector<8x1xi8> ``` This PR adds a canonicalizer that rewrites the above as ``` %1 = vector.broadcast %arg0 : vector<1xi8> to vector<8x1xi8> ``` It works by determining if a transpose is only shuffling contiguous broadcast dimensions.
1 parent ed9bcb5 commit 0daf20b

File tree

3 files changed

+235
-48
lines changed

3 files changed

+235
-48
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 96 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -6085,28 +6085,6 @@ class TransposeFolder final : public OpRewritePattern<vector::TransposeOp> {
60856085
}
60866086
};
60876087

6088-
// Folds transpose(broadcast(<scalar>)) into broadcast(<scalar>).
6089-
struct FoldTransposedScalarBroadcast final
6090-
: public OpRewritePattern<vector::TransposeOp> {
6091-
using OpRewritePattern::OpRewritePattern;
6092-
6093-
LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
6094-
PatternRewriter &rewriter) const override {
6095-
auto bcastOp = transposeOp.getVector().getDefiningOp<vector::BroadcastOp>();
6096-
if (!bcastOp)
6097-
return failure();
6098-
6099-
auto srcVectorType = llvm::dyn_cast<VectorType>(bcastOp.getSourceType());
6100-
if (!srcVectorType || srcVectorType.getNumElements() == 1) {
6101-
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
6102-
transposeOp, transposeOp.getResultVectorType(), bcastOp.getSource());
6103-
return success();
6104-
}
6105-
6106-
return failure();
6107-
}
6108-
};
6109-
61106088
// Folds transpose(splat x : src_type) : res_type into splat x : res_type.
61116089
class FoldTransposeSplat final : public OpRewritePattern<TransposeOp> {
61126090
public:
@@ -6161,12 +6139,106 @@ class FoldTransposeCreateMask final : public OpRewritePattern<TransposeOp> {
61616139
}
61626140
};
61636141

6142+
/// Folds transpose(broadcast(x)) to broadcast(x) if the transpose is
6143+
/// 'order preserving', where 'order preserving' means the flattened
6144+
/// inputs and outputs of the transpose have identical (numerical) values.
6145+
///
6146+
/// Example:
6147+
/// ```
6148+
/// %0 = vector.broadcast %input : vector<1x1xi32> to vector<1x8xi32>
6149+
/// %1 = vector.transpose %0, [1, 0] : vector<1x8xi32>
6150+
/// to vector<8x1xi32>
6151+
/// ```
6152+
/// can be rewritten as the equivalent
6153+
/// ```
6154+
/// %0 = vector.broadcast %input : vector<1x1xi32> to vector<8x1xi32>.
6155+
/// ```
6156+
/// The algorithm works by partitioning dimensions into groups that can be
6157+
/// locally permuted while preserving order, and checks that the transpose
6158+
/// only permutes within these groups.
6159+
///
6160+
/// Groups are either contiguous sequences of 1s, or non-1s (1-element groups).
6161+
/// Consider broadcasting 4x1x1x7 to 2x3x4x5x6x7. This is equivalent to
6162+
/// broadcasting from 1x1x4x1x1x7.
6163+
/// ^^^ ^ ^^^ ^
6164+
/// groups: 0 1 2 3
6165+
/// Order preserving permutations for this example are ones that only permute
6166+
/// within the groups [0,1] and [3,4], like (1 0 2 4 3 5 6).
6167+
class FoldTransposeBroadcast : public OpRewritePattern<vector::TransposeOp> {
6168+
public:
6169+
using OpRewritePattern::OpRewritePattern;
6170+
FoldTransposeBroadcast(MLIRContext *context, PatternBenefit benefit = 1)
6171+
: OpRewritePattern<vector::TransposeOp>(context, benefit) {}
6172+
6173+
LogicalResult matchAndRewrite(vector::TransposeOp transpose,
6174+
PatternRewriter &rewriter) const override {
6175+
6176+
vector::BroadcastOp broadcast =
6177+
transpose.getVector().getDefiningOp<vector::BroadcastOp>();
6178+
if (!broadcast) {
6179+
return rewriter.notifyMatchFailure(transpose,
6180+
"not preceded by a broadcast");
6181+
}
6182+
6183+
auto inputType = dyn_cast<VectorType>(broadcast.getSourceType());
6184+
VectorType outputType = transpose.getResultVectorType();
6185+
6186+
// transpose(broadcast(scalar)) -> broadcast(scalar) is always valid
6187+
bool inputIsScalar = !inputType;
6188+
if (inputIsScalar) {
6189+
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(transpose, outputType,
6190+
transpose.getVector());
6191+
return success();
6192+
}
6193+
6194+
ArrayRef<int64_t> permutation = transpose.getPermutation();
6195+
ArrayRef<int64_t> inputShape = inputType.getShape();
6196+
int64_t inputRank = inputType.getRank();
6197+
int64_t outputRank = transpose.getType().getRank();
6198+
int64_t deltaRank = outputRank - inputRank;
6199+
6200+
int low = 0;
6201+
for (int inputIndex = 0; inputIndex < inputRank; ++inputIndex) {
6202+
bool notOne = inputShape[inputIndex] != 1;
6203+
bool prevNotOne = (inputIndex != 0 && inputShape[inputIndex - 1] != 1);
6204+
bool groupEndFound = notOne || prevNotOne;
6205+
if (groupEndFound) {
6206+
int high = inputIndex + deltaRank;
6207+
// Return failure if not all permutation destinations for indices in
6208+
// [low, high) are in [low, high), i.e. the permutation is not local to
6209+
// the group.
6210+
for (int i = low; i < high; ++i) {
6211+
if (permutation[i] < low || permutation[i] >= high) {
6212+
return rewriter.notifyMatchFailure(
6213+
transpose, "permutation not local to group");
6214+
}
6215+
}
6216+
}
6217+
}
6218+
6219+
// We don't need to check the final group [low, outputRank) because if it is
6220+
// not locally bound, there must be a preceding group that already failed
6221+
// the check (impossible to have just 1 non-locally bound group).
6222+
6223+
// The preceding logic also ensures that at this point, the output of the
6224+
// transpose is definitely broadcastable from the input shape, assert so:
6225+
assert(vector::isBroadcastableTo(inputType, outputType) ==
6226+
vector::BroadcastableToResult::Success &&
6227+
"not broadcastable directly to transpose output");
6228+
6229+
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(transpose, outputType,
6230+
transpose.getVector());
6231+
6232+
return success();
6233+
}
6234+
};
6235+
61646236
} // namespace
61656237

61666238
void vector::TransposeOp::getCanonicalizationPatterns(
61676239
RewritePatternSet &results, MLIRContext *context) {
6168-
results.add<FoldTransposeCreateMask, FoldTransposedScalarBroadcast,
6169-
TransposeFolder, FoldTransposeSplat>(context);
6240+
results.add<FoldTransposeCreateMask, TransposeFolder, FoldTransposeSplat,
6241+
FoldTransposeBroadcast>(context);
61706242
}
61716243

61726244
//===----------------------------------------------------------------------===//

mlir/test/Dialect/Vector/canonicalize.mlir

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -2218,30 +2218,6 @@ func.func @shuffle_nofold1(%v0 : vector<4xi32>, %v1 : vector<2xi32>) -> vector<5
22182218

22192219
// -----
22202220

2221-
// CHECK-LABEL: func @transpose_scalar_broadcast1
2222-
// CHECK-SAME: (%[[ARG:.+]]: vector<1xf32>)
2223-
// CHECK: %[[V:.+]] = vector.broadcast %[[ARG]] : vector<1xf32> to vector<1x8xf32>
2224-
// CHECK: return %[[V]] : vector<1x8xf32>
2225-
func.func @transpose_scalar_broadcast1(%value: vector<1xf32>) -> vector<1x8xf32> {
2226-
%bcast = vector.broadcast %value : vector<1xf32> to vector<8x1xf32>
2227-
%t = vector.transpose %bcast, [1, 0] : vector<8x1xf32> to vector<1x8xf32>
2228-
return %t : vector<1x8xf32>
2229-
}
2230-
2231-
// -----
2232-
2233-
// CHECK-LABEL: func @transpose_scalar_broadcast2
2234-
// CHECK-SAME: (%[[ARG:.+]]: f32)
2235-
// CHECK: %[[V:.+]] = vector.broadcast %[[ARG]] : f32 to vector<1x8xf32>
2236-
// CHECK: return %[[V]] : vector<1x8xf32>
2237-
func.func @transpose_scalar_broadcast2(%value: f32) -> vector<1x8xf32> {
2238-
%bcast = vector.broadcast %value : f32 to vector<8x1xf32>
2239-
%t = vector.transpose %bcast, [1, 0] : vector<8x1xf32> to vector<1x8xf32>
2240-
return %t : vector<1x8xf32>
2241-
}
2242-
2243-
// -----
2244-
22452221
// CHECK-LABEL: func @transpose_splat_constant
22462222
// CHECK: %[[CST:.+]] = arith.constant dense<5.000000e+00> : vector<8x4xf32>
22472223
// CHECK: return %[[CST]]
Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
// RUN: mlir-opt %s -canonicalize="test-convergence" -split-input-file -allow-unregistered-dialect | FileCheck %s
2+
3+
// This file contains some canonicalizations tests involving vector.transpose.
4+
5+
// CHECK-LABEL: func @transpose_scalar_broadcast1
6+
// CHECK-SAME: (%[[ARG:.+]]: vector<1xf32>)
7+
// CHECK: %[[V:.+]] = vector.broadcast %[[ARG]] : vector<1xf32> to vector<1x8xf32>
8+
// CHECK: return %[[V]] : vector<1x8xf32>
9+
func.func @transpose_scalar_broadcast1(%value: vector<1xf32>) -> vector<1x8xf32> {
10+
%bcast = vector.broadcast %value : vector<1xf32> to vector<8x1xf32>
11+
%t = vector.transpose %bcast, [1, 0] : vector<8x1xf32> to vector<1x8xf32>
12+
return %t : vector<1x8xf32>
13+
}
14+
15+
// -----
16+
17+
// CHECK-LABEL: func @transpose_scalar_broadcast2
18+
// CHECK-SAME: (%[[ARG:.+]]: f32)
19+
// CHECK: %[[V:.+]] = vector.broadcast %[[ARG]] : f32 to vector<1x8xf32>
20+
// CHECK: return %[[V]] : vector<1x8xf32>
21+
func.func @transpose_scalar_broadcast2(%value: f32) -> vector<1x8xf32> {
22+
%bcast = vector.broadcast %value : f32 to vector<8x1xf32>
23+
%t = vector.transpose %bcast, [1, 0] : vector<8x1xf32> to vector<1x8xf32>
24+
return %t : vector<1x8xf32>
25+
}
26+
27+
// -----
28+
29+
30+
// CHECK-LABEL: broadcast_transpose_scalar_to_broadcast
31+
// CHECK-SAME: %[[ARG:.*]]: i8) -> vector<2x3x4xi8> {
32+
func.func @broadcast_transpose_scalar_to_broadcast(%arg0 : i8) -> vector<2x3x4xi8> {
33+
// CHECK: %[[BC:.*]] = vector.broadcast %[[ARG]] : i8 to vector<2x3x4xi8>
34+
%0 = vector.broadcast %arg0 : i8 to vector<3x4x2xi8>
35+
%1 = vector.transpose %0, [2, 0, 1] : vector<3x4x2xi8> to vector<2x3x4xi8>
36+
// CHECK: return %[[BC]] : vector<2x3x4xi8>
37+
return %1 : vector<2x3x4xi8>
38+
}
39+
40+
// -----
41+
42+
// CHECK-LABEL: broadcast_transpose_ones_to_broadcast
43+
// CHECK-SAME: %[[ARG:.*]]: vector<1x1x1xi8>) -> vector<2x3x4xi8> {
44+
// CHECK: %[[RES:.*]] = vector.broadcast %[[ARG]] : vector<1x1x1xi8> to vector<2x3x4xi8>
45+
// CHECK: return %[[RES]] : vector<2x3x4xi8>
46+
func.func @broadcast_transpose_ones_to_broadcast(%arg0 : vector<1x1x1xi8>) -> vector<2x3x4xi8> {
47+
%0 = vector.broadcast %arg0 : vector<1x1x1xi8> to vector<3x4x2xi8>
48+
%1 = vector.transpose %0, [2, 0, 1] : vector<3x4x2xi8> to vector<2x3x4xi8>
49+
return %1 : vector<2x3x4xi8>
50+
}
51+
52+
// -----
53+
54+
// CHECK-LABEL: broadcast_transpose_partial_ones_to_broadcast
55+
// CHECK-SAME: %[[ARG:.*]]: vector<1xi8>) -> vector<8x1xi8> {
56+
// CHECK: %[[RES:.*]] = vector.broadcast %[[ARG]] : vector<1xi8> to vector<8x1xi8>
57+
// CHECK: return %[[RES]] : vector<8x1xi8>
58+
func.func @broadcast_transpose_partial_ones_to_broadcast(%arg0 : vector<1xi8>) -> vector<8x1xi8> {
59+
%0 = vector.broadcast %arg0 : vector<1xi8> to vector<1x8xi8>
60+
%1 = vector.transpose %0, [1, 0] : vector<1x8xi8> to vector<8x1xi8>
61+
return %1 : vector<8x1xi8>
62+
}
63+
64+
// -----
65+
66+
// CHECK-LABEL: broadcast_transpose_mixed_example
67+
// CHECK-SAME: %[[ARG:.*]]: vector<4x1x1x7xi8>) -> vector<3x2x4x5x6x7xi8> {
68+
// CHECK: %[[RES:.*]] = vector.broadcast %[[ARG]] : vector<4x1x1x7xi8> to vector<3x2x4x5x6x7xi8>
69+
// CHECK: return %[[RES]] : vector<3x2x4x5x6x7xi8>
70+
func.func @broadcast_transpose_mixed_example(%arg0 : vector<4x1x1x7xi8>) -> vector<3x2x4x5x6x7xi8> {
71+
%0 = vector.broadcast %arg0 : vector<4x1x1x7xi8> to vector<2x3x4x5x6x7xi8>
72+
%1 = vector.transpose %0, [1, 0, 2, 3, 4, 5] : vector<2x3x4x5x6x7xi8> to vector<3x2x4x5x6x7xi8>
73+
return %1 : vector<3x2x4x5x6x7xi8>
74+
}
75+
76+
// -----
77+
78+
// CHECK-LABEL: broadcast_transpose_final_group
79+
// CHECK-SAME: %[[ARG:.*]]: vector<4x7x1x1xi8>) -> vector<4x7x2x3xi8> {
80+
// CHECK: %[[RES:.*]] = vector.broadcast %[[ARG]] : vector<4x7x1x1xi8> to vector<4x7x2x3xi8>
81+
// CHECK: return %[[RES]] : vector<4x7x2x3xi8>
82+
func.func @broadcast_transpose_final_group(%arg0 : vector<4x7x1x1xi8>) -> vector<4x7x2x3xi8> {
83+
%0 = vector.broadcast %arg0 : vector<4x7x1x1xi8> to vector<4x7x3x2xi8>
84+
%1 = vector.transpose %0, [0, 1, 3, 2] : vector<4x7x3x2xi8> to vector<4x7x2x3xi8>
85+
return %1 : vector<4x7x2x3xi8>
86+
}
87+
88+
// -----
89+
90+
// CHECK-LABEL: negative_broadcast_transpose_square
91+
// CHECK-SAME: %[[ARG:.*]]:
92+
// CHECK: %[[BCT:.*]] = vector.broadcast %[[ARG]]
93+
// CHECK: %[[TRP:.*]] = vector.transpose %[[BCT]], [1, 0]
94+
// CHECK: return %[[TRP]] : vector<4x4xi8>
95+
func.func @negative_broadcast_transpose_square(%arg0 : vector<4x1xi8>) -> vector<4x4xi8> {
96+
%0 = vector.broadcast %arg0 : vector<4x1xi8> to vector<4x4xi8>
97+
%1 = vector.transpose %0, [1, 0] : vector<4x4xi8> to vector<4x4xi8>
98+
return %1 : vector<4x4xi8>
99+
}
100+
101+
// -----
102+
103+
// CHECK-LABEL: negative_broadcast_transpose_hypercube
104+
// CHECK-SAME: %[[ARG:.*]]:
105+
// CHECK: %[[BCT:.*]] = vector.broadcast %[[ARG]]
106+
// CHECK: %[[TRP:.*]] = vector.transpose %[[BCT]], [1, 0, 3, 2]
107+
// CHECK: return %[[TRP]] : vector<4x4x4x4xi8>
108+
func.func @negative_broadcast_transpose_hypercube(%arg0 : vector<1x1x4xi8>) -> vector<4x4x4x4xi8> {
109+
%0 = vector.broadcast %arg0 : vector<1x1x4xi8> to vector<4x4x4x4xi8>
110+
%1 = vector.transpose %0, [1, 0, 3, 2] : vector<4x4x4x4xi8> to vector<4x4x4x4xi8>
111+
return %1 : vector<4x4x4x4xi8>
112+
}
113+
114+
// -----
115+
116+
// CHECK-LABEL: negative_broadcast_transpose_102
117+
// CHECK-SAME: %[[ARG:.*]]:
118+
// CHECK: %[[BCT:.*]] = vector.broadcast %[[ARG]]
119+
// CHECK: %[[TRP:.*]] = vector.transpose %[[BCT]], [1, 0, 2]
120+
// CHECK: return %[[TRP]] : vector<3x3x3xi8>
121+
func.func @negative_broadcast_transpose_102(%arg0 : vector<3x1x3xi8>) -> vector<3x3x3xi8> {
122+
%0 = vector.broadcast %arg0 : vector<3x1x3xi8> to vector<3x3x3xi8>
123+
%1 = vector.transpose %0, [1, 0, 2] : vector<3x3x3xi8> to vector<3x3x3xi8>
124+
return %1 : vector<3x3x3xi8>
125+
}
126+
127+
// -----
128+
129+
// CHECK-LABEL: negative_broadcast_transpose_021
130+
// CHECK-SAME: %[[ARG:.*]]:
131+
// CHECK: %[[BCT:.*]] = vector.broadcast %[[ARG]]
132+
// CHECK: %[[TRP:.*]] = vector.transpose %[[BCT]], [0, 2, 1]
133+
// CHECK: return %[[TRP]] : vector<3x3x3xi8>
134+
func.func @negative_broadcast_transpose_021(%arg0 : vector<3x1x3xi8>) -> vector<3x3x3xi8> {
135+
%0 = vector.broadcast %arg0 : vector<3x1x3xi8> to vector<3x3x3xi8>
136+
%1 = vector.transpose %0, [0, 2, 1] : vector<3x3x3xi8> to vector<3x3x3xi8>
137+
return %1 : vector<3x3x3xi8>
138+
}
139+

0 commit comments

Comments
 (0)