-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[mlir][linalg] Enable masked vectorisation for depthwise convolutions #81625
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][linalg] Enable masked vectorisation for depthwise convolutions #81625
Conversation
@llvm/pr-subscribers-mlir-vector @llvm/pr-subscribers-mlir Author: Andrzej Warzyński (banach-space) ChangesThis patch adds support for scalable vectorisation of depthwise 1D HWC Two major assumptions are made:
In terms of scalable vectorisation, this should be sufficient that cover Patch is 23.99 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/81625.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 2bd6929fea6142..aed96be5b6da00 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -54,6 +54,7 @@ using namespace mlir::linalg;
/// Try to vectorize `convOp` as a convolution.
static FailureOr<Operation *>
vectorizeConvolution(RewriterBase &rewriter, LinalgOp convOp,
+ ArrayRef<int64_t> inputVecSizes = {},
bool flatten1DDepthwiseConv = false);
/// Return the unique instance of OpType in `block` if it is indeed unique.
@@ -1609,6 +1610,19 @@ static LogicalResult reductionPreconditions(LinalgOp op) {
}
static LogicalResult vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op) {
+ // Support dynamic shapes in 1D depthwise convolution, but only in the
+ // _channel_ dimension. That's exclusively to support scalable vectorisation.
+ if (auto conv = dyn_cast<linalg::DepthwiseConv1DNwcWcOp>(op.getOperation())) {
+ auto lhsShaped = op.getDpsInputOperand(0)->get();
+ ArrayRef<int64_t> lhsShape =
+ dyn_cast<ShapedType>(lhsShaped.getType()).getShape();
+ auto shapeWithoutCh = lhsShape.drop_back(1);
+ if (ShapedType::isDynamicShape(shapeWithoutCh))
+ return failure();
+
+ return success();
+ }
+
// TODO: Masking only supports dynamic element-wise ops, linalg.generic ops,
// linalg.copy ops and ops that implement ContractionOpInterface for now.
if (!isElementwise(op) &&
@@ -1789,7 +1803,8 @@ vectorizeScalableVectorPrecondition(Operation *op,
// Only element-wise ops supported in the presence of scalable dims.
auto linalgOp = dyn_cast<LinalgOp>(op);
- return success(linalgOp && isElementwise(linalgOp));
+ return success(linalgOp && (isElementwise(linalgOp) ||
+ isa<linalg::DepthwiseConv1DNwcWcOp>(op)));
}
LogicalResult mlir::linalg::vectorizeOpPrecondition(
@@ -1871,7 +1886,7 @@ LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
// features. Will require stride/dilation attributes inference.
if (isa<ConvolutionOpInterface>(linalgOp.getOperation())) {
FailureOr<Operation *> convOr = vectorizeConvolution(
- rewriter, linalgOp, flatten1DDepthwiseConv);
+ rewriter, linalgOp, inputVectorSizes, flatten1DDepthwiseConv);
if (succeeded(convOr)) {
llvm::append_range(results, (*convOr)->getResults());
return success();
@@ -2697,6 +2712,7 @@ struct Conv1DGenerator
return;
break;
}
+ hasTensorSemantics = linalgOp.hasPureTensorSemantics();
// The op is now known to be valid.
valid = true;
}
@@ -3000,13 +3016,21 @@ struct Conv1DGenerator
/// kw is always unrolled.
/// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is
/// > 1.
- FailureOr<Operation *> depthwiseConv(bool flatten) {
+ FailureOr<Operation *> depthwiseConv(uint64_t channelDimVecSize,
+ bool flatten) {
if (!valid)
return rewriter.notifyMatchFailure(op, "unvectorizable depthwise conv");
+ bool scalableChDim = false;
int64_t nSize, wSize, cSize, kwSize;
// kernel{kw, c}
bindShapeDims(rhsShapedType, kwSize, cSize);
+ // Dynamic channel size implies scalable vectorisation
+ if (ShapedType::isDynamic(cSize)) {
+ assert(channelDimVecSize != 0 && "Channel dim vec size must be > 0");
+ cSize = channelDimVecSize;
+ scalableChDim = true;
+ }
// out{n, w, c}
bindShapeDims(resShapedType, nSize, wSize);
@@ -3027,20 +3051,74 @@ struct Conv1DGenerator
// (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14)
((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) - 1,
cSize},
- lhsEltType);
- VectorType rhsType = VectorType::get({kwSize, cSize}, rhsEltType);
- VectorType resType = VectorType::get({nSize, wSize, cSize}, resEltType);
+ lhsEltType, {false, false, scalableChDim});
+ VectorType rhsType =
+ VectorType::get({kwSize, cSize}, rhsEltType,
+ /*scalableDims=*/{false, scalableChDim});
+ VectorType resType =
+ VectorType::get({nSize, wSize, cSize}, resEltType,
+ /*scalableDims=*/{false, false, scalableChDim});
+
+ // Masks the input xfer Op along the channel dim, iff the corresponding
+ // scalable flag is set.
+ auto maybeMaskXferOp = [&](ArrayRef<int64_t> maskShape,
+ ArrayRef<bool> scalableDims,
+ Operation *opToMask) {
+ bool scalableChDim = scalableDims.back();
+ if (!scalableChDim)
+ return opToMask;
+
+ auto maskType =
+ VectorType::get(maskShape, rewriter.getI1Type(), scalableDims);
+
+ SmallVector<OpFoldResult> mixedSourceDims =
+ hasTensorSemantics
+ ? TypeSwitch<Operation *, SmallVector<OpFoldResult>>(opToMask)
+ .Case<vector::TransferReadOp>([&](auto readOp) {
+ return tensor::getMixedSizes(rewriter, loc,
+ readOp.getSource());
+ })
+ .Case<vector::TransferWriteOp>([&](auto writeOp) {
+ return tensor::getMixedSizes(rewriter, loc,
+ writeOp.getOperand(1));
+ })
+ : TypeSwitch<Operation *, SmallVector<OpFoldResult>>(opToMask)
+ .Case<vector::TransferReadOp>([&](auto readOp) {
+ return memref::getMixedSizes(rewriter, loc,
+ readOp.getSource());
+ })
+ .Case<vector::TransferWriteOp>([&](auto writeOp) {
+ return memref::getMixedSizes(rewriter, loc,
+ writeOp.getOperand(1));
+ });
+
+ Value maskOp =
+ rewriter.create<vector::CreateMaskOp>(loc, maskType, mixedSourceDims);
+
+ return mlir::vector::maskOperation(rewriter, opToMask, maskOp);
+ };
// Read lhs slice of size {n, w * strideW + kw * dilationW, c} @ [0, 0,
// 0].
Value lhs = rewriter.create<vector::TransferReadOp>(
loc, lhsType, lhsShaped, ValueRange{zero, zero, zero});
+ auto maybeMaskedLHS = maybeMaskXferOp(
+ lhsType.getShape(),
+ /*scalableDims=*/{false, false, scalableChDim}, lhs.getDefiningOp());
+
// Read rhs slice of size {kw, c} @ [0, 0].
Value rhs = rewriter.create<vector::TransferReadOp>(loc, rhsType, rhsShaped,
ValueRange{zero, zero});
+ auto maybeMaskedRHS = maybeMaskXferOp(
+ rhsType.getShape(),
+ /*scalableDims=*/{false, scalableChDim}, rhs.getDefiningOp());
+
// Read res slice of size {n, w, c} @ [0, 0, 0].
Value res = rewriter.create<vector::TransferReadOp>(
loc, resType, resShaped, ValueRange{zero, zero, zero});
+ auto maybeMaskedRES = maybeMaskXferOp(
+ resType.getShape(),
+ /*scalableDims=*/{false, false, scalableChDim}, res.getDefiningOp());
//===------------------------------------------------------------------===//
// Begin vector-only rewrite part
@@ -3055,7 +3133,7 @@ struct Conv1DGenerator
for (int64_t kw = 0; kw < kwSize; ++kw) {
for (int64_t w = 0; w < wSize; w += wSizeStep) {
lhsVals.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
- loc, lhs,
+ loc, maybeMaskedLHS->getResult(0),
/*offsets=*/ArrayRef<int64_t>{0, w * strideW + kw * dilationW, 0},
inOutSliceSizes, inOutStrides));
}
@@ -3063,12 +3141,13 @@ struct Conv1DGenerator
// Extract rhs slice of size {c} @ [kw].
for (int64_t kw = 0; kw < kwSize; ++kw) {
rhsVals.push_back(rewriter.create<vector::ExtractOp>(
- loc, rhs, /*offsets=*/ArrayRef<int64_t>{kw}));
+ loc, maybeMaskedRHS->getResult(0),
+ /*offsets=*/ArrayRef<int64_t>{kw}));
}
// Extract res slice: {n, wSizeStep, c} @ [0, w, 0].
for (int64_t w = 0; w < wSize; w += wSizeStep) {
resVals.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
- loc, res,
+ loc, maybeMaskedRES->getResult(0),
/*offsets=*/ArrayRef<int64_t>{0, w, 0}, inOutSliceSizes,
inOutStrides));
}
@@ -3107,6 +3186,7 @@ struct Conv1DGenerator
// Its possible we failed to create the Fma.
if (!llvm::all_of(resVals, [](Value v) { return v; })) {
// Manually revert (in reverse order) to avoid leaving a bad IR state.
+ // TODO: Replace with maybeMasked
for (auto &collection :
{resVals, rhsVals, lhsVals, {res, rhs, lhs, zero}})
for (Value v : collection)
@@ -3117,8 +3197,8 @@ struct Conv1DGenerator
// Write back res slice: {n, wSizeStep, c} @ [0, w, 0].
// This does not depend on kw.
for (int64_t w = 0; w < wSize; w += wSizeStep) {
- res = rewriter.create<vector::InsertStridedSliceOp>(
- loc, resVals[w], res,
+ maybeMaskedRES = rewriter.create<vector::InsertStridedSliceOp>(
+ loc, resVals[w], maybeMaskedRES->getResult(0),
/*offsets=*/ArrayRef<int64_t>{0, w, 0},
/*strides=*/ArrayRef<int64_t>{1, 1, 1});
}
@@ -3127,10 +3207,12 @@ struct Conv1DGenerator
//===------------------------------------------------------------------===//
// Write back res slice of size {n, w, c} @ [0, 0, 0].
- return rewriter
- .create<vector::TransferWriteOp>(loc, res, resShaped,
- ValueRange{zero, zero, zero})
- .getOperation();
+ Operation *resOut = rewriter.create<vector::TransferWriteOp>(
+ loc, maybeMaskedRES->getResult(0), resShaped,
+ ValueRange{zero, zero, zero});
+ return maybeMaskXferOp(resType.getShape(),
+ /*scalableDims=*/{false, false, scalableChDim},
+ resOut);
}
/// Lower:
@@ -3171,8 +3253,9 @@ struct Conv1DGenerator
if (!lhs || !rhs)
return nullptr;
- if (isa<FloatType>(resTy.getElementType()))
+ if (isa<FloatType>(resTy.getElementType())) {
return rewriter.create<vector::FMAOp>(loc, lhs, rhs, res);
+ }
auto mul = rewriter.create<arith::MulIOp>(loc, lhs, rhs);
return rewriter.create<arith::AddIOp>(loc, mul, res);
@@ -3268,7 +3351,8 @@ struct Conv1DGenerator
/// Entry point that transposes into the common form:
/// {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
- FailureOr<Operation *> generateDilatedConv(bool flatten = false) {
+ FailureOr<Operation *> generateDilatedConv(uint64_t vecChDimSize = 0,
+ bool flatten = false) {
AffineExpr n, w, c, kw;
bindDims(ctx, n, w, c, kw);
if (!iters({Par(), Par(), Par(), Red()}))
@@ -3279,7 +3363,7 @@ struct Conv1DGenerator
if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c},
/*rhsIndex*/ {kw, c},
/*resIndex*/ {n, w, c}}))
- return depthwiseConv(flatten);
+ return depthwiseConv(vecChDimSize, flatten);
return rewriter.notifyMatchFailure(op, "not a depthwise::Nwc layout");
}
@@ -3291,6 +3375,7 @@ struct Conv1DGenerator
StringAttr redOp;
StringAttr poolExtOp;
bool isPoolExt = false;
+ bool hasTensorSemantics = false;
int strideW, dilationW;
Value lhsShaped, rhsShaped, resShaped;
ShapedType lhsShapedType, rhsShapedType, resShapedType;
@@ -3346,6 +3431,7 @@ struct Conv1DGenerator
// TODO: extend the generic vectorization to support windows and drop this.
static FailureOr<Operation *>
vectorizeConvolution(RewriterBase &rewriter, LinalgOp op,
+ ArrayRef<int64_t> inputVecSizes,
bool flatten1DDepthwiseConv) {
// The ConvolutionOpInterface gives us guarantees of existence for
// strides/dilations. However, we do not need to rely on those, we can simply
@@ -3371,7 +3457,14 @@ vectorizeConvolution(RewriterBase &rewriter, LinalgOp op,
res = e.generateNcwPooling();
if (succeeded(res))
return res;
- return e.generateDilatedConv(flatten1DDepthwiseConv);
+
+ uint64_t vecChDimSize = ShapedType::kDynamic;
+ if (!inputVecSizes.empty()) {
+ // Only use the input vector size corresponding to the channel dim. Other
+ // vector dims will be inferred from the Ops.
+ vecChDimSize = inputVecSizes[2];
+ }
+ return e.generateDilatedConv(vecChDimSize, flatten1DDepthwiseConv);
}
struct VectorizeConvolution : public OpInterfaceRewritePattern<LinalgOp> {
diff --git a/mlir/test/Dialect/Linalg/vectorize-conv-scalable.mlir b/mlir/test/Dialect/Linalg/vectorize-conv-scalable.mlir
new file mode 100644
index 00000000000000..d4b3574451c2bf
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/vectorize-conv-scalable.mlir
@@ -0,0 +1,161 @@
+// RUN: mlir-opt -split-input-file -transform-interpreter %s | FileCheck %s
+
+func.func @depthwise_conv1d_nwc_wc_1x8x3xi8_tensor(%input: tensor<1x8x?xi8>,
+ %filter: tensor<1x?xi8>,
+ %output: tensor<1x8x?xi8>) -> (tensor<1x8x?xi8>) {
+ %res = linalg.depthwise_conv_1d_nwc_wc
+ {dilations = dense<1> : vector<1xi64>,
+ strides = dense<1> : vector<1xi64>}
+ ins(%input, %filter : tensor<1x8x?xi8>, tensor<1x?xi8>)
+ outs(%output : tensor<1x8x?xi8>) -> tensor<1x8x?xi8>
+ return %res : tensor<1x8x?xi8>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.depthwise_conv_1d_nwc_wc"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ transform.structured.vectorize %0 vector_sizes [1, 8, [4], 1] : !transform.any_op
+ transform.yield
+ }
+}
+
+// CHECK-LABEL: func.func @depthwise_conv1d_nwc_wc_1x8x3xi8_tensor(
+// CHECK-SAME: %[[INPUT:.*]]: tensor<1x8x?xi8>,
+// CHECK-SAME: %[[FILTER:.*]]: tensor<1x?xi8>,
+// CHECK-SAME: %[[OUTPUT:.*]]: tensor<1x8x?xi8>) -> tensor<1x8x?xi8> {
+
+// CHECK-DAG: arith.constant 1 : index
+// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[VAL_6:.*]] = tensor.dim %[[FILTER]], %[[VAL_5]] : tensor<1x?xi8>
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C0_I8:.*]] = arith.constant 0 : i8
+
+/// Create a mask for the input tensor
+// CHECK: %[[C2:.*]] = arith.constant 2 : index
+// CHECK: %[[CH_DIM_SIZE_INPUT:.*]] = tensor.dim %[[INPUT]], %[[C2]] : tensor<1x8x?xi8>
+// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK: %[[C8:.*]] = arith.constant 8 : index
+// CHECK: %[[MASK_IN:.*]] = vector.create_mask %[[C1]], %[[C8]], %[[CH_DIM_SIZE_INPUT]] : vector<1x8x[4]xi1>
+/// Read the input tensor
+// CHECK: %[[VEC_IN:.*]] = vector.mask %[[MASK_IN]] { vector.transfer_read %[[INPUT]]{{\[}}%[[C0]], %[[C0]], %[[C0]]], %[[C0_I8]] : tensor<1x8x?xi8>, vector<1x8x[4]xi8> } : vector<1x8x[4]xi1> -> vector<1x8x[4]xi8>
+
+/// Create a mask for the filter tensor
+// CHECK: %[[C0_I8_1:.*]] = arith.constant 0 : i8
+// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK: %[[CH_DIM_SIZE_FLT:.*]] = tensor.dim %[[FILTER]], %[[C1]] : tensor<1x?xi8>
+// CHECK: %[[C1_1:.*]] = arith.constant 1 : index
+// CHECK: %[[MASK_FLT:.*]] = vector.create_mask %[[C1_1]], %[[CH_DIM_SIZE_FLT]] : vector<1x[4]xi1>
+/// Read the filter tensor
+// CHECK: %[[VEC_FLT:.*]] = vector.mask %[[MASK_FLT]] { vector.transfer_read %[[FILTER]]{{\[}}%[[C0]], %[[C0]]], %[[C0_I8_1]] : tensor<1x?xi8>, vector<1x[4]xi8> } : vector<1x[4]xi1> -> vector<1x[4]xi8>
+
+/// Create a mask for the output tensor
+// CHECK: %[[VAL_22:.*]] = arith.constant 0 : i8
+// CHECK: %[[VAL_23:.*]] = arith.constant 2 : index
+// CHECK: %[[CH_DIM_SIZE_OUT:.*]] = tensor.dim %[[OUTPUT]], %[[VAL_23]] : tensor<1x8x?xi8>
+// CHECK: %[[VAL_25:.*]] = arith.constant 1 : index
+// CHECK: %[[VAL_26:.*]] = arith.constant 8 : index
+// CHECK: %[[MASK_OUT:.*]] = vector.create_mask %[[VAL_25]], %[[VAL_26]], %[[CH_DIM_SIZE_OUT]] : vector<1x8x[4]xi1>
+/// Read the output tensor
+// CHECK: %[[VEC_OUT:.*]] = vector.mask %[[MASK_OUT]] { vector.transfer_read %[[OUTPUT]]{{\[}}%[[C0]], %[[C0]], %[[C0]]], %[[VAL_22]] : tensor<1x8x?xi8>, vector<1x8x[4]xi8> } : vector<1x8x[4]xi1> -> vector<1x8x[4]xi8>
+
+/// Convolution
+// CHECK: %[[VEC_IN_0:.*]] = vector.extract_strided_slice %[[VEC_IN]] {offsets = [0, 0, 0], sizes = [1, 8, 4], strides = [1, 1, 1]} : vector<1x8x[4]xi8> to vector<1x8x[4]xi8>
+// CHECK: %[[VEC_FLT_0:.*]] = vector.extract %[[VEC_FLT]][0] : vector<[4]xi8> from vector<1x[4]xi8>
+// CHECK: %[[VEC_OUT_0:.*]] = vector.extract_strided_slice %[[VEC_OUT]] {offsets = [0, 0, 0], sizes = [1, 8, 4], strides = [1, 1, 1]} : vector<1x8x[4]xi8> to vector<1x8x[4]xi8>
+// CHECK: %[[FLT_B:.*]] = vector.broadcast %[[VEC_FLT_0]] : vector<[4]xi8> to vector<1x8x[4]xi8>
+// CHECK: %[[MULI:.*]] = arith.muli %[[VEC_IN_0]], %[[FLT_B]] : vector<1x8x[4]xi8>
+// CHECK: %[[ADDI:.*]] = arith.addi %[[MULI]], %[[VEC_OUT_0]] : vector<1x8x[4]xi8>
+// CHECK: %[[VEC_OUT_1:.*]] = vector.insert_strided_slice %[[ADDI]], %[[VEC_OUT]] {offsets = [0, 0, 0], strides = [1, 1, 1]} : vector<1x8x[4]xi8> into vector<1x8x[4]xi8>
+
+/// Create a mask for the output tensor
+// CHECK: %[[VAL_36:.*]] = arith.constant 2 : index
+// CHECK: %[[VAL_37:.*]] = tensor.dim %[[OUTPUT]], %[[VAL_36]] : tensor<1x8x?xi8>
+// CHECK: %[[VAL_38:.*]] = arith.constant 1 : index
+// CHECK: %[[VAL_39:.*]] = arith.constant 8 : index
+// CHECK: %[[MASK_OUT:.*]] = vector.create_mask %[[VAL_38]], %[[VAL_39]], %[[VAL_37]] : vector<1x8x[4]xi1>
+
+/// Write the output tensor
+// CHECK: vector.mask %[[MASK_OUT]] { vector.transfer_write %[[VEC_OUT_1]], %[[OUTPUT]]{{\[}}%[[C0]], %[[C0]], %[[C0]]] : vector<1x8x[4]xi8>, tensor<1x8x?xi8> } : vector<1x8x[4]xi1> -> tensor<1x8x?xi8>
+
+
+// -----
+
+func.func @depthwise_conv1d_nwc_wc_3x5x4xf32_memref_dillation_2(%input: memref<3x5x?xf32>,
+ %filter: memref<2x?xf32>,
+ %output: memref<3x2x?xf32>) {
+ linalg.depthwise_conv_1d_nwc_wc
+ {dilations = dense<2> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}
+ ins(%input, %filter : memref<3x5x?xf32>, memref<2x?xf32>)
+ outs(%output : memref<3x2x?xf32>)
+ return
+}
+
+// CHECK-LABEL: func.func @depthwise_conv1d_nwc_wc_3x5x4xf32_memref_dillation_2(
+// CHECK-SAME: %[[INPUT:.*]]: memref<3x5x?xf32>,
+// CHECK-SAME: %[[FILTER:.*]]: memref<2x?xf32>,
+// CHECK-SAME: %[[OUTPUT:.*]]: memref<3x2x?xf32>) {
+
+// CHECK: %[[VAL_3:.*]] = arith.constant 3 : index
+// CHECK: %[[VAL_4:.*]] = arith.constant 2 : index
+// CHECK: %[[VAL_5:.*]] = arith.constant 1 : index
+// CHECK: %[[VAL_6:.*]] = memref.dim %[[FILTER]], %[[VAL_5]] : memref<2x?xf32>
+// CHECK: %[[VAL_7:.*]] = arith.constant 2 : index
+// CHECK: %[[VAL_8:.*]] = arith.constant 0 : index
+// CHECK: %[[VAL_9:.*]] = arith.constant 0.000000e+00 : f32
+
+/// Create a mask for the input tensor
+// CHECK: %[[VAL_10:.*]] = arith.constant 2 : index
+// CHECK: %[[VAL_11:.*]] = memref.dim %[[INPUT]], %[[VAL_10]] : memref<3x5x?xf32>
+// CHECK: %[[VAL_12:.*]] = arith.constant 3 : index
+// CHECK: %[[VAL_13:.*]] = arith.constant 5 : index
+// CHECK: %[[MASK_IN:.*]] = vector.create_mask %[[VAL_12]], %[[VAL_13]], %[[VAL_11]] : vector<3x4x[4]xi1>
+/// Read the input tensor
+// CHECK: %[[VEC_IN:.*]] = vector.mask %[[MASK_IN]] { vector.transfer_read %[[INPUT]]{{\[}}%[[VAL_8]], %[[VAL_8]], %[[VAL_8]]], %[[...
[truncated]
|
@llvm/pr-subscribers-mlir-linalg Author: Andrzej Warzyński (banach-space) ChangesThis patch adds support for scalable vectorisation of depthwise 1D HWC Two major assumptions are made:
In terms of scalable vectorisation, this should be sufficient that cover Patch is 23.99 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/81625.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 2bd6929fea6142..aed96be5b6da00 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -54,6 +54,7 @@ using namespace mlir::linalg;
/// Try to vectorize `convOp` as a convolution.
static FailureOr<Operation *>
vectorizeConvolution(RewriterBase &rewriter, LinalgOp convOp,
+ ArrayRef<int64_t> inputVecSizes = {},
bool flatten1DDepthwiseConv = false);
/// Return the unique instance of OpType in `block` if it is indeed unique.
@@ -1609,6 +1610,19 @@ static LogicalResult reductionPreconditions(LinalgOp op) {
}
static LogicalResult vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op) {
+ // Support dynamic shapes in 1D depthwise convolution, but only in the
+ // _channel_ dimension. That's exclusively to support scalable vectorisation.
+ if (auto conv = dyn_cast<linalg::DepthwiseConv1DNwcWcOp>(op.getOperation())) {
+ auto lhsShaped = op.getDpsInputOperand(0)->get();
+ ArrayRef<int64_t> lhsShape =
+ dyn_cast<ShapedType>(lhsShaped.getType()).getShape();
+ auto shapeWithoutCh = lhsShape.drop_back(1);
+ if (ShapedType::isDynamicShape(shapeWithoutCh))
+ return failure();
+
+ return success();
+ }
+
// TODO: Masking only supports dynamic element-wise ops, linalg.generic ops,
// linalg.copy ops and ops that implement ContractionOpInterface for now.
if (!isElementwise(op) &&
@@ -1789,7 +1803,8 @@ vectorizeScalableVectorPrecondition(Operation *op,
// Only element-wise ops supported in the presence of scalable dims.
auto linalgOp = dyn_cast<LinalgOp>(op);
- return success(linalgOp && isElementwise(linalgOp));
+ return success(linalgOp && (isElementwise(linalgOp) ||
+ isa<linalg::DepthwiseConv1DNwcWcOp>(op)));
}
LogicalResult mlir::linalg::vectorizeOpPrecondition(
@@ -1871,7 +1886,7 @@ LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
// features. Will require stride/dilation attributes inference.
if (isa<ConvolutionOpInterface>(linalgOp.getOperation())) {
FailureOr<Operation *> convOr = vectorizeConvolution(
- rewriter, linalgOp, flatten1DDepthwiseConv);
+ rewriter, linalgOp, inputVectorSizes, flatten1DDepthwiseConv);
if (succeeded(convOr)) {
llvm::append_range(results, (*convOr)->getResults());
return success();
@@ -2697,6 +2712,7 @@ struct Conv1DGenerator
return;
break;
}
+ hasTensorSemantics = linalgOp.hasPureTensorSemantics();
// The op is now known to be valid.
valid = true;
}
@@ -3000,13 +3016,21 @@ struct Conv1DGenerator
/// kw is always unrolled.
/// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is
/// > 1.
- FailureOr<Operation *> depthwiseConv(bool flatten) {
+ FailureOr<Operation *> depthwiseConv(uint64_t channelDimVecSize,
+ bool flatten) {
if (!valid)
return rewriter.notifyMatchFailure(op, "unvectorizable depthwise conv");
+ bool scalableChDim = false;
int64_t nSize, wSize, cSize, kwSize;
// kernel{kw, c}
bindShapeDims(rhsShapedType, kwSize, cSize);
+ // Dynamic channel size implies scalable vectorisation
+ if (ShapedType::isDynamic(cSize)) {
+ assert(channelDimVecSize != 0 && "Channel dim vec size must be > 0");
+ cSize = channelDimVecSize;
+ scalableChDim = true;
+ }
// out{n, w, c}
bindShapeDims(resShapedType, nSize, wSize);
@@ -3027,20 +3051,74 @@ struct Conv1DGenerator
// (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14)
((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) - 1,
cSize},
- lhsEltType);
- VectorType rhsType = VectorType::get({kwSize, cSize}, rhsEltType);
- VectorType resType = VectorType::get({nSize, wSize, cSize}, resEltType);
+ lhsEltType, {false, false, scalableChDim});
+ VectorType rhsType =
+ VectorType::get({kwSize, cSize}, rhsEltType,
+ /*scalableDims=*/{false, scalableChDim});
+ VectorType resType =
+ VectorType::get({nSize, wSize, cSize}, resEltType,
+ /*scalableDims=*/{false, false, scalableChDim});
+
+ // Masks the input xfer Op along the channel dim, iff the corresponding
+ // scalable flag is set.
+ auto maybeMaskXferOp = [&](ArrayRef<int64_t> maskShape,
+ ArrayRef<bool> scalableDims,
+ Operation *opToMask) {
+ bool scalableChDim = scalableDims.back();
+ if (!scalableChDim)
+ return opToMask;
+
+ auto maskType =
+ VectorType::get(maskShape, rewriter.getI1Type(), scalableDims);
+
+ SmallVector<OpFoldResult> mixedSourceDims =
+ hasTensorSemantics
+ ? TypeSwitch<Operation *, SmallVector<OpFoldResult>>(opToMask)
+ .Case<vector::TransferReadOp>([&](auto readOp) {
+ return tensor::getMixedSizes(rewriter, loc,
+ readOp.getSource());
+ })
+ .Case<vector::TransferWriteOp>([&](auto writeOp) {
+ return tensor::getMixedSizes(rewriter, loc,
+ writeOp.getOperand(1));
+ })
+ : TypeSwitch<Operation *, SmallVector<OpFoldResult>>(opToMask)
+ .Case<vector::TransferReadOp>([&](auto readOp) {
+ return memref::getMixedSizes(rewriter, loc,
+ readOp.getSource());
+ })
+ .Case<vector::TransferWriteOp>([&](auto writeOp) {
+ return memref::getMixedSizes(rewriter, loc,
+ writeOp.getOperand(1));
+ });
+
+ Value maskOp =
+ rewriter.create<vector::CreateMaskOp>(loc, maskType, mixedSourceDims);
+
+ return mlir::vector::maskOperation(rewriter, opToMask, maskOp);
+ };
// Read lhs slice of size {n, w * strideW + kw * dilationW, c} @ [0, 0,
// 0].
Value lhs = rewriter.create<vector::TransferReadOp>(
loc, lhsType, lhsShaped, ValueRange{zero, zero, zero});
+ auto maybeMaskedLHS = maybeMaskXferOp(
+ lhsType.getShape(),
+ /*scalableDims=*/{false, false, scalableChDim}, lhs.getDefiningOp());
+
// Read rhs slice of size {kw, c} @ [0, 0].
Value rhs = rewriter.create<vector::TransferReadOp>(loc, rhsType, rhsShaped,
ValueRange{zero, zero});
+ auto maybeMaskedRHS = maybeMaskXferOp(
+ rhsType.getShape(),
+ /*scalableDims=*/{false, scalableChDim}, rhs.getDefiningOp());
+
// Read res slice of size {n, w, c} @ [0, 0, 0].
Value res = rewriter.create<vector::TransferReadOp>(
loc, resType, resShaped, ValueRange{zero, zero, zero});
+ auto maybeMaskedRES = maybeMaskXferOp(
+ resType.getShape(),
+ /*scalableDims=*/{false, false, scalableChDim}, res.getDefiningOp());
//===------------------------------------------------------------------===//
// Begin vector-only rewrite part
@@ -3055,7 +3133,7 @@ struct Conv1DGenerator
for (int64_t kw = 0; kw < kwSize; ++kw) {
for (int64_t w = 0; w < wSize; w += wSizeStep) {
lhsVals.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
- loc, lhs,
+ loc, maybeMaskedLHS->getResult(0),
/*offsets=*/ArrayRef<int64_t>{0, w * strideW + kw * dilationW, 0},
inOutSliceSizes, inOutStrides));
}
@@ -3063,12 +3141,13 @@ struct Conv1DGenerator
// Extract rhs slice of size {c} @ [kw].
for (int64_t kw = 0; kw < kwSize; ++kw) {
rhsVals.push_back(rewriter.create<vector::ExtractOp>(
- loc, rhs, /*offsets=*/ArrayRef<int64_t>{kw}));
+ loc, maybeMaskedRHS->getResult(0),
+ /*offsets=*/ArrayRef<int64_t>{kw}));
}
// Extract res slice: {n, wSizeStep, c} @ [0, w, 0].
for (int64_t w = 0; w < wSize; w += wSizeStep) {
resVals.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
- loc, res,
+ loc, maybeMaskedRES->getResult(0),
/*offsets=*/ArrayRef<int64_t>{0, w, 0}, inOutSliceSizes,
inOutStrides));
}
@@ -3107,6 +3186,7 @@ struct Conv1DGenerator
// Its possible we failed to create the Fma.
if (!llvm::all_of(resVals, [](Value v) { return v; })) {
// Manually revert (in reverse order) to avoid leaving a bad IR state.
+ // TODO: Replace with maybeMasked
for (auto &collection :
{resVals, rhsVals, lhsVals, {res, rhs, lhs, zero}})
for (Value v : collection)
@@ -3117,8 +3197,8 @@ struct Conv1DGenerator
// Write back res slice: {n, wSizeStep, c} @ [0, w, 0].
// This does not depend on kw.
for (int64_t w = 0; w < wSize; w += wSizeStep) {
- res = rewriter.create<vector::InsertStridedSliceOp>(
- loc, resVals[w], res,
+ maybeMaskedRES = rewriter.create<vector::InsertStridedSliceOp>(
+ loc, resVals[w], maybeMaskedRES->getResult(0),
/*offsets=*/ArrayRef<int64_t>{0, w, 0},
/*strides=*/ArrayRef<int64_t>{1, 1, 1});
}
@@ -3127,10 +3207,12 @@ struct Conv1DGenerator
//===------------------------------------------------------------------===//
// Write back res slice of size {n, w, c} @ [0, 0, 0].
- return rewriter
- .create<vector::TransferWriteOp>(loc, res, resShaped,
- ValueRange{zero, zero, zero})
- .getOperation();
+ Operation *resOut = rewriter.create<vector::TransferWriteOp>(
+ loc, maybeMaskedRES->getResult(0), resShaped,
+ ValueRange{zero, zero, zero});
+ return maybeMaskXferOp(resType.getShape(),
+ /*scalableDims=*/{false, false, scalableChDim},
+ resOut);
}
/// Lower:
@@ -3171,8 +3253,9 @@ struct Conv1DGenerator
if (!lhs || !rhs)
return nullptr;
- if (isa<FloatType>(resTy.getElementType()))
+ if (isa<FloatType>(resTy.getElementType())) {
return rewriter.create<vector::FMAOp>(loc, lhs, rhs, res);
+ }
auto mul = rewriter.create<arith::MulIOp>(loc, lhs, rhs);
return rewriter.create<arith::AddIOp>(loc, mul, res);
@@ -3268,7 +3351,8 @@ struct Conv1DGenerator
/// Entry point that transposes into the common form:
/// {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
- FailureOr<Operation *> generateDilatedConv(bool flatten = false) {
+ FailureOr<Operation *> generateDilatedConv(uint64_t vecChDimSize = 0,
+ bool flatten = false) {
AffineExpr n, w, c, kw;
bindDims(ctx, n, w, c, kw);
if (!iters({Par(), Par(), Par(), Red()}))
@@ -3279,7 +3363,7 @@ struct Conv1DGenerator
if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c},
/*rhsIndex*/ {kw, c},
/*resIndex*/ {n, w, c}}))
- return depthwiseConv(flatten);
+ return depthwiseConv(vecChDimSize, flatten);
return rewriter.notifyMatchFailure(op, "not a depthwise::Nwc layout");
}
@@ -3291,6 +3375,7 @@ struct Conv1DGenerator
StringAttr redOp;
StringAttr poolExtOp;
bool isPoolExt = false;
+ bool hasTensorSemantics = false;
int strideW, dilationW;
Value lhsShaped, rhsShaped, resShaped;
ShapedType lhsShapedType, rhsShapedType, resShapedType;
@@ -3346,6 +3431,7 @@ struct Conv1DGenerator
// TODO: extend the generic vectorization to support windows and drop this.
static FailureOr<Operation *>
vectorizeConvolution(RewriterBase &rewriter, LinalgOp op,
+ ArrayRef<int64_t> inputVecSizes,
bool flatten1DDepthwiseConv) {
// The ConvolutionOpInterface gives us guarantees of existence for
// strides/dilations. However, we do not need to rely on those, we can simply
@@ -3371,7 +3457,14 @@ vectorizeConvolution(RewriterBase &rewriter, LinalgOp op,
res = e.generateNcwPooling();
if (succeeded(res))
return res;
- return e.generateDilatedConv(flatten1DDepthwiseConv);
+
+ uint64_t vecChDimSize = ShapedType::kDynamic;
+ if (!inputVecSizes.empty()) {
+ // Only use the input vector size corresponding to the channel dim. Other
+ // vector dims will be inferred from the Ops.
+ vecChDimSize = inputVecSizes[2];
+ }
+ return e.generateDilatedConv(vecChDimSize, flatten1DDepthwiseConv);
}
struct VectorizeConvolution : public OpInterfaceRewritePattern<LinalgOp> {
diff --git a/mlir/test/Dialect/Linalg/vectorize-conv-scalable.mlir b/mlir/test/Dialect/Linalg/vectorize-conv-scalable.mlir
new file mode 100644
index 00000000000000..d4b3574451c2bf
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/vectorize-conv-scalable.mlir
@@ -0,0 +1,161 @@
+// RUN: mlir-opt -split-input-file -transform-interpreter %s | FileCheck %s
+
+func.func @depthwise_conv1d_nwc_wc_1x8x3xi8_tensor(%input: tensor<1x8x?xi8>,
+ %filter: tensor<1x?xi8>,
+ %output: tensor<1x8x?xi8>) -> (tensor<1x8x?xi8>) {
+ %res = linalg.depthwise_conv_1d_nwc_wc
+ {dilations = dense<1> : vector<1xi64>,
+ strides = dense<1> : vector<1xi64>}
+ ins(%input, %filter : tensor<1x8x?xi8>, tensor<1x?xi8>)
+ outs(%output : tensor<1x8x?xi8>) -> tensor<1x8x?xi8>
+ return %res : tensor<1x8x?xi8>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.depthwise_conv_1d_nwc_wc"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ transform.structured.vectorize %0 vector_sizes [1, 8, [4], 1] : !transform.any_op
+ transform.yield
+ }
+}
+
+// CHECK-LABEL: func.func @depthwise_conv1d_nwc_wc_1x8x3xi8_tensor(
+// CHECK-SAME: %[[INPUT:.*]]: tensor<1x8x?xi8>,
+// CHECK-SAME: %[[FILTER:.*]]: tensor<1x?xi8>,
+// CHECK-SAME: %[[OUTPUT:.*]]: tensor<1x8x?xi8>) -> tensor<1x8x?xi8> {
+
+// CHECK-DAG: arith.constant 1 : index
+// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[VAL_6:.*]] = tensor.dim %[[FILTER]], %[[VAL_5]] : tensor<1x?xi8>
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C0_I8:.*]] = arith.constant 0 : i8
+
+/// Create a mask for the input tensor
+// CHECK: %[[C2:.*]] = arith.constant 2 : index
+// CHECK: %[[CH_DIM_SIZE_INPUT:.*]] = tensor.dim %[[INPUT]], %[[C2]] : tensor<1x8x?xi8>
+// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK: %[[C8:.*]] = arith.constant 8 : index
+// CHECK: %[[MASK_IN:.*]] = vector.create_mask %[[C1]], %[[C8]], %[[CH_DIM_SIZE_INPUT]] : vector<1x8x[4]xi1>
+/// Read the input tensor
+// CHECK: %[[VEC_IN:.*]] = vector.mask %[[MASK_IN]] { vector.transfer_read %[[INPUT]]{{\[}}%[[C0]], %[[C0]], %[[C0]]], %[[C0_I8]] : tensor<1x8x?xi8>, vector<1x8x[4]xi8> } : vector<1x8x[4]xi1> -> vector<1x8x[4]xi8>
+
+/// Create a mask for the filter tensor
+// CHECK: %[[C0_I8_1:.*]] = arith.constant 0 : i8
+// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK: %[[CH_DIM_SIZE_FLT:.*]] = tensor.dim %[[FILTER]], %[[C1]] : tensor<1x?xi8>
+// CHECK: %[[C1_1:.*]] = arith.constant 1 : index
+// CHECK: %[[MASK_FLT:.*]] = vector.create_mask %[[C1_1]], %[[CH_DIM_SIZE_FLT]] : vector<1x[4]xi1>
+/// Read the filter tensor
+// CHECK: %[[VEC_FLT:.*]] = vector.mask %[[MASK_FLT]] { vector.transfer_read %[[FILTER]]{{\[}}%[[C0]], %[[C0]]], %[[C0_I8_1]] : tensor<1x?xi8>, vector<1x[4]xi8> } : vector<1x[4]xi1> -> vector<1x[4]xi8>
+
+/// Create a mask for the output tensor
+// CHECK: %[[VAL_22:.*]] = arith.constant 0 : i8
+// CHECK: %[[VAL_23:.*]] = arith.constant 2 : index
+// CHECK: %[[CH_DIM_SIZE_OUT:.*]] = tensor.dim %[[OUTPUT]], %[[VAL_23]] : tensor<1x8x?xi8>
+// CHECK: %[[VAL_25:.*]] = arith.constant 1 : index
+// CHECK: %[[VAL_26:.*]] = arith.constant 8 : index
+// CHECK: %[[MASK_OUT:.*]] = vector.create_mask %[[VAL_25]], %[[VAL_26]], %[[CH_DIM_SIZE_OUT]] : vector<1x8x[4]xi1>
+/// Read the output tensor
+// CHECK: %[[VEC_OUT:.*]] = vector.mask %[[MASK_OUT]] { vector.transfer_read %[[OUTPUT]]{{\[}}%[[C0]], %[[C0]], %[[C0]]], %[[VAL_22]] : tensor<1x8x?xi8>, vector<1x8x[4]xi8> } : vector<1x8x[4]xi1> -> vector<1x8x[4]xi8>
+
+/// Convolution
+// CHECK: %[[VEC_IN_0:.*]] = vector.extract_strided_slice %[[VEC_IN]] {offsets = [0, 0, 0], sizes = [1, 8, 4], strides = [1, 1, 1]} : vector<1x8x[4]xi8> to vector<1x8x[4]xi8>
+// CHECK: %[[VEC_FLT_0:.*]] = vector.extract %[[VEC_FLT]][0] : vector<[4]xi8> from vector<1x[4]xi8>
+// CHECK: %[[VEC_OUT_0:.*]] = vector.extract_strided_slice %[[VEC_OUT]] {offsets = [0, 0, 0], sizes = [1, 8, 4], strides = [1, 1, 1]} : vector<1x8x[4]xi8> to vector<1x8x[4]xi8>
+// CHECK: %[[FLT_B:.*]] = vector.broadcast %[[VEC_FLT_0]] : vector<[4]xi8> to vector<1x8x[4]xi8>
+// CHECK: %[[MULI:.*]] = arith.muli %[[VEC_IN_0]], %[[FLT_B]] : vector<1x8x[4]xi8>
+// CHECK: %[[ADDI:.*]] = arith.addi %[[MULI]], %[[VEC_OUT_0]] : vector<1x8x[4]xi8>
+// CHECK: %[[VEC_OUT_1:.*]] = vector.insert_strided_slice %[[ADDI]], %[[VEC_OUT]] {offsets = [0, 0, 0], strides = [1, 1, 1]} : vector<1x8x[4]xi8> into vector<1x8x[4]xi8>
+
+/// Create a mask for the output tensor
+// CHECK: %[[VAL_36:.*]] = arith.constant 2 : index
+// CHECK: %[[VAL_37:.*]] = tensor.dim %[[OUTPUT]], %[[VAL_36]] : tensor<1x8x?xi8>
+// CHECK: %[[VAL_38:.*]] = arith.constant 1 : index
+// CHECK: %[[VAL_39:.*]] = arith.constant 8 : index
+// CHECK: %[[MASK_OUT:.*]] = vector.create_mask %[[VAL_38]], %[[VAL_39]], %[[VAL_37]] : vector<1x8x[4]xi1>
+
+/// Write the output tensor
+// CHECK: vector.mask %[[MASK_OUT]] { vector.transfer_write %[[VEC_OUT_1]], %[[OUTPUT]]{{\[}}%[[C0]], %[[C0]], %[[C0]]] : vector<1x8x[4]xi8>, tensor<1x8x?xi8> } : vector<1x8x[4]xi1> -> tensor<1x8x?xi8>
+
+
+// -----
+
+func.func @depthwise_conv1d_nwc_wc_3x5x4xf32_memref_dillation_2(%input: memref<3x5x?xf32>,
+ %filter: memref<2x?xf32>,
+ %output: memref<3x2x?xf32>) {
+ linalg.depthwise_conv_1d_nwc_wc
+ {dilations = dense<2> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}
+ ins(%input, %filter : memref<3x5x?xf32>, memref<2x?xf32>)
+ outs(%output : memref<3x2x?xf32>)
+ return
+}
+
+// CHECK-LABEL: func.func @depthwise_conv1d_nwc_wc_3x5x4xf32_memref_dillation_2(
+// CHECK-SAME: %[[INPUT:.*]]: memref<3x5x?xf32>,
+// CHECK-SAME: %[[FILTER:.*]]: memref<2x?xf32>,
+// CHECK-SAME: %[[OUTPUT:.*]]: memref<3x2x?xf32>) {
+
+// CHECK: %[[VAL_3:.*]] = arith.constant 3 : index
+// CHECK: %[[VAL_4:.*]] = arith.constant 2 : index
+// CHECK: %[[VAL_5:.*]] = arith.constant 1 : index
+// CHECK: %[[VAL_6:.*]] = memref.dim %[[FILTER]], %[[VAL_5]] : memref<2x?xf32>
+// CHECK: %[[VAL_7:.*]] = arith.constant 2 : index
+// CHECK: %[[VAL_8:.*]] = arith.constant 0 : index
+// CHECK: %[[VAL_9:.*]] = arith.constant 0.000000e+00 : f32
+
+/// Create a mask for the input tensor
+// CHECK: %[[VAL_10:.*]] = arith.constant 2 : index
+// CHECK: %[[VAL_11:.*]] = memref.dim %[[INPUT]], %[[VAL_10]] : memref<3x5x?xf32>
+// CHECK: %[[VAL_12:.*]] = arith.constant 3 : index
+// CHECK: %[[VAL_13:.*]] = arith.constant 5 : index
+// CHECK: %[[MASK_IN:.*]] = vector.create_mask %[[VAL_12]], %[[VAL_13]], %[[VAL_11]] : vector<3x4x[4]xi1>
+/// Read the input tensor
+// CHECK: %[[VEC_IN:.*]] = vector.mask %[[MASK_IN]] { vector.transfer_read %[[INPUT]]{{\[}}%[[VAL_8]], %[[VAL_8]], %[[VAL_8]]], %[[...
[truncated]
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Taking a first look!
bool scalableChDim = false; | ||
int64_t nSize, wSize, cSize, kwSize; | ||
// kernel{kw, c} | ||
bindShapeDims(rhsShapedType, kwSize, cSize); | ||
// Dynamic channel size implies scalable vectorisation | ||
if (ShapedType::isDynamic(cSize)) { | ||
assert(channelDimVecSize != 0 && "Channel dim vec size must be > 0"); | ||
cSize = channelDimVecSize; | ||
scalableChDim = true; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why a dynamic channel dimension implies scalable vectors? We have to make sure this also works for non scalable cases.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder this too (though I'm not very familiar with the linalg vectorizer). Could you also use scalable vectors if your channel dimension was large (say 100 elements), but still a static size?
Would it be possible to control this with the by passing a the scalable dims flags for the vector sizes, like is the case with matmuls?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why a dynamic channel dimension implies scalable vectors?
Because I run out of time this week an that's the case I need the most :) I will send an update shortly to add support for more generic case too. It's actually quite easy. The most labor-intensive task is writing tests for these convs.
Could you also use scalable vectors if your channel dimension was large (say 100 elements), but still a static size?
Yes, we already do that for elementwise Ops. However, you need to tile first - that will lead to tensors with dynamic shapes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the reviews, sending updates shortly!
bool scalableChDim = false; | ||
int64_t nSize, wSize, cSize, kwSize; | ||
// kernel{kw, c} | ||
bindShapeDims(rhsShapedType, kwSize, cSize); | ||
// Dynamic channel size implies scalable vectorisation | ||
if (ShapedType::isDynamic(cSize)) { | ||
assert(channelDimVecSize != 0 && "Channel dim vec size must be > 0"); | ||
cSize = channelDimVecSize; | ||
scalableChDim = true; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why a dynamic channel dimension implies scalable vectors?
Because I run out of time this week an that's the case I need the most :) I will send an update shortly to add support for more generic case too. It's actually quite easy. The most labor-intensive task is writing tests for these convs.
Could you also use scalable vectors if your channel dimension was large (say 100 elements), but still a static size?
Yes, we already do that for elementwise Ops. However, you need to tile first - that will lead to tensors with dynamic shapes.
%filter: memref<2x?xf32>, | ||
%output: memref<3x2x?xf32>) { | ||
linalg.depthwise_conv_1d_nwc_wc | ||
{dilations = dense<2> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what's the significance of a dilation of 2? Is dilation explained anywhere? I can't see anything in the linalg docs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've been struggling to find good documentation for this. This book provides a nice description:
See Figure 10.3.
The reason for adding this case here is to demonstrate that this patch is not doing anything that would only work for the default case (dilation = 1). But I appreciate that it's a bit tricky to follow without knowing the semantics of the Op itself. Is there anything I can do to make it clearer? You could try comparing against:
llvm-project/mlir/test/Dialect/Linalg/vectorize-convolution-flatten.mlir
Lines 51 to 108 in 982e902
func.func @depthwise_conv1d_nwc_wc_3x5x4xf32_memref_dillation_2(%input: memref<3x5x4xf32>, | |
%filter: memref<2x4xf32>, | |
%output: memref<3x2x4xf32>) { | |
linalg.depthwise_conv_1d_nwc_wc | |
{dilations = dense<2> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} | |
ins(%input, %filter : memref<3x5x4xf32>, memref<2x4xf32>) | |
outs(%output : memref<3x2x4xf32>) | |
return | |
} | |
// CHECK: func @depthwise_conv1d_nwc_wc_3x5x4xf32_memref_dillation_2 | |
// CHECK-SAME: (%[[INPUT:[0-9a-z]+]]: memref<3x5x4xf32>, %[[FILTER:[0-9a-z]+]]: memref<2x4xf32>, %[[OUTPUT:[0-9a-z]+]]: memref<3x2x4xf32>) | |
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index | |
// CHECK-DAG: %[[F0:.+]] = arith.constant 0.000000e+00 : f32 | |
/// Read the whole data in one shot. | |
// CHECK-DAG: %[[V_INPUT_R:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C0]], %[[C0]]] | |
// CHECK-DAG: %[[V_FILTER_R:.+]] = vector.transfer_read %[[FILTER]][%[[C0]], %[[C0]]] | |
// CHECK-DAG: %[[V_OUTPUT_R:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]] | |
// CHECK: %[[V_INPUT_0:.+]] = vector.extract_strided_slice %[[V_INPUT_R]] | |
// CHECK-SAME: {offsets = [0, 0, 0], sizes = [3, 2, 4], strides = [1, 1, 1]} : vector<3x4x4xf32> to vector<3x2x4xf32> | |
// CHECK: %[[V_INPUT_1:.+]] = vector.extract_strided_slice %[[V_INPUT_R]] | |
// CHECK-SAME: {offsets = [0, 2, 0], sizes = [3, 2, 4], strides = [1, 1, 1]} : vector<3x4x4xf32> to vector<3x2x4xf32> | |
// CHECK: %[[V_FILTER_0:.+]] = vector.extract %[[V_FILTER_R]][0] : vector<4xf32> from vector<2x4xf32> | |
// CHECK: %[[V_FILTER_1:.+]] = vector.extract %[[V_FILTER_R]][1] : vector<4xf32> from vector<2x4xf32> | |
/// w == 0, kw = 0 | |
// CHECK: %[[SC_V_INPUT_0:.*]] = vector.shape_cast %[[V_INPUT_0]] : vector<3x2x4xf32> to vector<3x8xf32> | |
// CHECK: %[[SC_V_OUTPUT_R:.*]] = vector.shape_cast %[[V_OUTPUT_R]] : vector<3x2x4xf32> to vector<3x8xf32> | |
// CHECK: %[[SH_FILTER_0:.*]] = vector.shuffle %[[V_FILTER_0]], %[[V_FILTER_0]] | |
// CHECK-SAME: [0, 1, 2, 3, 0, 1, 2, 3] : vector<4xf32>, vector<4xf32> | |
// CHECK: %[[B_FILTER_0:.*]] = vector.broadcast %[[SH_FILTER_0]] : vector<8xf32> to vector<3x8xf32> | |
// CHECK: %[[FMA_0:.*]] = vector.fma %[[SC_V_INPUT_0]], %[[B_FILTER_0]], %[[SC_V_OUTPUT_R]] : vector<3x8xf32> | |
/// w == 0, kw = 1 | |
// CHECK: %[[SC_V_INPUT_1:.*]] = vector.shape_cast %[[V_INPUT_1]] : vector<3x2x4xf32> to vector<3x8xf32> | |
// CHECK: %[[SH_FILTER_1:.*]] = vector.shuffle %[[V_FILTER_1]], %[[V_FILTER_1]] | |
// CHECK-SAME: [0, 1, 2, 3, 0, 1, 2, 3] : vector<4xf32>, vector<4xf32> | |
// CHECK: %[[B_FILTER_1:.*]] = vector.broadcast %[[SH_FILTER_1]] : vector<8xf32> to vector<3x8xf32> | |
// CHECK: %[[FMA_1:.*]] = vector.fma %[[SC_V_INPUT_1]], %[[B_FILTER_1]], %[[FMA_0]] : vector<3x8xf32> | |
// Write the result back in one shot. | |
// CHECK: %[[SC_FMA_1:.*]] = vector.shape_cast %[[FMA_1]] : vector<3x8xf32> to vector<3x2x4xf32> | |
// CHECK: vector.transfer_write %[[SC_FMA_1]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]] | |
module attributes {transform.with_named_sequence} { | |
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { | |
%0 = transform.structured.match ops{["linalg.depthwise_conv_1d_nwc_wc"]} in %arg0 : (!transform.any_op) -> !transform.any_op | |
%1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op | |
%2 = transform.structured.vectorize_children_and_apply_patterns %1 {flatten_1d_depthwise_conv} : (!transform.any_op) -> !transform.any_op | |
transform.yield | |
} | |
} |
But note that I wasn't trying to keep the check-lines consistent and, more importantly, that is testing "flattened" convs.
Kind ping :) I believe that I've addressed all your comments - could you take another look? This is quite critical for me 🙏🏻 😅 |
5e7090a
to
625eef5
Compare
✅ With the latest revision this PR passed the C/C++ code formatter. |
625eef5
to
30d407a
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've left a few final comments but otherwise this LGTM, cheers. I would appreciate an integration test as well, but that doesn't need to be in this patch.
mlir/test/Dialect/Linalg/vectorize-conv-masked-and-scalable.mlir
Outdated
Show resolved
Hide resolved
// CHECK: %[[VEC_OUT:.*]] = vector.mask %[[MASK_OUT]] { vector.transfer_read %[[OUTPUT]]{{\[}}%[[C0]], %[[C0]], %[[C0]]], %[[PAD]] : tensor<1x8x?xi8>, vector<1x8x4xi8> } : vector<1x8x4xi1> -> vector<1x8x4xi8> | ||
|
||
/// Convolution | ||
// CHECK: %[[IN_1:.*]] = vector.extract_strided_slice %[[VEC_IN]] {offsets = [0, 0, 0], sizes = [1, 8, 4], strides = [1, 1, 1]} : vector<1x8x4xi8> to vector<1x8x4xi8> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the insert/extract strided slice operations in this test and the scalable equivalent are not necessary, with canonicalization these are removed which I think makes it easier to follow, as well as making it more obvious what is happening for the final test where the dilation is 2 and the extract_strided_slice
is necessary.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree that the extra indirection is somewhat confusing, but I deliberately avoid canonicalization in these tests - otherwise we wouldn't be testing the vectoriser. In fact, we have introduced dedicated TD ops to separate:
- tests specifically for the vectoriser: transform.structured.vectorize
- tests for the vectoriser + various patterns: transform.structured.vectorize_children_and_apply_patterns
In particular, once canonicalizations are included, the generated output will depend on many other parts of MLIR and I want to avoid that. Having said that, we could add more tests with patterns.
} | ||
// Extract res slice: {n, wSizeStep, c} @ [0, w, 0]. | ||
for (int64_t w = 0; w < wSize; w += wSizeStep) { | ||
resVals.push_back(rewriter.create<vector::ExtractStridedSliceOp>( | ||
loc, res, | ||
loc, maybeMaskedRes->getResult(0), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logic below this (for when flatten=true
) looks like it will drop the scalable dims:
auto inOutFlattenSliceSizes =
SmallVector<int64_t>{nSize, wSizeStep * cSize};
auto lhsCastType = VectorType::get(inOutFlattenSliceSizes, lhsEltType);
auto resCastType = VectorType::get(inOutFlattenSliceSizes, resEltType);
// Compute contraction: O{n, w, c} += I{n, sw * w + dw * kw, c} * F{c}
You may want to update that (or bail out if early if flatten=true
).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's a great catch, thanks! This should be captured by one of the pre-conditions and that's what I'll do.
We don't have a mechanism within TD to trigger this case, which means it's hard to test it. And I'm not sure whether extending TD just to test things is the right thing 🤔 I need to think of a use-case other than testing 😅
This patch adds support for scalable vectorisation of depthwise 1D HWC convolutions,`linalg.depthwise_conv_1d_nwc_wc`. This is implemented by adding support for masking. Two major assumptions are made: * only the channel dimension can be scalable/dynamic (i.e. the trailing dim), * when specifying vector sizes to use in the vectoriser, only the size corresponding to the channel dim is effectively used (other dims are inferred from the context). In terms of scalable vectorisation, this should be sufficient that cover all practical cases (i.e. making arbitrary dim scalable wouldn't make much sense). As for more generic cases with dynamic shapes (e.g. w or n dims being dynamic), more work would be needed. In particular, one would have to consider the filter and input/ouput tensors separately. However, it's not clear whether that would be of any use in practice.
Addressing Cullen's comments
…utions Addressing PR comments: - add CSE in tests, update check-lines accordingly - add support for plain (non-scalable) masked vectorisation - moved pre-conditions for vectorisation to a dedicated hook
…onvolutions Address comments from Ben and Crefeda
…ions Better LIT var names in tests
…ions More documentaiton, some simplification (as per Cullen's comments)
…ions Address Diego's comments, move to vector utils
…utions Generalise the code a tiny bit to cover for 1D depthwise NCW convs (once supported by the vectoriser).
30d407a
to
2a8ce8a
Compare
…utions * Add missing dyn dimension in a test * Make sure "flattening" + "masked vectorisation" are not allowed
2a8ce8a
to
b367636
Compare
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, for enablement purposes. I understand the criticality of this and I don't want to block the development but, as discussed (also offline with Nicolas), we know that this is a short-term specialization and that all the convolution support in the vectorizer needs to be revisited. As such, it may happen that some of this functionality has to be disabled/revisited in the future if it hinders a more generic path forward. I would also be concerned if there are plans to continue extending this path so let's make sure we discuss those and plan for a better approach if that is the case.
Thanks all for reviewing, I will be merging this shortly (it's quite critical for us - we can't leverage wider SVE vectors otherwise).
I would be very keen to participate in any such refactor. Lets continue brainstorming about this!
I am more than happy to redesign/refactor this if need be 👍🏻 |
Follow-up for #81625 Relands #85225 with a minor update to the RUN line to fix buildbot failures: ```diff -// RUN: %{compile} | %{run} | FileCheck %s +// RUN: rm -f %t && %{compile} && %{run} | FileCheck %s ``` Failing buildbots after landing #85225: * https://lab.llvm.org/buildbot/#/builders/184/builds/11363 * https://lab.llvm.org/buildbot/#/builders/176/builds/9331
…d22716224 Local branch amd-gfx 02fd227 Merged main:4f873730d6ac1a8496cdef939cc451f178a864ee into amd-gfx:f078619cd03f Remote branch main c56bd7a [mlir][linalg] Enable masked vectorisation for depthwise convolutions (llvm#81625)
This patch adds support for masked vectorisation of depthwise 1D WC
convolutions,
linalg.depthwise_conv_1d_nwc_wc
. This is implemented byadding support for masking.
Two major assumptions are made:
trailing dim),
corresponding to the channel dim is effectively used (other dims are
inferred from the context).
In terms of scalable vectorisation, this should be sufficient to cover
all practical cases (i.e. making arbitrary dim scalable wouldn't make
much sense). As for more generic cases with dynamic shapes (e.g. W or N
dims being dynamic), more work would be needed. In particular, one would
have to consider the filter and input/output tensors separately.