Skip to content

Commit 7a90358

Browse files
committed
address review comments: notify match failure, and create new test file
1 parent e63d35b commit 7a90358

File tree

3 files changed

+144
-108
lines changed

3 files changed

+144
-108
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 40 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
#include "llvm/ADT/SmallVector.h"
4343
#include "llvm/ADT/StringSet.h"
4444
#include "llvm/ADT/TypeSwitch.h"
45+
#include "llvm/Support/FormatVariadic.h"
4546

4647
#include <cassert>
4748
#include <cstdint>
@@ -6172,49 +6173,57 @@ class FoldTransposeCreateMask final : public OpRewritePattern<TransposeOp> {
61726173
/// The algorithm works by partitioning dimensions into groups that can be
61736174
/// locally permuted while preserving order, and checks that the transpose
61746175
/// only permutes within these groups.
6176+
///
6177+
/// Groups are either contiguous sequences of 1s, or non-1s (1-element groups).
6178+
/// Consider broadcasting 4x1x1x7 to 2x3x4x5x6x7. This is equivalent to
6179+
/// broadcasting from 1x1x4x1x1x7.
6180+
/// ^^^ ^ ^^^ ^
6181+
/// groups: 0 1 2 3
6182+
/// Order preserving permutations for this example are ones that only permute
6183+
/// within the groups [0,1] and [3,4], like (1 0 2 4 3 5 6).
61756184
class FoldTransposeBroadcast : public OpRewritePattern<vector::TransposeOp> {
61766185
public:
61776186
using OpRewritePattern::OpRewritePattern;
61786187
FoldTransposeBroadcast(MLIRContext *context, PatternBenefit benefit = 1)
61796188
: OpRewritePattern<vector::TransposeOp>(context, benefit) {}
61806189

6181-
static bool canFoldIntoPrecedingBroadcast(vector::TransposeOp transpose) {
6190+
LogicalResult matchAndRewrite(vector::TransposeOp transpose,
6191+
PatternRewriter &rewriter) const override {
61826192

61836193
vector::BroadcastOp broadcast =
61846194
transpose.getVector().getDefiningOp<vector::BroadcastOp>();
6185-
if (!broadcast)
6186-
return false;
6195+
if (!broadcast) {
6196+
return rewriter.notifyMatchFailure(transpose,
6197+
"not preceded by a broadcast");
6198+
}
61876199

61886200
auto inputType = dyn_cast<VectorType>(broadcast.getSourceType());
6201+
6202+
// transpose(broadcast(scalar)) -> broadcast(scalar) is always valid
61896203
bool inputIsScalar = !inputType;
6204+
if (inputIsScalar) {
6205+
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
6206+
transpose, transpose.getResultVectorType(), transpose.getVector());
6207+
return success();
6208+
}
6209+
6210+
ArrayRef<int64_t> permutation = transpose.getPermutation();
61906211
ArrayRef<int64_t> inputShape = inputType.getShape();
61916212
int64_t inputRank = inputType.getRank();
61926213
int64_t outputRank = transpose.getType().getRank();
61936214
int64_t deltaRank = outputRank - inputRank;
61946215

6195-
// transpose(broadcast(scalar)) -> broadcast(scalar) is always valid
6196-
if (inputIsScalar)
6197-
return true;
6198-
61996216
// Return true if all permutation destinations for indices in [low, high)
62006217
// are in [low, high), so the permutation is local to the group.
6201-
auto isGroupBound = [&](int low, int high) {
6202-
ArrayRef<int64_t> permutation = transpose.getPermutation();
6203-
for (int j = low; j < high; ++j) {
6204-
if (permutation[j] < low || permutation[j] >= high) {
6218+
auto isGroupBound = [permutation](int low, int high) {
6219+
for (int i = low; i < high; ++i) {
6220+
if (permutation[i] < low || permutation[i] >= high) {
62056221
return false;
62066222
}
62076223
}
62086224
return true;
62096225
};
62106226

6211-
// Groups are either contiguous sequences of 1s and non-1s (1-element
6212-
// groups). Consider broadcasting 4x1x1x7 to 2x3x4x5x6x7. This is equivalent
6213-
// to broadcasting from 1x1x4x1x1x7.
6214-
// ^^^ ^ ^^^ ^
6215-
// groups: 0 1 2 3
6216-
// Order preserving permutations for this example are ones that only permute
6217-
// within the groups [0,1] and [3,4], like (1 0 2 4 3 5 6).
62186227
int low = 0;
62196228
for (int inputIndex = 0; inputIndex < inputRank; ++inputIndex) {
62206229
bool notOne = inputShape[inputIndex] != 1;
@@ -6223,32 +6232,29 @@ class FoldTransposeBroadcast : public OpRewritePattern<vector::TransposeOp> {
62236232
if (groupEndFound) {
62246233
int high = inputIndex + deltaRank;
62256234
if (!isGroupBound(low, high)) {
6226-
return false;
6235+
return rewriter.notifyMatchFailure(
6236+
transpose, llvm::formatv("output dimensions in interval [{0}, "
6237+
"{1}) aren't locally permuted.",
6238+
low, high));
62276239
}
62286240
low = high;
62296241
}
62306242
}
62316243
if (!isGroupBound(low, outputRank)) {
6232-
return false;
6244+
return rewriter.notifyMatchFailure(
6245+
transpose,
6246+
llvm::formatv("output dimensions in final interval [{0}, {1}) "
6247+
"aren't locally permuted.",
6248+
low, outputRank));
62336249
}
62346250

6235-
// The preceding logic ensures that by this point, the ouutput of the
6236-
// transpose is definitely broadcastable from the input shape. So we don't
6237-
// need to call 'vector::isBroadcastableTo', but asserting here just as a
6238-
// sanity check:
6251+
// The preceding logic ensures that at this point, the output of the
6252+
// transpose is definitely broadcastable from the input shape. We confirm
6253+
// this as a sanity check:
62396254
bool isBroadcastable =
62406255
vector::isBroadcastableTo(inputType, transpose.getResultVectorType()) ==
62416256
vector::BroadcastableToResult::Success;
6242-
assert(isBroadcastable &&
6243-
"(I think) it must be broadcastable at this point.");
6244-
6245-
return true;
6246-
}
6247-
6248-
LogicalResult matchAndRewrite(vector::TransposeOp transpose,
6249-
PatternRewriter &rewriter) const override {
6250-
if (!canFoldIntoPrecedingBroadcast(transpose))
6251-
return failure();
6257+
assert(isBroadcastable && "It must be broadcastable at this point.");
62526258

62536259
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
62546260
transpose, transpose.getResultVectorType(), transpose.getVector());

mlir/test/Dialect/Vector/canonicalize.mlir

Lines changed: 12 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,18 @@ func.func @create_vector_mask_to_constant_mask_scalable_all_true() -> (vector<8x
6666

6767
// -----
6868

69+
// CHECK-LABEL: scalar_broadcast_transpose_to_broadcast_folds
70+
// CHECK-SAME: %[[ARG:.*]]: i8) -> vector<2x3x4xi8> {
71+
func.func @scalar_broadcast_transpose_to_broadcast_folds(%arg0 : i8) -> vector<2x3x4xi8> {
72+
// CHECK: %[[BC:.*]] = vector.broadcast %[[ARG]] : i8 to vector<2x3x4xi8>
73+
%0 = vector.broadcast %arg0 : i8 to vector<3x4x2xi8>
74+
%1 = vector.transpose %0, [2, 0, 1] : vector<3x4x2xi8> to vector<2x3x4xi8>
75+
// CHECK: return %[[BC]] : vector<2x3x4xi8>
76+
return %1 : vector<2x3x4xi8>
77+
}
78+
79+
// -----
80+
6981
// CHECK-LABEL: create_mask_transpose_to_transposed_create_mask
7082
// CHECK-SAME: %[[DIM0:.*]]: index, %[[DIM1:.*]]: index, %[[DIM2:.*]]: index
7183
func.func @create_mask_transpose_to_transposed_create_mask(
@@ -2215,80 +2227,6 @@ func.func @transpose_splat2(%arg : f32) -> vector<3x4xf32> {
22152227

22162228
// -----
22172229

2218-
// CHECK-LABEL: scalar_broadcast_transpose_to_broadcast_folds
2219-
// CHECK-SAME: %[[ARG:.*]]: i8) -> vector<2x3x4xi8> {
2220-
// CHECK: %[[RES:.*]] = vector.broadcast %[[ARG]] : i8 to vector<2x3x4xi8>
2221-
// CHECK: return %[[RES]] : vector<2x3x4xi8>
2222-
func.func @scalar_broadcast_transpose_to_broadcast_folds(%arg0 : i8) -> vector<2x3x4xi8> {
2223-
%0 = vector.broadcast %arg0 : i8 to vector<3x4x2xi8>
2224-
%1 = vector.transpose %0, [2, 0, 1] : vector<3x4x2xi8> to vector<2x3x4xi8>
2225-
return %1 : vector<2x3x4xi8>
2226-
}
2227-
2228-
// -----
2229-
2230-
// CHECK-LABEL: ones_broadcast_transpose_to_broadcast_folds
2231-
// CHECK-SAME: %[[ARG:.*]]: vector<1x1x1xi8>) -> vector<2x3x4xi8> {
2232-
// CHECK: %[[RES:.*]] = vector.broadcast %[[ARG]] : vector<1x1x1xi8> to vector<2x3x4xi8>
2233-
// CHECK: return %[[RES]] : vector<2x3x4xi8>
2234-
func.func @ones_broadcast_transpose_to_broadcast_folds(%arg0 : vector<1x1x1xi8>) -> vector<2x3x4xi8> {
2235-
%0 = vector.broadcast %arg0 : vector<1x1x1xi8> to vector<3x4x2xi8>
2236-
%1 = vector.transpose %0, [2, 0, 1] : vector<3x4x2xi8> to vector<2x3x4xi8>
2237-
return %1 : vector<2x3x4xi8>
2238-
}
2239-
2240-
// -----
2241-
2242-
// CHECK-LABEL: partial_ones_broadcast_transpose_to_broadcast_folds
2243-
// CHECK-SAME: %[[ARG:.*]]: vector<1xi8>) -> vector<8x1xi8> {
2244-
// CHECK: %[[RES:.*]] = vector.broadcast %[[ARG]] : vector<1xi8> to vector<8x1xi8>
2245-
// CHECK: return %[[RES]] : vector<8x1xi8>
2246-
func.func @partial_ones_broadcast_transpose_to_broadcast_folds(%arg0 : vector<1xi8>) -> vector<8x1xi8> {
2247-
%0 = vector.broadcast %arg0 : vector<1xi8> to vector<1x8xi8>
2248-
%1 = vector.transpose %0, [1, 0] : vector<1x8xi8> to vector<8x1xi8>
2249-
return %1 : vector<8x1xi8>
2250-
}
2251-
2252-
// -----
2253-
2254-
// CHECK-LABEL: broadcast_transpose_mixed_example_folds
2255-
// CHECK-SAME: %[[ARG:.*]]: vector<4x1x1x7xi8>) -> vector<3x2x4x5x6x7xi8> {
2256-
// CHECK: %[[RES:.*]] = vector.broadcast %[[ARG]] : vector<4x1x1x7xi8> to vector<3x2x4x5x6x7xi8>
2257-
// CHECK: return %[[RES]] : vector<3x2x4x5x6x7xi8>
2258-
func.func @broadcast_transpose_mixed_example_folds(%arg0 : vector<4x1x1x7xi8>) -> vector<3x2x4x5x6x7xi8> {
2259-
%0 = vector.broadcast %arg0 : vector<4x1x1x7xi8> to vector<2x3x4x5x6x7xi8>
2260-
%1 = vector.transpose %0, [1, 0, 2, 3, 4, 5] : vector<2x3x4x5x6x7xi8> to vector<3x2x4x5x6x7xi8>
2261-
return %1 : vector<3x2x4x5x6x7xi8>
2262-
}
2263-
2264-
// -----
2265-
2266-
// CHECK-LABEL: broadcast_transpose_102_nofold
2267-
// CHECK-SAME: %[[ARG:.*]]:
2268-
// CHECK: %[[BCT:.*]] = vector.broadcast %[[ARG]]
2269-
// CHECK: %[[TRP:.*]] = vector.transpose %[[BCT]], [1, 0, 2]
2270-
// CHECK: return %[[TRP]] : vector<3x3x3xi8>
2271-
func.func @broadcast_transpose_102_nofold(%arg0 : vector<3x1x3xi8>) -> vector<3x3x3xi8> {
2272-
%0 = vector.broadcast %arg0 : vector<3x1x3xi8> to vector<3x3x3xi8>
2273-
%1 = vector.transpose %0, [1, 0, 2] : vector<3x3x3xi8> to vector<3x3x3xi8>
2274-
return %1 : vector<3x3x3xi8>
2275-
}
2276-
2277-
// -----
2278-
2279-
// CHECK-LABEL: broadcast_transpose_021_nofold
2280-
// CHECK-SAME: %[[ARG:.*]]:
2281-
// CHECK: %[[BCT:.*]] = vector.broadcast %[[ARG]]
2282-
// CHECK: %[[TRP:.*]] = vector.transpose %[[BCT]], [0, 2, 1]
2283-
// CHECK: return %[[TRP]] : vector<3x3x3xi8>
2284-
func.func @broadcast_transpose_021_nofold(%arg0 : vector<3x1x3xi8>) -> vector<3x3x3xi8> {
2285-
%0 = vector.broadcast %arg0 : vector<3x1x3xi8> to vector<3x3x3xi8>
2286-
%1 = vector.transpose %0, [0, 2, 1] : vector<3x3x3xi8> to vector<3x3x3xi8>
2287-
return %1 : vector<3x3x3xi8>
2288-
}
2289-
2290-
// -----
2291-
22922230
// CHECK-LABEL: func.func @insert_1d_constant
22932231
// CHECK-DAG: %[[ACST:.*]] = arith.constant dense<[9, 1, 2]> : vector<3xi32>
22942232
// CHECK-DAG: %[[BCST:.*]] = arith.constant dense<[0, 9, 2]> : vector<3xi32>
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
// RUN: mlir-opt %s -canonicalize="test-convergence" -split-input-file -allow-unregistered-dialect | FileCheck %s
2+
3+
// This file contains some (but not all) tests of canonicalizations that eliminate vector.transpose.
4+
5+
intentional bug to sanity check CI picks this new test up
6+
7+
// CHECK-LABEL: ones_broadcast_transpose_to_broadcast_folds
8+
// CHECK-SAME: %[[ARG:.*]]: vector<1x1x1xi8>) -> vector<2x3x4xi8> {
9+
// CHECK: %[[RES:.*]] = vector.broadcast %[[ARG]] : vector<1x1x1xi8> to vector<2x3x4xi8>
10+
// CHECK: return %[[RES]] : vector<2x3x4xi8>
11+
func.func @ones_broadcast_transpose_to_broadcast_folds(%arg0 : vector<1x1x1xi8>) -> vector<2x3x4xi8> {
12+
%0 = vector.broadcast %arg0 : vector<1x1x1xi8> to vector<3x4x2xi8>
13+
%1 = vector.transpose %0, [2, 0, 1] : vector<3x4x2xi8> to vector<2x3x4xi8>
14+
return %1 : vector<2x3x4xi8>
15+
}
16+
17+
// -----
18+
19+
// CHECK-LABEL: partial_ones_broadcast_transpose_to_broadcast_folds
20+
// CHECK-SAME: %[[ARG:.*]]: vector<1xi8>) -> vector<8x1xi8> {
21+
// CHECK: %[[RES:.*]] = vector.broadcast %[[ARG]] : vector<1xi8> to vector<8x1xi8>
22+
// CHECK: return %[[RES]] : vector<8x1xi8>
23+
func.func @partial_ones_broadcast_transpose_to_broadcast_folds(%arg0 : vector<1xi8>) -> vector<8x1xi8> {
24+
%0 = vector.broadcast %arg0 : vector<1xi8> to vector<1x8xi8>
25+
%1 = vector.transpose %0, [1, 0] : vector<1x8xi8> to vector<8x1xi8>
26+
return %1 : vector<8x1xi8>
27+
}
28+
29+
// -----
30+
31+
// CHECK-LABEL: broadcast_transpose_mixed_example_folds
32+
// CHECK-SAME: %[[ARG:.*]]: vector<4x1x1x7xi8>) -> vector<3x2x4x5x6x7xi8> {
33+
// CHECK: %[[RES:.*]] = vector.broadcast %[[ARG]] : vector<4x1x1x7xi8> to vector<3x2x4x5x6x7xi8>
34+
// CHECK: return %[[RES]] : vector<3x2x4x5x6x7xi8>
35+
func.func @broadcast_transpose_mixed_example_folds(%arg0 : vector<4x1x1x7xi8>) -> vector<3x2x4x5x6x7xi8> {
36+
%0 = vector.broadcast %arg0 : vector<4x1x1x7xi8> to vector<2x3x4x5x6x7xi8>
37+
%1 = vector.transpose %0, [1, 0, 2, 3, 4, 5] : vector<2x3x4x5x6x7xi8> to vector<3x2x4x5x6x7xi8>
38+
return %1 : vector<3x2x4x5x6x7xi8>
39+
}
40+
41+
// -----
42+
43+
// CHECK-LABEL: broadcast_transpose_square_nofold
44+
// CHECK-SAME: %[[ARG:.*]]:
45+
// CHECK: %[[BCT:.*]] = vector.broadcast %[[ARG]]
46+
// CHECK: %[[TRP:.*]] = vector.transpose %[[BCT]], [1, 0]
47+
// CHECK: return %[[TRP]] : vector<4x4xi8>
48+
func.func @broadcast_transpose_square_nofold(%arg0 : vector<4x1xi8>) -> vector<4x4xi8> {
49+
%0 = vector.broadcast %arg0 : vector<4x1xi8> to vector<4x4xi8>
50+
%1 = vector.transpose %0, [1, 0] : vector<4x4xi8> to vector<4x4xi8>
51+
return %1 : vector<4x4xi8>
52+
}
53+
54+
// -----
55+
56+
// CHECK-LABEL: broadcast_transpose_hypercube_nofold
57+
// CHECK-SAME: %[[ARG:.*]]:
58+
// CHECK: %[[BCT:.*]] = vector.broadcast %[[ARG]]
59+
// CHECK: %[[TRP:.*]] = vector.transpose %[[BCT]], [1, 0, 3, 2]
60+
// CHECK: return %[[TRP]] : vector<4x4x4x4xi8>
61+
func.func @broadcast_transpose_hypercube_nofold(%arg0 : vector<1x1x4xi8>) -> vector<4x4x4x4xi8> {
62+
%0 = vector.broadcast %arg0 : vector<1x1x4xi8> to vector<4x4x4x4xi8>
63+
%1 = vector.transpose %0, [1, 0, 3, 2] : vector<4x4x4x4xi8> to vector<4x4x4x4xi8>
64+
return %1 : vector<4x4x4x4xi8>
65+
}
66+
67+
// -----
68+
69+
// CHECK-LABEL: broadcast_transpose_102_nofold
70+
// CHECK-SAME: %[[ARG:.*]]:
71+
// CHECK: %[[BCT:.*]] = vector.broadcast %[[ARG]]
72+
// CHECK: %[[TRP:.*]] = vector.transpose %[[BCT]], [1, 0, 2]
73+
// CHECK: return %[[TRP]] : vector<3x3x3xi8>
74+
func.func @broadcast_transpose_102_nofold(%arg0 : vector<3x1x3xi8>) -> vector<3x3x3xi8> {
75+
%0 = vector.broadcast %arg0 : vector<3x1x3xi8> to vector<3x3x3xi8>
76+
%1 = vector.transpose %0, [1, 0, 2] : vector<3x3x3xi8> to vector<3x3x3xi8>
77+
return %1 : vector<3x3x3xi8>
78+
}
79+
80+
// -----
81+
82+
// CHECK-LABEL: broadcast_transpose_021_nofold
83+
// CHECK-SAME: %[[ARG:.*]]:
84+
// CHECK: %[[BCT:.*]] = vector.broadcast %[[ARG]]
85+
// CHECK: %[[TRP:.*]] = vector.transpose %[[BCT]], [0, 2, 1]
86+
// CHECK: return %[[TRP]] : vector<3x3x3xi8>
87+
func.func @broadcast_transpose_021_nofold(%arg0 : vector<3x1x3xi8>) -> vector<3x3x3xi8> {
88+
%0 = vector.broadcast %arg0 : vector<3x1x3xi8> to vector<3x3x3xi8>
89+
%1 = vector.transpose %0, [0, 2, 1] : vector<3x3x3xi8> to vector<3x3x3xi8>
90+
return %1 : vector<3x3x3xi8>
91+
}
92+

0 commit comments

Comments
 (0)