Skip to content

Commit 33a3782

Browse files
committed
transpose(broadcast) -> broadcast folder
Signed-off-by: James Newling <[email protected]>
1 parent 7cbf78e commit 33a3782

File tree

2 files changed

+177
-1
lines changed

2 files changed

+177
-1
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 100 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2609,6 +2609,7 @@ struct BroadcastFolder : public OpRewritePattern<BroadcastOp> {
26092609
return success();
26102610
}
26112611
};
2612+
26122613
} // namespace
26132614

26142615
void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
@@ -6155,12 +6156,110 @@ class FoldTransposeCreateMask final : public OpRewritePattern<TransposeOp> {
61556156
}
61566157
};
61576158

6159+
/// Folds transpose(broadcast(x)) into broadcast(x) if the transpose is
6160+
/// 'order preserving', where 'order preserving' here means the flattened
6161+
/// inputs and outputs of the transpose have identical values.
6162+
///
6163+
/// Example:
6164+
/// ```
6165+
/// %0 = vector.broadcast %input : vector<1x1xi32> to vector<1x8xi32>
6166+
/// %1 = vector.transpose %0, [1, 0] : vector<1x8xi32>
6167+
/// to vector<8x1xi32>
6168+
/// ```
6169+
/// can be rewritten as the equivalent
6170+
/// ```
6171+
/// %0 = vector.broadcast %input : vector<1x1xi32> to vector<8x1xi32>.
6172+
/// ```
6173+
/// The algorithm works by partitioning dimensions into groups that can be
6174+
/// locally permuted while preserving order, and checks that the transpose
6175+
/// only permutes within these groups.
6176+
class FoldTransposeBroadcast : public OpRewritePattern<vector::TransposeOp> {
6177+
public:
6178+
using OpRewritePattern::OpRewritePattern;
6179+
FoldTransposeBroadcast(MLIRContext *context, PatternBenefit benefit = 1)
6180+
: OpRewritePattern<vector::TransposeOp>(context, benefit) {}
6181+
6182+
static bool canFoldIntoPrecedingBroadcast(vector::TransposeOp transpose) {
6183+
6184+
vector::BroadcastOp broadcast =
6185+
transpose.getVector().getDefiningOp<vector::BroadcastOp>();
6186+
if (!broadcast)
6187+
return false;
6188+
6189+
auto inputType = dyn_cast<VectorType>(broadcast.getSourceType());
6190+
bool inputIsScalar = !inputType;
6191+
auto inputShape = inputType.getShape();
6192+
auto inputRank = inputType.getRank();
6193+
auto outputRank = transpose.getType().getRank();
6194+
auto deltaRank = outputRank - inputRank;
6195+
6196+
// transpose(broadcast(scalar)) -> broadcast(scalar) is always valid
6197+
if (inputIsScalar)
6198+
return true;
6199+
6200+
// Return true if all permutation destinations for indices in [low, high)
6201+
// are in [low, high), so the permutation is local to the group.
6202+
auto isGroupBound = [&](int low, int high) {
6203+
auto perm = transpose.getPermutation();
6204+
for (int j = low; j < high; ++j) {
6205+
if (perm[j] < low || perm[j] >= high) {
6206+
return false;
6207+
}
6208+
}
6209+
return true;
6210+
};
6211+
6212+
// Groups are either contiguous sequences of 1s and non-1s (1-element
6213+
// groups). Consider broadcasting 4x1x1x7 to 2x3x4x5x6x7. This is equivalent
6214+
// to broadcasting from 1x1x4x1x1x7.
6215+
// ^^^ ^ ^^^ ^
6216+
// groups: 0 1 2 3
6217+
// Order preserving permutations for this example are ones that only permute
6218+
// within the groups [0,1] and [3,4], like (1 0 2 4 3 5 6).
6219+
int low = 0;
6220+
for (int inputIndex = 0; inputIndex < inputRank; ++inputIndex) {
6221+
bool notOne = inputShape[inputIndex] != 1;
6222+
bool prevNotOne = (inputIndex != 0 && inputShape[inputIndex - 1] != 1);
6223+
bool groupEndFound = notOne || prevNotOne;
6224+
if (groupEndFound) {
6225+
int high = inputIndex + deltaRank;
6226+
if (!isGroupBound(low, high)) {
6227+
return false;
6228+
}
6229+
low = high;
6230+
}
6231+
}
6232+
if (!isGroupBound(low, outputRank)) {
6233+
return false;
6234+
}
6235+
6236+
bool isBroadcastable =
6237+
vector::isBroadcastableTo(inputType, transpose.getResultVectorType()) ==
6238+
vector::BroadcastableToResult::Success;
6239+
assert(isBroadcastable && "it should be broadcastable at this point");
6240+
6241+
return true;
6242+
}
6243+
6244+
LogicalResult matchAndRewrite(vector::TransposeOp transpose,
6245+
PatternRewriter &rewriter) const override {
6246+
if (!canFoldIntoPrecedingBroadcast(transpose))
6247+
return failure();
6248+
6249+
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
6250+
transpose, transpose.getResultVectorType(), transpose.getVector());
6251+
6252+
return success();
6253+
}
6254+
};
6255+
61586256
} // namespace
61596257

