Skip to content

Commit 2498d7d

Browse files
committed
transpose(broadcast) -> broadcast folder (squashed)
1 parent 70627af commit 2498d7d

File tree

2 files changed

+198
-1
lines changed

2 files changed

+198
-1
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 96 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6151,12 +6151,107 @@ class FoldTransposeCreateMask final : public OpRewritePattern<TransposeOp> {
61516151
}
61526152
};
61536153

6154+
/// Folds transpose(broadcast(x)) to broadcast(x) if the transpose is
6155+
/// 'order preserving', where 'order preserving' means the flattened
6156+
/// inputs and outputs of the transpose have identical (numerical) values.
6157+
///
6158+
/// Example:
6159+
/// ```
6160+
/// %0 = vector.broadcast %input : vector<1x1xi32> to vector<1x8xi32>
6161+
/// %1 = vector.transpose %0, [1, 0] : vector<1x8xi32>
6162+
/// to vector<8x1xi32>
6163+
/// ```
6164+
/// can be rewritten as the equivalent
6165+
/// ```
6166+
/// %0 = vector.broadcast %input : vector<1x1xi32> to vector<8x1xi32>.
6167+
/// ```
6168+
/// The algorithm works by partitioning dimensions into groups that can be
6169+
/// locally permuted while preserving order, and checks that the transpose
6170+
/// only permutes within these groups.
6171+
///
6172+
/// Groups are either contiguous sequences of 1s, or non-1s (1-element groups).
6173+
/// Consider broadcasting 4x1x1x7 to 2x3x4x5x6x7. This is equivalent to
6174+
/// broadcasting from 1x1x4x1x1x7.
6175+
/// ^^^ ^ ^^^ ^
6176+
/// groups: 0 1 2 3
6177+
/// Order preserving permutations for this example are ones that only permute
6178+
/// within the groups [0,1] and [3,4], like (1 0 2 4 3 5 6).
6179+
class FoldTransposeBroadcast : public OpRewritePattern<vector::TransposeOp> {
6180+
public:
6181+
using OpRewritePattern::OpRewritePattern;
6182+
FoldTransposeBroadcast(MLIRContext *context, PatternBenefit benefit = 1)
6183+
: OpRewritePattern<vector::TransposeOp>(context, benefit) {}
6184+
6185+
LogicalResult matchAndRewrite(vector::TransposeOp transpose,
6186+
PatternRewriter &rewriter) const override {
6187+
6188+
vector::BroadcastOp broadcast =
6189+
transpose.getVector().getDefiningOp<vector::BroadcastOp>();
6190+
if (!broadcast) {
6191+
return rewriter.notifyMatchFailure(transpose,
6192+
"not preceded by a broadcast");
6193+
}
6194+
6195+
auto inputType = dyn_cast<VectorType>(broadcast.getSourceType());
6196+
6197+
// transpose(broadcast(scalar)) -> broadcast(scalar) is always valid, and
6198+
// transpose(broadcast(all ones)) -> broadcast(all ones) is always valid
6199+
bool inputIsScalar = !inputType;
6200+
bool inputIsSizeOneVector = inputType.getNumElements() == 1;
6201+
if (inputIsScalar || inputIsSizeOneVector) {
6202+
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
6203+
transpose, transpose.getResultVectorType(), transpose.getVector());
6204+
return success();
6205+
}
6206+
6207+
ArrayRef<int64_t> permutation = transpose.getPermutation();
6208+
ArrayRef<int64_t> inputShape = inputType.getShape();
6209+
int64_t inputRank = inputType.getRank();
6210+
int64_t outputRank = transpose.getType().getRank();
6211+
int64_t deltaRank = outputRank - inputRank;
6212+
6213+
int low = 0;
6214+
for (int inputIndex = 0; inputIndex < inputRank; ++inputIndex) {
6215+
bool notOne = inputShape[inputIndex] != 1;
6216+
bool prevNotOne = (inputIndex != 0 && inputShape[inputIndex - 1] != 1);
6217+
bool groupEndFound = notOne || prevNotOne;
6218+
if (groupEndFound) {
6219+
// Return failure if all not permutation destinations for indices in
6220+
// [low, high) are in [low, high), i.e. the permutation is not local to
6221+
// the group.
6222+
int high = inputIndex + deltaRank;
6223+
for (int i = low; i < high; ++i) {
6224+
if (permutation[i] < low || permutation[i] >= high) {
6225+
return rewriter.notifyMatchFailure(
6226+
transpose, "permutation not local to group");
6227+
}
6228+
}
6229+
}
6230+
}
6231+
6232+
// We don't need to check the final group [low, outputRank) because
6233+
// if it is not locally bound, there must be a preceding group that
6234+
// already failed the check (impossible to have just 1 non-locally
6235+
// bound group).
6236+
6237+
// The preceding logic also ensures that at this point, the output of the
6238+
// transpose is definitely broadcastable from the input shape, so we
6239+
// don't need to check vector::isBroadcastableTo now.
6240+
6241+
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
6242+
transpose, transpose.getResultVectorType(), transpose.getVector());
6243+
6244+
return success();
6245+
}
6246+
};
6247+
61546248
} // namespace
61556249

