42
42
#include " llvm/ADT/SmallVector.h"
43
43
#include " llvm/ADT/StringSet.h"
44
44
#include " llvm/ADT/TypeSwitch.h"
45
+ #include " llvm/Support/FormatVariadic.h"
45
46
46
47
#include < cassert>
47
48
#include < cstdint>
@@ -6172,49 +6173,57 @@ class FoldTransposeCreateMask final : public OpRewritePattern<TransposeOp> {
6172
6173
// / The algorithm works by partitioning dimensions into groups that can be
6173
6174
// / locally permuted while preserving order, and checks that the transpose
6174
6175
// / only permutes within these groups.
6176
+ // /
6177
+ // / Groups are either contiguous sequences of 1s, or non-1s (1-element groups).
6178
+ // / Consider broadcasting 4x1x1x7 to 2x3x4x5x6x7. This is equivalent to
6179
+ // / broadcasting from 1x1x4x1x1x7.
6180
+ // / ^^^ ^ ^^^ ^
6181
+ // / groups: 0 1 2 3
6182
+ // / Order preserving permutations for this example are ones that only permute
6183
+ // / within the groups [0,1] and [3,4], like (1 0 2 4 3 5 6).
6175
6184
class FoldTransposeBroadcast : public OpRewritePattern <vector::TransposeOp> {
6176
6185
public:
6177
6186
using OpRewritePattern::OpRewritePattern;
6178
6187
FoldTransposeBroadcast (MLIRContext *context, PatternBenefit benefit = 1 )
6179
6188
: OpRewritePattern<vector::TransposeOp>(context, benefit) {}
6180
6189
6181
- static bool canFoldIntoPrecedingBroadcast (vector::TransposeOp transpose) {
6190
+ LogicalResult matchAndRewrite (vector::TransposeOp transpose,
6191
+ PatternRewriter &rewriter) const override {
6182
6192
6183
6193
vector::BroadcastOp broadcast =
6184
6194
transpose.getVector ().getDefiningOp <vector::BroadcastOp>();
6185
- if (!broadcast)
6186
- return false ;
6195
+ if (!broadcast) {
6196
+ return rewriter.notifyMatchFailure (transpose,
6197
+ " not preceded by a broadcast" );
6198
+ }
6187
6199
6188
6200
auto inputType = dyn_cast<VectorType>(broadcast.getSourceType ());
6201
+
6202
+ // transpose(broadcast(scalar)) -> broadcast(scalar) is always valid
6189
6203
bool inputIsScalar = !inputType;
6204
+ if (inputIsScalar) {
6205
+ rewriter.replaceOpWithNewOp <vector::BroadcastOp>(
6206
+ transpose, transpose.getResultVectorType (), transpose.getVector ());
6207
+ return success ();
6208
+ }
6209
+
6210
+ ArrayRef<int64_t > permutation = transpose.getPermutation ();
6190
6211
ArrayRef<int64_t > inputShape = inputType.getShape ();
6191
6212
int64_t inputRank = inputType.getRank ();
6192
6213
int64_t outputRank = transpose.getType ().getRank ();
6193
6214
int64_t deltaRank = outputRank - inputRank;
6194
6215
6195
- // transpose(broadcast(scalar)) -> broadcast(scalar) is always valid
6196
- if (inputIsScalar)
6197
- return true ;
6198
-
6199
6216
// Return true if all permutation destinations for indices in [low, high)
6200
6217
// are in [low, high), so the permutation is local to the group.
6201
- auto isGroupBound = [&](int low, int high) {
6202
- ArrayRef<int64_t > permutation = transpose.getPermutation ();
6203
- for (int j = low; j < high; ++j) {
6204
- if (permutation[j] < low || permutation[j] >= high) {
6218
+ auto isGroupBound = [permutation](int low, int high) {
6219
+ for (int i = low; i < high; ++i) {
6220
+ if (permutation[i] < low || permutation[i] >= high) {
6205
6221
return false ;
6206
6222
}
6207
6223
}
6208
6224
return true ;
6209
6225
};
6210
6226
6211
- // Groups are either contiguous sequences of 1s and non-1s (1-element
6212
- // groups). Consider broadcasting 4x1x1x7 to 2x3x4x5x6x7. This is equivalent
6213
- // to broadcasting from 1x1x4x1x1x7.
6214
- // ^^^ ^ ^^^ ^
6215
- // groups: 0 1 2 3
6216
- // Order preserving permutations for this example are ones that only permute
6217
- // within the groups [0,1] and [3,4], like (1 0 2 4 3 5 6).
6218
6227
int low = 0 ;
6219
6228
for (int inputIndex = 0 ; inputIndex < inputRank; ++inputIndex) {
6220
6229
bool notOne = inputShape[inputIndex] != 1 ;
@@ -6223,32 +6232,29 @@ class FoldTransposeBroadcast : public OpRewritePattern<vector::TransposeOp> {
6223
6232
if (groupEndFound) {
6224
6233
int high = inputIndex + deltaRank;
6225
6234
if (!isGroupBound (low, high)) {
6226
- return false ;
6235
+ return rewriter.notifyMatchFailure (
6236
+ transpose, llvm::formatv (" output dimensions in interval [{0}, "
6237
+ " {1}) aren't locally permuted." ,
6238
+ low, high));
6227
6239
}
6228
6240
low = high;
6229
6241
}
6230
6242
}
6231
6243
if (!isGroupBound (low, outputRank)) {
6232
- return false ;
6244
+ return rewriter.notifyMatchFailure (
6245
+ transpose,
6246
+ llvm::formatv (" output dimensions in final interval [{0}, {1}) "
6247
+ " aren't locally permuted." ,
6248
+ low, outputRank));
6233
6249
}
6234
6250
6235
- // The preceding logic ensures that by this point, the ouutput of the
6236
- // transpose is definitely broadcastable from the input shape. So we don't
6237
- // need to call 'vector::isBroadcastableTo', but asserting here just as a
6238
- // sanity check:
6251
+ // The preceding logic ensures that at this point, the output of the
6252
+ // transpose is definitely broadcastable from the input shape. We confirm
6253
+ // this as a sanity check:
6239
6254
bool isBroadcastable =
6240
6255
vector::isBroadcastableTo (inputType, transpose.getResultVectorType ()) ==
6241
6256
vector::BroadcastableToResult::Success;
6242
- assert (isBroadcastable &&
6243
- " (I think) it must be broadcastable at this point." );
6244
-
6245
- return true ;
6246
- }
6247
-
6248
- LogicalResult matchAndRewrite (vector::TransposeOp transpose,
6249
- PatternRewriter &rewriter) const override {
6250
- if (!canFoldIntoPrecedingBroadcast (transpose))
6251
- return failure ();
6257
+ assert (isBroadcastable && " It must be broadcastable at this point." );
6252
6258
6253
6259
rewriter.replaceOpWithNewOp <vector::BroadcastOp>(
6254
6260
transpose, transpose.getResultVectorType (), transpose.getVector ());
0 commit comments