61606258
void vector::TransposeOp::getCanonicalizationPatterns(
61616259
RewritePatternSet &results, MLIRContext *context) {
61626260
results.add<FoldTransposeCreateMask, FoldTransposedScalarBroadcast,
6163-
TransposeFolder, FoldTransposeSplat>(context);
6261+
TransposeFolder, FoldTransposeSplat, FoldTransposeBroadcast>(
6262+
context);
61646263
}
61656264

61666265
//===----------------------------------------------------------------------===//

mlir/test/Dialect/Vector/canonicalize.mlir

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
// RUN: mlir-opt %s -canonicalize="test-convergence" -split-input-file -allow-unregistered-dialect | FileCheck %s
22

3+
4+
5+
36
// CHECK-LABEL: create_vector_mask_to_constant_mask
47
func.func @create_vector_mask_to_constant_mask() -> (vector<4x3xi1>) {
58
%c2 = arith.constant 2 : index
@@ -2215,6 +2218,80 @@ func.func @transpose_splat2(%arg : f32) -> vector<3x4xf32> {
22152218

22162219
// -----
22172220

2221+
// CHECK-LABEL: scalar_broadcast_transpose_to_broadcast_folds
2222+
// CHECK-SAME: %[[ARG:.*]]: i8) -> vector<2x3x4xi8> {
2223+
// CHECK: %[[RES:.*]] = vector.broadcast %[[ARG]] : i8 to vector<2x3x4xi8>
2224+
// CHECK: return %[[RES]] : vector<2x3x4xi8>
2225+
func.func @scalar_broadcast_transpose_to_broadcast_folds(%arg0 : i8) -> vector<2x3x4xi8> {
2226+
%0 = vector.broadcast %arg0 : i8 to vector<3x4x2xi8>
2227+
%1 = vector.transpose %0, [2, 0, 1] : vector<3x4x2xi8> to vector<2x3x4xi8>
2228+
return %1 : vector<2x3x4xi8>
2229+
}
2230+
2231+
// -----
2232+
2233+
// CHECK-LABEL: ones_broadcast_transpose_to_broadcast_folds
2234+
// CHECK-SAME: %[[ARG:.*]]: vector<1x1x1xi8>) -> vector<2x3x4xi8> {
2235+
// CHECK: %[[RES:.*]] = vector.broadcast %[[ARG]] : vector<1x1x1xi8> to vector<2x3x4xi8>
2236+
// CHECK: return %[[RES]] : vector<2x3x4xi8>
2237+
func.func @ones_broadcast_transpose_to_broadcast_folds(%arg0 : vector<1x1x1xi8>) -> vector<2x3x4xi8> {
2238+
%0 = vector.broadcast %arg0 : vector<1x1x1xi8> to vector<3x4x2xi8>
2239+
%1 = vector.transpose %0, [2, 0, 1] : vector<3x4x2xi8> to vector<2x3x4xi8>
2240+
return %1 : vector<2x3x4xi8>
2241+
}
2242+
2243+
// -----
2244+
2245+
// CHECK-LABEL: partial_ones_broadcast_transpose_to_broadcast_folds
2246+
// CHECK-SAME: %[[ARG:.*]]: vector<1xi8>) -> vector<8x1xi8> {
2247+
// CHECK: %[[RES:.*]] = vector.broadcast %[[ARG]] : vector<1xi8> to vector<8x1xi8>
2248+
// CHECK: return %[[RES]] : vector<8x1xi8>
2249+
func.func @partial_ones_broadcast_transpose_to_broadcast_folds(%arg0 : vector<1xi8>) -> vector<8x1xi8> {
2250+
%0 = vector.broadcast %arg0 : vector<1xi8> to vector<1x8xi8>
2251+
%1 = vector.transpose %0, [1, 0] : vector<1x8xi8> to vector<8x1xi8>
2252+
return %1 : vector<8x1xi8>
2253+
}
2254+
2255+
// -----
2256+
2257+
// CHECK-LABEL: broadcast_transpose_mixed_example_folds
2258+
// CHECK-SAME: %[[ARG:.*]]: vector<4x1x1x7xi8>) -> vector<3x2x4x5x6x7xi8> {
2259+
// CHECK: %[[RES:.*]] = vector.broadcast %[[ARG]] : vector<4x1x1x7xi8> to vector<3x2x4x5x6x7xi8>
2260+
// CHECK: return %[[RES]] : vector<3x2x4x5x6x7xi8>
2261+
func.func @broadcast_transpose_mixed_example_folds(%arg0 : vector<4x1x1x7xi8>) -> vector<3x2x4x5x6x7xi8> {
2262+
%0 = vector.broadcast %arg0 : vector<4x1x1x7xi8> to vector<2x3x4x5x6x7xi8>
2263+
%1 = vector.transpose %0, [1, 0, 2, 3, 4, 5] : vector<2x3x4x5x6x7xi8> to vector<3x2x4x5x6x7xi8>
2264+
return %1 : vector<3x2x4x5x6x7xi8>
2265+
}
2266+
2267+
// -----
2268+
2269+
// CHECK-LABEL: broadcast_transpose_102_nofold
2270+
// CHECK-SAME: %[[ARG:.*]]:
2271+
// CHECK: %[[BCT:.*]] = vector.broadcast %[[ARG]]
2272+
// CHECK: %[[TRP:.*]] = vector.transpose %[[BCT]], [1, 0, 2]
2273+
// CHECK: return %[[TRP]] : vector<3x3x3xi8>
2274+
func.func @broadcast_transpose_102_nofold(%arg0 : vector<3x1x3xi8>) -> vector<3x3x3xi8> {
2275+
%0 = vector.broadcast %arg0 : vector<3x1x3xi8> to vector<3x3x3xi8>
2276+
%1 = vector.transpose %0, [1, 0, 2] : vector<3x3x3xi8> to vector<3x3x3xi8>
2277+
return %1 : vector<3x3x3xi8>
2278+
}
2279+
2280+
// -----
2281+
2282+
// CHECK-LABEL: broadcast_transpose_021_nofold
2283+
// CHECK-SAME: %[[ARG:.*]]:
2284+
// CHECK: %[[BCT:.*]] = vector.broadcast %[[ARG]]
2285+
// CHECK: %[[TRP:.*]] = vector.transpose %[[BCT]], [0, 2, 1]
2286+
// CHECK: return %[[TRP]] : vector<3x3x3xi8>
2287+
func.func @broadcast_transpose_021_nofold(%arg0 : vector<3x1x3xi8>) -> vector<3x3x3xi8> {
2288+
%0 = vector.broadcast %arg0 : vector<3x1x3xi8> to vector<3x3x3xi8>
2289+
%1 = vector.transpose %0, [0, 2, 1] : vector<3x3x3xi8> to vector<3x3x3xi8>
2290+
return %1 : vector<3x3x3xi8>
2291+
}
2292+
2293+
// -----
2294+
22182295
// CHECK-LABEL: func.func @insert_1d_constant
22192296
// CHECK-DAG: %[[ACST:.*]] = arith.constant dense<[9, 1, 2]> : vector<3xi32>
22202297
// CHECK-DAG: %[[BCST:.*]] = arith.constant dense<[0, 9, 2]> : vector<3xi32>

0 commit comments

Comments
 (0)