Skip to content

Commit c5dade6

Browse files
committed
Fixups
1 parent 272f375 commit c5dade6

File tree

5 files changed

+52
-16
lines changed

5 files changed

+52
-16
lines changed

mlir/include/mlir/Dialect/Vector/IR/VectorOps.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,9 @@ namespace detail {
5656
struct BitmaskEnumStorage;
5757
} // namespace detail
5858

59+
/// Predefined constant_mask kinds.
60+
enum class ConstantMaskKind { AllFalse = 0, AllTrue };
61+
5962
/// Default callback to build a region with a 'vector.yield' terminator with no
6063
/// arguments.
6164
void buildTerminatedBody(OpBuilder &builder, Location loc);

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2362,6 +2362,11 @@ def Vector_ConstantMaskOp :
23622362
```
23632363
}];
23642364

2365+
let builders = [
2366+
// Build with mixed static/dynamic operands.
2367+
OpBuilder<(ins "VectorType":$type, "ConstantMaskKind":$kind)>
2368+
];
2369+
23652370
let extraClassDeclaration = [{
23662371
/// Return the result type of this op.
23672372
VectorType getVectorType() {

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5749,6 +5749,16 @@ void vector::TransposeOp::getCanonicalizationPatterns(
57495749
// ConstantMaskOp
57505750
//===----------------------------------------------------------------------===//
57515751

5752+
void ConstantMaskOp::build(OpBuilder &builder, OperationState &result,
5753+
VectorType type, ConstantMaskKind kind) {
5754+
assert(kind == ConstantMaskKind::AllTrue ||
5755+
kind == ConstantMaskKind::AllFalse);
5756+
build(builder, result, type,
5757+
kind == ConstantMaskKind::AllTrue
5758+
? type.getShape()
5759+
: SmallVector<int64_t>(type.getRank(), 0));
5760+
}
5761+
57525762
LogicalResult ConstantMaskOp::verify() {
57535763
auto resultType = llvm::cast<VectorType>(getResult().getType());
57545764
// Check the corner case of 0-D vectors first.

mlir/lib/Dialect/Vector/Transforms/VectorMaskElimination.cpp

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@ using namespace mlir;
1717
using namespace mlir::vector;
1818
namespace {
1919

20-
/// If `value` is a constant multiple of `vector.vscale` return the multiplier.
20+
/// If `value` is a constant multiple of `vector.vscale` (e.g. `%cst *
21+
/// vector.vscale`), return the multiplier (`%cst`). Otherwise, return
22+
/// `std::nullopt`.
2123
std::optional<int64_t> getConstantVscaleMultiplier(Value value) {
2224
if (value.getDefiningOp<vector::VectorScaleOp>())
2325
return 1;
@@ -78,17 +80,16 @@ LogicalResult resolveAllTrueCreateMaskOp(IRRewriter &rewriter,
7880
if (failed(dimLowerBoundSize))
7981
return failure();
8082
if (dimLowerBoundSize->scalable) {
81-
// If the lower bound is scalable and < the mask dim size then this dim is
82-
// not all-true.
83+
// 1. The lower bound, LB, is scalable. If LB is < the mask dim size then
84+
// this dim is not all-true.
8385
if (dimLowerBoundSize->baseSize < maskTypeDimSizes[i])
8486
return failure();
8587
} else {
86-
// If the lower bound is a constant:
88+
// 2. The lower bound, LB, is a constant.
8789
// - If the mask dim size is scalable then this dim is not all-true.
8890
if (maskTypeDimScalableFlags[i])
8991
return failure();
90-
// - If the lower bound is < the _fixed-size_ mask dim size then this dim
91-
// is not all-true.
92+
// - If LB < the _fixed-size_ mask dim size then this dim is not all-true.
9293
if (dimLowerBoundSize->baseSize < maskTypeDimSizes[i])
9394
return failure();
9495
}
@@ -97,8 +98,8 @@ LogicalResult resolveAllTrueCreateMaskOp(IRRewriter &rewriter,
9798
// Replace createMaskOp with an all-true constant. This should result in the
9899
// mask being removed in most cases (as xfer ops + vector.mask have folds to
99100
// remove all-true masks).
100-
auto allTrue = rewriter.create<arith::ConstantOp>(
101-
createMaskOp.getLoc(), maskType, DenseElementsAttr::get(maskType, true));
101+
auto allTrue = rewriter.create<vector::ConstantMaskOp>(
102+
createMaskOp.getLoc(), maskType, ConstantMaskKind::AllTrue);
102103
rewriter.replaceAllUsesWith(createMaskOp, allTrue);
103104
return success();
104105
}

mlir/test/Dialect/Vector/eliminate-masks.mlir

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1-
// RUN: mlir-opt %s -split-input-file -test-eliminate-vector-masks | FileCheck %s
1+
// RUN: mlir-opt %s -split-input-file -test-eliminate-vector-masks --split-input-file | FileCheck %s
22

33
// This tests a general pattern the vectorizer tends to emit.
44

55
// CHECK-LABEL: @eliminate_redundant_masks_through_insert_and_extracts
6-
// CHECK: %[[ALL_TRUE_MASK:.*]] = arith.constant dense<true> : vector<[4]xi1>
6+
// CHECK: %[[ALL_TRUE_MASK:.*]] = vector.constant_mask [4] : vector<[4]xi1>
77
// CHECK: vector.transfer_read {{.*}} %[[ALL_TRUE_MASK]]
88
// CHECK: vector.transfer_write {{.*}} %[[ALL_TRUE_MASK]]
99
func.func @eliminate_redundant_masks_through_insert_and_extracts(%tensor: tensor<1x1000xf32>) {
@@ -40,7 +40,7 @@ func.func @eliminate_redundant_masks_through_insert_and_extracts(%tensor: tensor
4040
// -----
4141

4242
// CHECK-LABEL: @negative_extract_slice_size_shrink
43-
// CHECK-NOT: arith.constant dense<true> : vector<[4]xi1>
43+
// CHECK-NOT: vector.constant_mask
4444
// CHECK: %[[MASK:.*]] = vector.create_mask
4545
// CHECK: "test.some_use"(%[[MASK]]) : (vector<[4]xi1>) -> ()
4646
func.func @negative_extract_slice_size_shrink(%tensor: tensor<1000xf32>) {
@@ -67,8 +67,25 @@ func.func @negative_extract_slice_size_shrink(%tensor: tensor<1000xf32>) {
6767

6868
// -----
6969

70+
// CHECK-LABEL: @trivially_all_true_case
71+
// CHECK: %[[ALL_TRUE_MASK:.*]] = vector.constant_mask [2, 4] : vector<2x[4]xi1>
72+
// CHECK: "test.some_use"(%[[ALL_TRUE_MASK]]) : (vector<2x[4]xi1>) -> ()
73+
func.func @trivially_all_true_case(%tensor: tensor<2x?xf32>)
74+
{
75+
%c2 = arith.constant 2 : index
76+
%c4 = arith.constant 4 : index
77+
%vscale = vector.vscale
78+
%c4_vscale = arith.muli %vscale, %c4 : index
79+
// Is found to be all true _without_ value bounds analysis.
80+
%mask = vector.create_mask %c2, %c4_vscale : vector<2x[4]xi1>
81+
"test.some_use"(%mask) : (vector<2x[4]xi1>) -> ()
82+
return
83+
}
84+
85+
// -----
86+
7087
// CHECK-LABEL: @negative_constant_dim_not_all_true
71-
// CHECK-NOT: arith.constant dense<true> : vector<2x[4]xi1>
88+
// CHECK-NOT: vector.constant_mask
7289
// CHECK: %[[MASK:.*]] = vector.create_mask
7390
// CHECK: "test.some_use"(%[[MASK]]) : (vector<2x[4]xi1>) -> ()
7491
func.func @negative_constant_dim_not_all_true()
@@ -87,7 +104,7 @@ func.func @negative_constant_dim_not_all_true()
87104
// -----
88105

89106
// CHECK-LABEL: @negative_constant_vscale_multiple_not_all_true
90-
// CHECK-NOT: arith.constant dense<true> : vector<2x[4]xi1>
107+
// CHECK-NOT: vector.constant_mask
91108
// CHECK: %[[MASK:.*]] = vector.create_mask
92109
// CHECK: "test.some_use"(%[[MASK]]) : (vector<2x[4]xi1>) -> ()
93110
func.func @negative_constant_vscale_multiple_not_all_true() {
@@ -105,7 +122,7 @@ func.func @negative_constant_vscale_multiple_not_all_true() {
105122
// -----
106123

107124
// CHECK-LABEL: @negative_value_bounds_fixed_dim_not_all_true
108-
// CHECK-NOT: arith.constant dense<true> : vector<3x[4]xi1>
125+
// CHECK-NOT: vector.constant_mask
109126
// CHECK: %[[MASK:.*]] = vector.create_mask
110127
// CHECK: "test.some_use"(%[[MASK]]) : (vector<3x[4]xi1>) -> ()
111128
func.func @negative_value_bounds_fixed_dim_not_all_true(%tensor: tensor<2x?xf32>)
@@ -114,7 +131,7 @@ func.func @negative_value_bounds_fixed_dim_not_all_true(%tensor: tensor<2x?xf32>
114131
%c4 = arith.constant 4 : index
115132
%vscale = vector.vscale
116133
%c4_vscale = arith.muli %vscale, %c4 : index
117-
// This is _very_ simple but since tensor.dim is not a constant value bounds
134+
// This is _very_ simple, but since tensor.dim is not a constant, value bounds
118135
// will be used to resolve it.
119136
%dim = tensor.dim %tensor, %c0 : tensor<2x?xf32>
120137
%mask = vector.create_mask %dim, %c4_vscale : vector<3x[4]xi1>
@@ -125,7 +142,7 @@ func.func @negative_value_bounds_fixed_dim_not_all_true(%tensor: tensor<2x?xf32>
125142
// -----
126143

127144
// CHECK-LABEL: @negative_value_bounds_scalable_dim_not_all_true
128-
// CHECK-NOT: arith.constant dense<true> : vector<3x[4]xi1>
145+
// CHECK-NOT: vector.constant_mask
129146
// CHECK: %[[MASK:.*]] = vector.create_mask
130147
// CHECK: "test.some_use"(%[[MASK]]) : (vector<3x[4]xi1>) -> ()
131148
func.func @negative_value_bounds_scalable_dim_not_all_true(%tensor: tensor<2x100xf32>) {

0 commit comments

Comments
 (0)