Skip to content

Commit 7decc51

Browse files
committed
tidy
Signed-off-by: James Newling <[email protected]>
1 parent 33a3782 commit 7decc51

File tree

2 files changed

+15
-14
lines changed

2 files changed

+15
-14
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2609,7 +2609,6 @@ struct BroadcastFolder : public OpRewritePattern<BroadcastOp> {
26092609
return success();
26102610
}
26112611
};
2612-
26132612
} // namespace
26142613

26152614
void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
@@ -6156,9 +6155,9 @@ class FoldTransposeCreateMask final : public OpRewritePattern<TransposeOp> {
61566155
}
61576156
};
61586157

6159-
/// Folds transpose(broadcast(x)) into broadcast(x) if the transpose is
6160-
/// 'order preserving', where 'order preserving' here means the flattened
6161-
/// inputs and outputs of the transpose have identical values.
6158+
/// Folds transpose(broadcast(x)) to broadcast(x) if the transpose is
6159+
/// 'order preserving', where 'order preserving' means the flattened
6160+
/// inputs and outputs of the transpose have identical (numerical) values.
61626161
///
61636162
/// Example:
61646163
/// ```
@@ -6188,10 +6187,10 @@ class FoldTransposeBroadcast : public OpRewritePattern<vector::TransposeOp> {
61886187

61896188
auto inputType = dyn_cast<VectorType>(broadcast.getSourceType());
61906189
bool inputIsScalar = !inputType;
6191-
auto inputShape = inputType.getShape();
6192-
auto inputRank = inputType.getRank();
6193-
auto outputRank = transpose.getType().getRank();
6194-
auto deltaRank = outputRank - inputRank;
6190+
ArrayRef<int64_t> inputShape = inputType.getShape();
6191+
int64_t inputRank = inputType.getRank();
6192+
int64_t outputRank = transpose.getType().getRank();
6193+
int64_t deltaRank = outputRank - inputRank;
61956194

61966195
// transpose(broadcast(scalar)) -> broadcast(scalar) is always valid
61976196
if (inputIsScalar)
@@ -6200,9 +6199,9 @@ class FoldTransposeBroadcast : public OpRewritePattern<vector::TransposeOp> {
62006199
// Return true if all permutation destinations for indices in [low, high)
62016200
// are in [low, high), so the permutation is local to the group.
62026201
auto isGroupBound = [&](int low, int high) {
6203-
auto perm = transpose.getPermutation();
6202+
ArrayRef<int64_t> permutation = transpose.getPermutation();
62046203
for (int j = low; j < high; ++j) {
6205-
if (perm[j] < low || perm[j] >= high) {
6204+
if (permutation[j] < low || permutation[j] >= high) {
62066205
return false;
62076206
}
62086207
}
@@ -6233,10 +6232,15 @@ class FoldTransposeBroadcast : public OpRewritePattern<vector::TransposeOp> {
62336232
return false;
62346233
}
62356234

6235+
// The preceding logic ensures that by this point, the ouutput of the
6236+
// transpose is definitely broadcastable from the input shape. So we don't
6237+
// need to call 'vector::isBroadcastableTo', but asserting here just as a
6238+
// sanity check:
62366239
bool isBroadcastable =
62376240
vector::isBroadcastableTo(inputType, transpose.getResultVectorType()) ==
62386241
vector::BroadcastableToResult::Success;
6239-
assert(isBroadcastable && "it should be broadcastable at this point");
6242+
assert(isBroadcastable &&
6243+
"(I think) it must be broadcastable at this point.");
62406244

62416245
return true;
62426246
}

mlir/test/Dialect/Vector/canonicalize.mlir

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,5 @@
11
// RUN: mlir-opt %s -canonicalize="test-convergence" -split-input-file -allow-unregistered-dialect | FileCheck %s
22

3-
4-
5-
63
// CHECK-LABEL: create_vector_mask_to_constant_mask
74
func.func @create_vector_mask_to_constant_mask() -> (vector<4x3xi1>) {
85
%c2 = arith.constant 2 : index

0 commit comments

Comments
 (0)