Skip to content

Commit 34e34a0

Browse files
authored
[mlir][linalg] Mark xfers as in-bounds when masking depthwise convs (#96771)
If this is not set the fact that the dynamic channel is in-bounds cannot be inferred automatically (like it can for static sizes), which eventually leads to it being marked as out-of-bounds (which prevents some rewrites).
1 parent bb97378 commit 34e34a0

File tree

2 files changed

+17
-12
lines changed

2 files changed

+17
-12
lines changed

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3234,6 +3234,11 @@ struct Conv1DGenerator
32343234
auto maskType =
32353235
VectorType::get(maskShape, rewriter.getI1Type(), scalableDims);
32363236

3237+
SmallVector<bool> inBounds(maskShape.size(), true);
3238+
auto xferOp = cast<VectorTransferOpInterface>(opToMask);
3239+
xferOp->setAttr(xferOp.getInBoundsAttrName(),
3240+
rewriter.getBoolArrayAttr(inBounds));
3241+
32373242
SmallVector<OpFoldResult> mixedDims = vector::getMixedSizesXfer(
32383243
cast<LinalgOp>(op).hasPureTensorSemantics(), opToMask, rewriter);
32393244

mlir/test/Dialect/Linalg/vectorize-conv-masked-and-scalable.mlir

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -34,18 +34,18 @@ module attributes {transform.with_named_sequence} {
3434
// CHECK: %[[C8:.*]] = arith.constant 8 : index
3535
// CHECK: %[[MASK_IN:.*]] = vector.create_mask %[[C1]], %[[C8]], %[[CH_DIM_IN]] : vector<1x8x4xi1>
3636
/// Read the input tensor
37-
// CHECK: %[[VEC_IN:.*]] = vector.mask %[[MASK_IN]] { vector.transfer_read %[[INPUT]]{{\[}}%[[C0]], %[[C0]], %[[C0]]], %[[PAD]] : tensor<1x8x?xi8>, vector<1x8x4xi8> } : vector<1x8x4xi1> -> vector<1x8x4xi8>
37+
// CHECK: %[[VEC_IN:.*]] = vector.mask %[[MASK_IN]] { vector.transfer_read %[[INPUT]]{{\[}}%[[C0]], %[[C0]], %[[C0]]], %[[PAD]] {in_bounds = [true, true, true]} : tensor<1x8x?xi8>, vector<1x8x4xi8> } : vector<1x8x4xi1> -> vector<1x8x4xi8>
3838

3939
/// Create a mask for the filter tensor
4040
// CHECK: %[[CH_DIM_FLT:.*]] = tensor.dim %[[FILTER]], %[[C1]] : tensor<1x?xi8>
4141
// CHECK: %[[MASK_FLT:.*]] = vector.create_mask %[[C1]], %[[CH_DIM_FLT]] : vector<1x4xi1>
4242
/// Read the filter tensor
43-
// CHECK: %[[VEC_FLT:.*]] = vector.mask %[[MASK_FLT]] { vector.transfer_read %[[FILTER]]{{\[}}%[[C0]], %[[C0]]], %[[PAD]] : tensor<1x?xi8>, vector<1x4xi8> } : vector<1x4xi1> -> vector<1x4xi8>
43+
// CHECK: %[[VEC_FLT:.*]] = vector.mask %[[MASK_FLT]] { vector.transfer_read %[[FILTER]]{{\[}}%[[C0]], %[[C0]]], %[[PAD]] {in_bounds = [true, true]} : tensor<1x?xi8>, vector<1x4xi8> } : vector<1x4xi1> -> vector<1x4xi8>
4444

4545
/// Create a mask for the output tensor
4646
// CHECK: %[[CH_DIM_OUT:.*]] = tensor.dim %[[OUTPUT]], %[[C2]] : tensor<1x8x?xi8>
4747
// CHECK: %[[MASK_OUT:.*]] = vector.create_mask %[[C1]], %[[C8]], %[[CH_DIM_OUT]] : vector<1x8x4xi1>
48-
// CHECK: %[[VEC_OUT:.*]] = vector.mask %[[MASK_OUT]] { vector.transfer_read %[[OUTPUT]]{{\[}}%[[C0]], %[[C0]], %[[C0]]], %[[PAD]] : tensor<1x8x?xi8>, vector<1x8x4xi8> } : vector<1x8x4xi1> -> vector<1x8x4xi8>
48+
// CHECK: %[[VEC_OUT:.*]] = vector.mask %[[MASK_OUT]] { vector.transfer_read %[[OUTPUT]]{{\[}}%[[C0]], %[[C0]], %[[C0]]], %[[PAD]] {in_bounds = [true, true, true]} : tensor<1x8x?xi8>, vector<1x8x4xi8> } : vector<1x8x4xi1> -> vector<1x8x4xi8>
4949

5050
/// Convolution
5151
// CHECK: %[[IN_1:.*]] = vector.extract_strided_slice %[[VEC_IN]] {offsets = [0, 0, 0], sizes = [1, 8, 4], strides = [1, 1, 1]} : vector<1x8x4xi8> to vector<1x8x4xi8>
@@ -55,7 +55,7 @@ module attributes {transform.with_named_sequence} {
5555
// CHECK: %[[MULI:.*]] = arith.muli %[[IN_1]], %[[FLT_1_B]] : vector<1x8x4xi8>
5656
// CHECK: %[[ADDI:.*]] = arith.addi %[[MULI]], %[[OUT_1]] : vector<1x8x4xi8>
5757
// CHECK: %[[OUT_INS:.*]] = vector.insert_strided_slice %[[ADDI]], %[[VEC_OUT]] {offsets = [0, 0, 0], strides = [1, 1, 1]} : vector<1x8x4xi8> into vector<1x8x4xi8>
58-
// CHECK: %[[OUT:.*]] = vector.mask %[[MASK_OUT]] { vector.transfer_write %[[OUT_INS]], %[[OUTPUT]]{{\[}}%[[C0]], %[[C0]], %[[C0]]] : vector<1x8x4xi8>, tensor<1x8x?xi8> } : vector<1x8x4xi1> -> tensor<1x8x?xi8>
58+
// CHECK: %[[OUT:.*]] = vector.mask %[[MASK_OUT]] { vector.transfer_write %[[OUT_INS]], %[[OUTPUT]]{{\[}}%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true]} : vector<1x8x4xi8>, tensor<1x8x?xi8> } : vector<1x8x4xi1> -> tensor<1x8x?xi8>
5959
// CHECK: return %[[OUT]] : tensor<1x8x?xi8>
6060

6161
// -----
@@ -95,19 +95,19 @@ module attributes {transform.with_named_sequence} {
9595
// CHECK: %[[C8:.*]] = arith.constant 8 : index
9696
// CHECK: %[[MASK_IN:.*]] = vector.create_mask %[[C1]], %[[C8]], %[[CH_DIM_IN]] : vector<1x8x[4]xi1>
9797
/// Read the input tensor
98-
// CHECK: %[[VEC_IN:.*]] = vector.mask %[[MASK_IN]] { vector.transfer_read %[[INPUT]]{{\[}}%[[C0]], %[[C0]], %[[C0]]], %[[PAD]] : tensor<1x8x?xi8>, vector<1x8x[4]xi8> } : vector<1x8x[4]xi1> -> vector<1x8x[4]xi8>
98+
// CHECK: %[[VEC_IN:.*]] = vector.mask %[[MASK_IN]] { vector.transfer_read %[[INPUT]]{{\[}}%[[C0]], %[[C0]], %[[C0]]], %[[PAD]] {in_bounds = [true, true, true]} : tensor<1x8x?xi8>, vector<1x8x[4]xi8> } : vector<1x8x[4]xi1> -> vector<1x8x[4]xi8>
9999

100100
/// Create a mask for the filter tensor
101101
// CHECK: %[[CH_DIM_FLT:.*]] = tensor.dim %[[FILTER]], %[[C1]] : tensor<1x?xi8>
102102
// CHECK: %[[MASK_FLT:.*]] = vector.create_mask %[[C1]], %[[CH_DIM_FLT]] : vector<1x[4]xi1>
103103
/// Read the filter tensor
104-
// CHECK: %[[VEC_FLT:.*]] = vector.mask %[[MASK_FLT]] { vector.transfer_read %[[FILTER]]{{\[}}%[[C0]], %[[C0]]], %[[PAD]] : tensor<1x?xi8>, vector<1x[4]xi8> } : vector<1x[4]xi1> -> vector<1x[4]xi8>
104+
// CHECK: %[[VEC_FLT:.*]] = vector.mask %[[MASK_FLT]] { vector.transfer_read %[[FILTER]]{{\[}}%[[C0]], %[[C0]]], %[[PAD]] {in_bounds = [true, true]} : tensor<1x?xi8>, vector<1x[4]xi8> } : vector<1x[4]xi1> -> vector<1x[4]xi8>
105105

106106
/// Create a mask for the output tensor
107107
// CHECK: %[[CH_DIM_OUT:.*]] = tensor.dim %[[OUTPUT]], %[[C2]] : tensor<1x8x?xi8>
108108
// CHECK: %[[MASK_OUT:.*]] = vector.create_mask %[[C1]], %[[C8]], %[[CH_DIM_OUT]] : vector<1x8x[4]xi1>
109109
/// Read the output tensor
110-
// CHECK: %[[VEC_OUT:.*]] = vector.mask %[[MASK_OUT]] { vector.transfer_read %[[OUTPUT]]{{\[}}%[[C0]], %[[C0]], %[[C0]]], %[[PAD]] : tensor<1x8x?xi8>, vector<1x8x[4]xi8> } : vector<1x8x[4]xi1> -> vector<1x8x[4]xi8>
110+
// CHECK: %[[VEC_OUT:.*]] = vector.mask %[[MASK_OUT]] { vector.transfer_read %[[OUTPUT]]{{\[}}%[[C0]], %[[C0]], %[[C0]]], %[[PAD]] {in_bounds = [true, true, true]} : tensor<1x8x?xi8>, vector<1x8x[4]xi8> } : vector<1x8x[4]xi1> -> vector<1x8x[4]xi8>
111111

112112
/// Convolution
113113
// CHECK: %[[IN_1:.*]] = vector.extract_strided_slice %[[VEC_IN]] {offsets = [0, 0, 0], sizes = [1, 8, 4], strides = [1, 1, 1]} : vector<1x8x[4]xi8> to vector<1x8x[4]xi8>
@@ -117,7 +117,7 @@ module attributes {transform.with_named_sequence} {
117117
// CHECK: %[[MULI:.*]] = arith.muli %[[IN_1]], %[[FLT_1_B]] : vector<1x8x[4]xi8>
118118
// CHECK: %[[ADDI:.*]] = arith.addi %[[MULI]], %[[OUT_1]] : vector<1x8x[4]xi8>
119119
// CHECK: %[[OUT_INS:.*]] = vector.insert_strided_slice %[[ADDI]], %[[VEC_OUT]] {offsets = [0, 0, 0], strides = [1, 1, 1]} : vector<1x8x[4]xi8> into vector<1x8x[4]xi8>
120-
// CHECK: %[[OUT:.*]] = vector.mask %[[MASK_OUT]] { vector.transfer_write %[[OUT_INS]], %[[OUTPUT]]{{\[}}%[[C0]], %[[C0]], %[[C0]]] : vector<1x8x[4]xi8>, tensor<1x8x?xi8> } : vector<1x8x[4]xi1> -> tensor<1x8x?xi8>
120+
// CHECK: %[[OUT:.*]] = vector.mask %[[MASK_OUT]] { vector.transfer_write %[[OUT_INS]], %[[OUTPUT]]{{\[}}%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true]} : vector<1x8x[4]xi8>, tensor<1x8x?xi8> } : vector<1x8x[4]xi1> -> tensor<1x8x?xi8>
121121
// CHECK: return %[[OUT]] : tensor<1x8x?xi8>
122122

123123
// -----
@@ -157,19 +157,19 @@ module attributes {transform.with_named_sequence} {
157157
// CHECK: %[[C5:.*]] = arith.constant 5 : index
158158
// CHECK: %[[MASK_IN:.*]] = vector.create_mask %[[C3]], %[[C5]], %[[CH_DIM_IN]] : vector<3x4x[4]xi1>
159159
/// Read the input tensor
160-
// CHECK: %[[VEC_IN:.*]] = vector.mask %[[MASK_IN]] { vector.transfer_read %[[INPUT]]{{\[}}%[[C0]], %[[C0]], %[[C0]]], %[[PAD]] : memref<3x5x?xf32>, vector<3x4x[4]xf32> } : vector<3x4x[4]xi1> -> vector<3x4x[4]xf32>
160+
// CHECK: %[[VEC_IN:.*]] = vector.mask %[[MASK_IN]] { vector.transfer_read %[[INPUT]]{{\[}}%[[C0]], %[[C0]], %[[C0]]], %[[PAD]] {in_bounds = [true, true, true]} : memref<3x5x?xf32>, vector<3x4x[4]xf32> } : vector<3x4x[4]xi1> -> vector<3x4x[4]xf32>
161161

162162
/// Create a mask for the filter tensor
163163
// CHECK: %[[CH_DIM_FLT:.*]] = memref.dim %[[FILTER]], %[[C1]] : memref<2x?xf32>
164164
// CHECK: %[[MASK_FLT:.*]] = vector.create_mask %[[C2]], %[[CH_DIM_FLT]] : vector<2x[4]xi1>
165165
/// Read the filter tensor
166-
// CHECK: %[[VEC_FLT:.*]] = vector.mask %[[MASK_FLT]] { vector.transfer_read %[[FILTER]]{{\[}}%[[C0]], %[[C0]]], %[[PAD]] : memref<2x?xf32>, vector<2x[4]xf32> } : vector<2x[4]xi1> -> vector<2x[4]xf32>
166+
// CHECK: %[[VEC_FLT:.*]] = vector.mask %[[MASK_FLT]] { vector.transfer_read %[[FILTER]]{{\[}}%[[C0]], %[[C0]]], %[[PAD]] {in_bounds = [true, true]} : memref<2x?xf32>, vector<2x[4]xf32> } : vector<2x[4]xi1> -> vector<2x[4]xf32>
167167

168168
/// Create a mask for the output tensor
169169
// CHECK: %[[CH_DIM_OUT:.*]] = memref.dim %[[OUTPUT]], %[[C2]] : memref<3x2x?xf32>
170170
// CHECK: %[[MASK_OUT:.*]] = vector.create_mask %[[C3]], %[[C2]], %[[CH_DIM_OUT]] : vector<3x2x[4]xi1>
171171
/// Read the output tensor
172-
// CHECK: %[[VEC_OUT:.*]] = vector.mask %[[MASK_OUT]] { vector.transfer_read %[[OUTPUT]]{{\[}}%[[C0]], %[[C0]], %[[C0]]], %[[PAD]] : memref<3x2x?xf32>, vector<3x2x[4]xf32> } : vector<3x2x[4]xi1> -> vector<3x2x[4]xf32>
172+
// CHECK: %[[VEC_OUT:.*]] = vector.mask %[[MASK_OUT]] { vector.transfer_read %[[OUTPUT]]{{\[}}%[[C0]], %[[C0]], %[[C0]]], %[[PAD]] {in_bounds = [true, true, true]} : memref<3x2x?xf32>, vector<3x2x[4]xf32> } : vector<3x2x[4]xi1> -> vector<3x2x[4]xf32>
173173

174174
/// Convolution
175175
// CHECK: %[[IN_1:.*]] = vector.extract_strided_slice %[[VEC_IN]] {offsets = [0, 0, 0], sizes = [3, 2, 4], strides = [1, 1, 1]} : vector<3x4x[4]xf32> to vector<3x2x[4]xf32>
@@ -182,4 +182,4 @@ module attributes {transform.with_named_sequence} {
182182
// CHECK: %[[FLT_2_B:.*]] = vector.broadcast %[[FLT_2]] : vector<[4]xf32> to vector<3x2x[4]xf32>
183183
// CHECK: %[[FMA_2:.*]] = vector.fma %[[IN_2]], %[[FLT_2_B]], %[[FMA_1]] : vector<3x2x[4]xf32>
184184
// CHECK: %[[OUT_INS:.*]] = vector.insert_strided_slice %[[FMA_2]], %[[VEC_OUT]] {offsets = [0, 0, 0], strides = [1, 1, 1]} : vector<3x2x[4]xf32> into vector<3x2x[4]xf32>
185-
// CHECK: vector.mask %[[MASK_OUT]] { vector.transfer_write %[[OUT_INS]], %[[OUTPUT]]{{\[}}%[[C0]], %[[C0]], %[[C0]]] : vector<3x2x[4]xf32>, memref<3x2x?xf32> } : vector<3x2x[4]xi1>
185+
// CHECK: vector.mask %[[MASK_OUT]] { vector.transfer_write %[[OUT_INS]], %[[OUTPUT]]{{\[}}%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true]} : vector<3x2x[4]xf32>, memref<3x2x?xf32> } : vector<3x2x[4]xi1>

0 commit comments

Comments
 (0)