Skip to content

Commit 625eef5

Browse files
committed
fixup! [mlir][linalg] Add scalable vectorisation for depthwise convolutions
Generalise the code a tiny bit to cover for 1D depthwise NCW convs (once supported by the vectoriser).
1 parent 5e46b1b commit 625eef5

File tree

2 files changed

+29
-2
lines changed

2 files changed

+29
-2
lines changed

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3588,8 +3588,16 @@ static FailureOr<Operation *> vectorizeConvolution(
35883588
if (!inputVecSizes.empty()) {
35893589
// Only use the input vector size corresponding to the channel dim. Other
35903590
// vector dims will be inferred from the Ops.
3591-
vecChDimSize = inputVecSizes[2];
3592-
vecChDimScalableFlag = inputScalableVecDims[2];
3591+
assert((isa<linalg::DepthwiseConv1DNwcWcOp>(*op) ||
3592+
isa<linalg::DepthwiseConv1DNcwCwOp>(*op)) &&
3593+
"Not a 1D depthwise conv!");
3594+
size_t chDimIdx =
3595+
TypeSwitch<Operation *, size_t>(op)
3596+
.Case<linalg::DepthwiseConv1DNwcWcOp>([](auto conv) { return 2; })
3597+
.Case<linalg::DepthwiseConv1DNcwCwOp>([](auto conv) { return 1; });
3598+
3599+
vecChDimSize = inputVecSizes[chDimIdx];
3600+
vecChDimScalableFlag = inputScalableVecDims[chDimIdx];
35933601
}
35943602
return e.generateDilatedConv(vecChDimSize, vecChDimScalableFlag,
35953603
flatten1DDepthwiseConv);

mlir/test/Dialect/Linalg/vectorization-unsupported.mlir

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,25 @@ module attributes {transform.with_named_sequence} {
1919

2020
// -----
2121

22+
func.func @depthwise_conv1d_ncw_cw(%input: memref<3x5x4xf32>, %filter: memref<5x1xf32>, %output: memref<3x5x4xf32>) {
23+
// expected-error @+1 {{Attempted to vectorize, but failed}}
24+
linalg.depthwise_conv_1d_ncw_cw
25+
{dilations = dense<2> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}
26+
ins(%input, %filter : memref<3x5x4xf32>, memref<5x1xf32>)
27+
outs(%output : memref<3x5x4xf32>)
28+
return
29+
}
30+
31+
module attributes {transform.with_named_sequence} {
32+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
33+
%0 = transform.structured.match ops{["linalg.depthwise_conv_1d_ncw_cw"]} in %arg1 : (!transform.any_op) -> !transform.any_op
34+
transform.structured.vectorize %0 vector_sizes [3, 4, 5, 1] : !transform.any_op
35+
transform.yield
36+
}
37+
}
38+
39+
// -----
40+
2241
func.func @depthwise_conv1d_nwc_wc_dyn_w_dim(%input: memref<3x?x4xf32>, %filter: memref<?x4xf32>, %output: memref<3x?x4xf32>) {
2342
// expected-error @+1 {{Attempted to vectorize, but failed}}
2443
linalg.depthwise_conv_1d_nwc_wc

0 commit comments

Comments
 (0)