61566250
void vector::TransposeOp::getCanonicalizationPatterns(
61576251
RewritePatternSet &results, MLIRContext *context) {
61586252
results.add<FoldTransposeCreateMask, FoldTransposedScalarBroadcast,
6159-
TransposeFolder, FoldTransposeSplat>(context);
6253+
TransposeFolder, FoldTransposeSplat, FoldTransposeBroadcast>(
6254+
context);
61606255
}
61616256

61626257
//===----------------------------------------------------------------------===//
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
// RUN: mlir-opt %s -canonicalize="test-convergence" -split-input-file -allow-unregistered-dialect | FileCheck %s
2+
3+
// This file contains some canonicalizations tests involving vector.transpose.
4+
5+
// CHECK-LABEL: scalar_broadcast_transpose_to_broadcast
6+
// CHECK-SAME: %[[ARG:.*]]: i8) -> vector<2x3x4xi8> {
7+
func.func @scalar_broadcast_transpose_to_broadcast(%arg0 : i8) -> vector<2x3x4xi8> {
8+
// CHECK: %[[BC:.*]] = vector.broadcast %[[ARG]] : i8 to vector<2x3x4xi8>
9+
%0 = vector.broadcast %arg0 : i8 to vector<3x4x2xi8>
10+
%1 = vector.transpose %0, [2, 0, 1] : vector<3x4x2xi8> to vector<2x3x4xi8>
11+
// CHECK: return %[[BC]] : vector<2x3x4xi8>
12+
return %1 : vector<2x3x4xi8>
13+
}
14+
15+
// -----
16+
17+
// CHECK-LABEL: ones_broadcast_transpose_to_broadcast
18+
// CHECK-SAME: %[[ARG:.*]]: vector<1x1x1xi8>) -> vector<2x3x4xi8> {
19+
// CHECK: %[[RES:.*]] = vector.broadcast %[[ARG]] : vector<1x1x1xi8> to vector<2x3x4xi8>
20+
// CHECK: return %[[RES]] : vector<2x3x4xi8>
21+
func.func @ones_broadcast_transpose_to_broadcast(%arg0 : vector<1x1x1xi8>) -> vector<2x3x4xi8> {
22+
%0 = vector.broadcast %arg0 : vector<1x1x1xi8> to vector<3x4x2xi8>
23+
%1 = vector.transpose %0, [2, 0, 1] : vector<3x4x2xi8> to vector<2x3x4xi8>
24+
return %1 : vector<2x3x4xi8>
25+
}
26+
27+
// -----
28+
29+
// CHECK-LABEL: partial_ones_broadcast_transpose_to_broadcast
30+
// CHECK-SAME: %[[ARG:.*]]: vector<1xi8>) -> vector<8x1xi8> {
31+
// CHECK: %[[RES:.*]] = vector.broadcast %[[ARG]] : vector<1xi8> to vector<8x1xi8>
32+
// CHECK: return %[[RES]] : vector<8x1xi8>
33+
func.func @partial_ones_broadcast_transpose_to_broadcast(%arg0 : vector<1xi8>) -> vector<8x1xi8> {
34+
%0 = vector.broadcast %arg0 : vector<1xi8> to vector<1x8xi8>
35+
%1 = vector.transpose %0, [1, 0] : vector<1x8xi8> to vector<8x1xi8>
36+
return %1 : vector<8x1xi8>
37+
}
38+
39+
// -----
40+
41+
// CHECK-LABEL: broadcast_transpose_mixed_example
42+
// CHECK-SAME: %[[ARG:.*]]: vector<4x1x1x7xi8>) -> vector<3x2x4x5x6x7xi8> {
43+
// CHECK: %[[RES:.*]] = vector.broadcast %[[ARG]] : vector<4x1x1x7xi8> to vector<3x2x4x5x6x7xi8>
44+
// CHECK: return %[[RES]] : vector<3x2x4x5x6x7xi8>
45+
func.func @broadcast_transpose_mixed_example(%arg0 : vector<4x1x1x7xi8>) -> vector<3x2x4x5x6x7xi8> {
46+
%0 = vector.broadcast %arg0 : vector<4x1x1x7xi8> to vector<2x3x4x5x6x7xi8>
47+
%1 = vector.transpose %0, [1, 0, 2, 3, 4, 5] : vector<2x3x4x5x6x7xi8> to vector<3x2x4x5x6x7xi8>
48+
return %1 : vector<3x2x4x5x6x7xi8>
49+
}
50+
51+
// -----
52+
53+
// CHECK-LABEL: negative_broadcast_transpose_square
54+
// CHECK-SAME: %[[ARG:.*]]:
55+
// CHECK: %[[BCT:.*]] = vector.broadcast %[[ARG]]
56+
// CHECK: %[[TRP:.*]] = vector.transpose %[[BCT]], [1, 0]
57+
// CHECK: return %[[TRP]] : vector<4x4xi8>
58+
func.func @negative_broadcast_transpose_square(%arg0 : vector<4x1xi8>) -> vector<4x4xi8> {
59+
%0 = vector.broadcast %arg0 : vector<4x1xi8> to vector<4x4xi8>
60+
%1 = vector.transpose %0, [1, 0] : vector<4x4xi8> to vector<4x4xi8>
61+
return %1 : vector<4x4xi8>
62+
}
63+
64+
// -----
65+
66+
// CHECK-LABEL: negative_broadcast_transpose_hypercube
67+
// CHECK-SAME: %[[ARG:.*]]:
68+
// CHECK: %[[BCT:.*]] = vector.broadcast %[[ARG]]
69+
// CHECK: %[[TRP:.*]] = vector.transpose %[[BCT]], [1, 0, 3, 2]
70+
// CHECK: return %[[TRP]] : vector<4x4x4x4xi8>
71+
func.func @negative_broadcast_transpose_hypercube(%arg0 : vector<1x1x4xi8>) -> vector<4x4x4x4xi8> {
72+
%0 = vector.broadcast %arg0 : vector<1x1x4xi8> to vector<4x4x4x4xi8>
73+
%1 = vector.transpose %0, [1, 0, 3, 2] : vector<4x4x4x4xi8> to vector<4x4x4x4xi8>
74+
return %1 : vector<4x4x4x4xi8>
75+
}
76+
77+
// -----
78+
79+
// CHECK-LABEL: negative_broadcast_transpose_102
80+
// CHECK-SAME: %[[ARG:.*]]:
81+
// CHECK: %[[BCT:.*]] = vector.broadcast %[[ARG]]
82+
// CHECK: %[[TRP:.*]] = vector.transpose %[[BCT]], [1, 0, 2]
83+
// CHECK: return %[[TRP]] : vector<3x3x3xi8>
84+
func.func @negative_broadcast_transpose_102(%arg0 : vector<3x1x3xi8>) -> vector<3x3x3xi8> {
85+
%0 = vector.broadcast %arg0 : vector<3x1x3xi8> to vector<3x3x3xi8>
86+
%1 = vector.transpose %0, [1, 0, 2] : vector<3x3x3xi8> to vector<3x3x3xi8>
87+
return %1 : vector<3x3x3xi8>
88+
}
89+
90+
// -----
91+
92+
// CHECK-LABEL: negative_broadcast_transpose_021
93+
// CHECK-SAME: %[[ARG:.*]]:
94+
// CHECK: %[[BCT:.*]] = vector.broadcast %[[ARG]]
95+
// CHECK: %[[TRP:.*]] = vector.transpose %[[BCT]], [0, 2, 1]
96+
// CHECK: return %[[TRP]] : vector<3x3x3xi8>
97+
func.func @neagtive_broadcast_transpose_021(%arg0 : vector<3x1x3xi8>) -> vector<3x3x3xi8> {
98+
%0 = vector.broadcast %arg0 : vector<3x1x3xi8> to vector<3x3x3xi8>
99+
%1 = vector.transpose %0, [0, 2, 1] : vector<3x3x3xi8> to vector<3x3x3xi8>
100+
return %1 : vector<3x3x3xi8>
101+
}
102+

0 commit comments

Comments
 (0)