Skip to content

Commit 7f42653

Browse files
committed
fixup! [mlir][linalg] Add masked vectorisation for depthwise convolutions
More documentaiton, some simplification (as per Cullen's comments)
1 parent db6ef69 commit 7f42653

File tree

3 files changed

+12
-11
lines changed

3 files changed

+12
-11
lines changed

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1715,7 +1715,7 @@ static LogicalResult reductionPreconditions(LinalgOp op) {
17151715
}
17161716

17171717
static LogicalResult vectorizeDynamicConvOpPrecondition(linalg::LinalgOp conv) {
1718-
if (!isa<linalg::DepthwiseConv1DNwcWcOp>(conv.getOperation())) {
1718+
if (!isa<linalg::DepthwiseConv1DNwcWcOp>(conv)) {
17191719
LDBG("Not a depth-wise 1D conv, dynamic shapes are not supported\n");
17201720
return failure();
17211721
}
@@ -3597,6 +3597,9 @@ static FailureOr<Operation *> vectorizeConvolution(
35973597
if (succeeded(res))
35983598
return res;
35993599

3600+
// Only depthwise 1D NWC convs are left - these can be vectorized using masks
3601+
// and scalable vectors. Note that ATM the only dim that can be dynamic (i.e.
3602+
// masked/scalable) is the channel dim (i.e. the trailing dim).
36003603
uint64_t vecChDimSize = ShapedType::kDynamic;
36013604
bool vecChDimScalableFlag = false;
36023605
if (!inputVecSizes.empty()) {

mlir/test/Dialect/Linalg/vectorization-unsupported.mlir

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@ module attributes {transform.with_named_sequence} {
4141
func.func @depthwise_conv1d_nwc_wc_dyn_ch_dim(%input: memref<3x5x?xf32>, %filter: memref<2x?xf32>, %output: memref<3x2x?xf32>) {
4242
// expected-error @+1 {{Attempted to vectorize, but failed}}
4343
linalg.depthwise_conv_1d_nwc_wc
44-
{dilations = dense<2> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}
4544
ins(%input, %filter : memref<3x5x?xf32>, memref<2x?xf32>)
4645
outs(%output : memref<3x2x?xf32>)
4746
return
@@ -60,7 +59,6 @@ module attributes {transform.with_named_sequence} {
6059
func.func @depthwise_conv1d_nwc_wc_dyn_w_dim(%input: memref<3x?x3xf32>, %filter: memref<2x3xf32>, %output: memref<3x?x3xf32>) {
6160
// expected-error @+1 {{Attempted to vectorize, but failed}}
6261
linalg.depthwise_conv_1d_nwc_wc
63-
{dilations = dense<2> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}
6462
ins(%input, %filter : memref<3x?x3xf32>, memref<2x3xf32>)
6563
outs(%output : memref<3x?x3xf32>)
6664
return

mlir/test/Dialect/Linalg/vectorize-conv-scalable.mlir renamed to mlir/test/Dialect/Linalg/vectorize-conv-masked-and-scalable.mlir

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,14 @@ func.func @depthwise_conv1d_nwc_wc_3x5x4xf32_memref_dilation_2(
135135
return
136136
}
137137

138+
module attributes {transform.with_named_sequence} {
139+
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
140+
%0 = transform.structured.match ops{["linalg.depthwise_conv_1d_nwc_wc"]} in %arg0 : (!transform.any_op) -> !transform.any_op
141+
transform.structured.vectorize %0 vector_sizes [3, 2, [4], 2] : !transform.any_op
142+
transform.yield
143+
}
144+
}
145+
138146
// CHECK-LABEL: func.func @depthwise_conv1d_nwc_wc_3x5x4xf32_memref_dilation_2(
139147
// CHECK-SAME: %[[INPUT:.*]]: memref<3x5x?xf32>,
140148
// CHECK-SAME: %[[FILTER:.*]]: memref<2x?xf32>,
@@ -177,11 +185,3 @@ func.func @depthwise_conv1d_nwc_wc_3x5x4xf32_memref_dilation_2(
177185
// CHECK: %[[FMA_2:.*]] = vector.fma %[[IN_2]], %[[FLT_2_B]], %[[FMA_1]] : vector<3x2x[4]xf32>
178186
// CHECK: %[[OUT_INS:.*]] = vector.insert_strided_slice %[[FMA_2]], %[[VEC_OUT]] {offsets = [0, 0, 0], strides = [1, 1, 1]} : vector<3x2x[4]xf32> into vector<3x2x[4]xf32>
179187
// CHECK: vector.mask %[[MASK_OUT]] { vector.transfer_write %[[OUT_INS]], %[[OUTPUT]]{{\[}}%[[C0]], %[[C0]], %[[C0]]] : vector<3x2x[4]xf32>, memref<3x2x?xf32> } : vector<3x2x[4]xi1>
180-
181-
module attributes {transform.with_named_sequence} {
182-
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
183-
%0 = transform.structured.match ops{["linalg.depthwise_conv_1d_nwc_wc"]} in %arg0 : (!transform.any_op) -> !transform.any_op
184-
transform.structured.vectorize %0 vector_sizes [3, 2, [4], 2] : !transform.any_op
185-
transform.yield
186-
}
187-
}

0 commit comments

Comments
 (0)