@@ -2609,7 +2609,6 @@ struct BroadcastFolder : public OpRewritePattern<BroadcastOp> {
2609
2609
return success ();
2610
2610
}
2611
2611
};
2612
-
2613
2612
} // namespace
2614
2613
2615
2614
void BroadcastOp::getCanonicalizationPatterns (RewritePatternSet &results,
@@ -6156,9 +6155,9 @@ class FoldTransposeCreateMask final : public OpRewritePattern<TransposeOp> {
6156
6155
}
6157
6156
};
6158
6157
6159
- // / Folds transpose(broadcast(x)) into broadcast(x) if the transpose is
6160
- // / 'order preserving', where 'order preserving' here means the flattened
6161
- // / inputs and outputs of the transpose have identical values.
6158
+ // / Folds transpose(broadcast(x)) to broadcast(x) if the transpose is
6159
+ // / 'order preserving', where 'order preserving' means the flattened
6160
+ // / inputs and outputs of the transpose have identical (numerical) values.
6162
6161
// /
6163
6162
// / Example:
6164
6163
// / ```
@@ -6188,10 +6187,10 @@ class FoldTransposeBroadcast : public OpRewritePattern<vector::TransposeOp> {
6188
6187
6189
6188
auto inputType = dyn_cast<VectorType>(broadcast.getSourceType ());
6190
6189
bool inputIsScalar = !inputType;
6191
- auto inputShape = inputType.getShape ();
6192
- auto inputRank = inputType.getRank ();
6193
- auto outputRank = transpose.getType ().getRank ();
6194
- auto deltaRank = outputRank - inputRank;
6190
+ ArrayRef< int64_t > inputShape = inputType.getShape ();
6191
+ int64_t inputRank = inputType.getRank ();
6192
+ int64_t outputRank = transpose.getType ().getRank ();
6193
+ int64_t deltaRank = outputRank - inputRank;
6195
6194
6196
6195
// transpose(broadcast(scalar)) -> broadcast(scalar) is always valid
6197
6196
if (inputIsScalar)
@@ -6200,9 +6199,9 @@ class FoldTransposeBroadcast : public OpRewritePattern<vector::TransposeOp> {
6200
6199
// Return true if all permutation destinations for indices in [low, high)
6201
6200
// are in [low, high), so the permutation is local to the group.
6202
6201
auto isGroupBound = [&](int low, int high) {
6203
- auto perm = transpose.getPermutation ();
6202
+ ArrayRef< int64_t > permutation = transpose.getPermutation ();
6204
6203
for (int j = low; j < high; ++j) {
6205
- if (perm [j] < low || perm [j] >= high) {
6204
+ if (permutation [j] < low || permutation [j] >= high) {
6206
6205
return false ;
6207
6206
}
6208
6207
}
@@ -6233,10 +6232,15 @@ class FoldTransposeBroadcast : public OpRewritePattern<vector::TransposeOp> {
6233
6232
return false ;
6234
6233
}
6235
6234
6235
+ // The preceding logic ensures that by this point, the ouutput of the
6236
+ // transpose is definitely broadcastable from the input shape. So we don't
6237
+ // need to call 'vector::isBroadcastableTo', but asserting here just as a
6238
+ // sanity check:
6236
6239
bool isBroadcastable =
6237
6240
vector::isBroadcastableTo (inputType, transpose.getResultVectorType ()) ==
6238
6241
vector::BroadcastableToResult::Success;
6239
- assert (isBroadcastable && " it should be broadcastable at this point" );
6242
+ assert (isBroadcastable &&
6243
+ " (I think) it must be broadcastable at this point." );
6240
6244
6241
6245
return true ;
6242
6246
}
0 commit comments