Skip to content

[mlir][linalg] Enable masked vectorisation for depthwise convolutions #81625

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Mar 14, 2024

Conversation

banach-space
Copy link
Contributor

@banach-space banach-space commented Feb 13, 2024

This patch adds support for masked vectorisation of depthwise 1D WC
convolutions,linalg.depthwise_conv_1d_nwc_wc. This is implemented by
adding support for masking.

Two major assumptions are made:

  • only the channel dimension can be dynamic/scalable (i.e. the
    trailing dim),
  • when specifying vector sizes to use in the vectoriser, only the size
    corresponding to the channel dim is effectively used (other dims are
    inferred from the context).

In terms of scalable vectorisation, this should be sufficient to cover
all practical cases (i.e. making arbitrary dim scalable wouldn't make
much sense). As for more generic cases with dynamic shapes (e.g. W or N
dims being dynamic), more work would be needed. In particular, one would
have to consider the filter and input/output tensors separately.

@llvmbot
Copy link
Member

llvmbot commented Feb 13, 2024

@llvm/pr-subscribers-mlir-vector

@llvm/pr-subscribers-mlir

Author: Andrzej Warzyński (banach-space)

Changes

This patch adds support for scalable vectorisation of depthwise 1D HWC
convolutions,linalg.depthwise_conv_1d_nwc_wc. This is implemented by
adding support for masking.

Two major assumptions are made:

  • only the channel dimension can be scalable/dynamic (i.e. the
    trailing dim),
  • when specifying vector sizes to use in the vectoriser, only the size
    corresponding to the channel dim is effectively used (other dims are
    inferred from the context).

In terms of scalable vectorisation, this should be sufficient that cover
all practical cases (i.e. making arbitrary dim scalable wouldn't make
much sense). As for more generic cases with dynamic shapes (e.g. w or n
dims being dynamic), more work would be needed. In particular, one would
have to consider the filter and input/ouput tensors separately. However,
it's not clear whether that would be of any use in practice.


Patch is 23.99 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/81625.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp (+112-19)
  • (added) mlir/test/Dialect/Linalg/vectorize-conv-scalable.mlir (+161)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 2bd6929fea6142..aed96be5b6da00 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -54,6 +54,7 @@ using namespace mlir::linalg;
 /// Try to vectorize `convOp` as a convolution.
 static FailureOr<Operation *>
 vectorizeConvolution(RewriterBase &rewriter, LinalgOp convOp,
+                     ArrayRef<int64_t> inputVecSizes = {},
                      bool flatten1DDepthwiseConv = false);
 
 /// Return the unique instance of OpType in `block` if it is indeed unique.
@@ -1609,6 +1610,19 @@ static LogicalResult reductionPreconditions(LinalgOp op) {
 }
 
 static LogicalResult vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op) {
+  // Support dynamic shapes in 1D depthwise convolution, but only in the
+  // _channel_ dimension. That's exclusively to support scalable vectorisation.
+  if (auto conv = dyn_cast<linalg::DepthwiseConv1DNwcWcOp>(op.getOperation())) {
+    auto lhsShaped = op.getDpsInputOperand(0)->get();
+    ArrayRef<int64_t> lhsShape =
+        dyn_cast<ShapedType>(lhsShaped.getType()).getShape();
+    auto shapeWithoutCh = lhsShape.drop_back(1);
+    if (ShapedType::isDynamicShape(shapeWithoutCh))
+      return failure();
+
+    return success();
+  }
+
   // TODO: Masking only supports dynamic element-wise ops, linalg.generic ops,
   // linalg.copy ops and ops that implement ContractionOpInterface for now.
   if (!isElementwise(op) &&
@@ -1789,7 +1803,8 @@ vectorizeScalableVectorPrecondition(Operation *op,
 
   // Only element-wise ops supported in the presence of scalable dims.
   auto linalgOp = dyn_cast<LinalgOp>(op);
-  return success(linalgOp && isElementwise(linalgOp));
+  return success(linalgOp && (isElementwise(linalgOp) ||
+                              isa<linalg::DepthwiseConv1DNwcWcOp>(op)));
 }
 
 LogicalResult mlir::linalg::vectorizeOpPrecondition(
@@ -1871,7 +1886,7 @@ LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
             // features. Will require stride/dilation attributes inference.
             if (isa<ConvolutionOpInterface>(linalgOp.getOperation())) {
               FailureOr<Operation *> convOr = vectorizeConvolution(
-                  rewriter, linalgOp, flatten1DDepthwiseConv);
+                  rewriter, linalgOp, inputVectorSizes, flatten1DDepthwiseConv);
               if (succeeded(convOr)) {
                 llvm::append_range(results, (*convOr)->getResults());
                 return success();
@@ -2697,6 +2712,7 @@ struct Conv1DGenerator
         return;
       break;
     }
+    hasTensorSemantics = linalgOp.hasPureTensorSemantics();
     // The op is now known to be valid.
     valid = true;
   }
@@ -3000,13 +3016,21 @@ struct Conv1DGenerator
   /// kw is always unrolled.
   /// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is
   /// > 1.
-  FailureOr<Operation *> depthwiseConv(bool flatten) {
+  FailureOr<Operation *> depthwiseConv(uint64_t channelDimVecSize,
+                                       bool flatten) {
     if (!valid)
       return rewriter.notifyMatchFailure(op, "unvectorizable depthwise conv");
 
+    bool scalableChDim = false;
     int64_t nSize, wSize, cSize, kwSize;
     // kernel{kw, c}
     bindShapeDims(rhsShapedType, kwSize, cSize);
+    // Dynamic channel size implies scalable vectorisation
+    if (ShapedType::isDynamic(cSize)) {
+      assert(channelDimVecSize != 0 && "Channel dim vec size must be > 0");
+      cSize = channelDimVecSize;
+      scalableChDim = true;
+    }
     // out{n, w, c}
     bindShapeDims(resShapedType, nSize, wSize);
 
@@ -3027,20 +3051,74 @@ struct Conv1DGenerator
          //   (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14)
          ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) - 1,
          cSize},
-        lhsEltType);
-    VectorType rhsType = VectorType::get({kwSize, cSize}, rhsEltType);
-    VectorType resType = VectorType::get({nSize, wSize, cSize}, resEltType);
+        lhsEltType, {false, false, scalableChDim});
+    VectorType rhsType =
+        VectorType::get({kwSize, cSize}, rhsEltType,
+                        /*scalableDims=*/{false, scalableChDim});
+    VectorType resType =
+        VectorType::get({nSize, wSize, cSize}, resEltType,
+                        /*scalableDims=*/{false, false, scalableChDim});
+
+    // Masks the input xfer Op along the channel dim, iff the corresponding
+    // scalable flag is set.
+    auto maybeMaskXferOp = [&](ArrayRef<int64_t> maskShape,
+                               ArrayRef<bool> scalableDims,
+                               Operation *opToMask) {
+      bool scalableChDim = scalableDims.back();
+      if (!scalableChDim)
+        return opToMask;
+
+      auto maskType =
+          VectorType::get(maskShape, rewriter.getI1Type(), scalableDims);
+
+      SmallVector<OpFoldResult> mixedSourceDims =
+          hasTensorSemantics
+              ? TypeSwitch<Operation *, SmallVector<OpFoldResult>>(opToMask)
+                    .Case<vector::TransferReadOp>([&](auto readOp) {
+                      return tensor::getMixedSizes(rewriter, loc,
+                                                   readOp.getSource());
+                    })
+                    .Case<vector::TransferWriteOp>([&](auto writeOp) {
+                      return tensor::getMixedSizes(rewriter, loc,
+                                                   writeOp.getOperand(1));
+                    })
+              : TypeSwitch<Operation *, SmallVector<OpFoldResult>>(opToMask)
+                    .Case<vector::TransferReadOp>([&](auto readOp) {
+                      return memref::getMixedSizes(rewriter, loc,
+                                                   readOp.getSource());
+                    })
+                    .Case<vector::TransferWriteOp>([&](auto writeOp) {
+                      return memref::getMixedSizes(rewriter, loc,
+                                                   writeOp.getOperand(1));
+                    });
+
+      Value maskOp =
+          rewriter.create<vector::CreateMaskOp>(loc, maskType, mixedSourceDims);
+
+      return mlir::vector::maskOperation(rewriter, opToMask, maskOp);
+    };
 
     // Read lhs slice of size {n, w * strideW + kw * dilationW, c} @ [0, 0,
     // 0].
     Value lhs = rewriter.create<vector::TransferReadOp>(
         loc, lhsType, lhsShaped, ValueRange{zero, zero, zero});
+    auto maybeMaskedLHS = maybeMaskXferOp(
+        lhsType.getShape(),
+        /*scalableDims=*/{false, false, scalableChDim}, lhs.getDefiningOp());
+
     // Read rhs slice of size {kw, c} @ [0, 0].
     Value rhs = rewriter.create<vector::TransferReadOp>(loc, rhsType, rhsShaped,
                                                         ValueRange{zero, zero});
+    auto maybeMaskedRHS = maybeMaskXferOp(
+        rhsType.getShape(),
+        /*scalableDims=*/{false, scalableChDim}, rhs.getDefiningOp());
+
     // Read res slice of size {n, w, c} @ [0, 0, 0].
     Value res = rewriter.create<vector::TransferReadOp>(
         loc, resType, resShaped, ValueRange{zero, zero, zero});
+    auto maybeMaskedRES = maybeMaskXferOp(
+        resType.getShape(),
+        /*scalableDims=*/{false, false, scalableChDim}, res.getDefiningOp());
 
     //===------------------------------------------------------------------===//
     // Begin vector-only rewrite part
@@ -3055,7 +3133,7 @@ struct Conv1DGenerator
     for (int64_t kw = 0; kw < kwSize; ++kw) {
       for (int64_t w = 0; w < wSize; w += wSizeStep) {
         lhsVals.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
-            loc, lhs,
+            loc, maybeMaskedLHS->getResult(0),
             /*offsets=*/ArrayRef<int64_t>{0, w * strideW + kw * dilationW, 0},
             inOutSliceSizes, inOutStrides));
       }
@@ -3063,12 +3141,13 @@ struct Conv1DGenerator
     // Extract rhs slice of size {c} @ [kw].
     for (int64_t kw = 0; kw < kwSize; ++kw) {
       rhsVals.push_back(rewriter.create<vector::ExtractOp>(
-          loc, rhs, /*offsets=*/ArrayRef<int64_t>{kw}));
+          loc, maybeMaskedRHS->getResult(0),
+          /*offsets=*/ArrayRef<int64_t>{kw}));
     }
     // Extract res slice: {n, wSizeStep, c} @ [0, w, 0].
     for (int64_t w = 0; w < wSize; w += wSizeStep) {
       resVals.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
-          loc, res,
+          loc, maybeMaskedRES->getResult(0),
           /*offsets=*/ArrayRef<int64_t>{0, w, 0}, inOutSliceSizes,
           inOutStrides));
     }
@@ -3107,6 +3186,7 @@ struct Conv1DGenerator
     // Its possible we failed to create the Fma.
     if (!llvm::all_of(resVals, [](Value v) { return v; })) {
       // Manually revert (in reverse order) to avoid leaving a bad IR state.
+      // TODO: Replace with maybeMasked
       for (auto &collection :
            {resVals, rhsVals, lhsVals, {res, rhs, lhs, zero}})
         for (Value v : collection)
@@ -3117,8 +3197,8 @@ struct Conv1DGenerator
     // Write back res slice: {n, wSizeStep, c} @ [0, w, 0].
     // This does not depend on kw.
     for (int64_t w = 0; w < wSize; w += wSizeStep) {
-      res = rewriter.create<vector::InsertStridedSliceOp>(
-          loc, resVals[w], res,
+      maybeMaskedRES = rewriter.create<vector::InsertStridedSliceOp>(
+          loc, resVals[w], maybeMaskedRES->getResult(0),
           /*offsets=*/ArrayRef<int64_t>{0, w, 0},
           /*strides=*/ArrayRef<int64_t>{1, 1, 1});
     }
@@ -3127,10 +3207,12 @@ struct Conv1DGenerator
     //===------------------------------------------------------------------===//
 
     // Write back res slice of size {n, w, c} @ [0, 0, 0].
-    return rewriter
-        .create<vector::TransferWriteOp>(loc, res, resShaped,
-                                         ValueRange{zero, zero, zero})
-        .getOperation();
+    Operation *resOut = rewriter.create<vector::TransferWriteOp>(
+        loc, maybeMaskedRES->getResult(0), resShaped,
+        ValueRange{zero, zero, zero});
+    return maybeMaskXferOp(resType.getShape(),
+                           /*scalableDims=*/{false, false, scalableChDim},
+                           resOut);
   }
 
   /// Lower:
@@ -3171,8 +3253,9 @@ struct Conv1DGenerator
     if (!lhs || !rhs)
       return nullptr;
 
-    if (isa<FloatType>(resTy.getElementType()))
+    if (isa<FloatType>(resTy.getElementType())) {
       return rewriter.create<vector::FMAOp>(loc, lhs, rhs, res);
+    }
 
     auto mul = rewriter.create<arith::MulIOp>(loc, lhs, rhs);
     return rewriter.create<arith::AddIOp>(loc, mul, res);
@@ -3268,7 +3351,8 @@ struct Conv1DGenerator
 
   /// Entry point that transposes into the common form:
   ///   {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
-  FailureOr<Operation *> generateDilatedConv(bool flatten = false) {
+  FailureOr<Operation *> generateDilatedConv(uint64_t vecChDimSize = 0,
+                                             bool flatten = false) {
     AffineExpr n, w, c, kw;
     bindDims(ctx, n, w, c, kw);
     if (!iters({Par(), Par(), Par(), Red()}))
@@ -3279,7 +3363,7 @@ struct Conv1DGenerator
     if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c},
                 /*rhsIndex*/ {kw, c},
                 /*resIndex*/ {n, w, c}}))
-      return depthwiseConv(flatten);
+      return depthwiseConv(vecChDimSize, flatten);
 
     return rewriter.notifyMatchFailure(op, "not a depthwise::Nwc layout");
   }
@@ -3291,6 +3375,7 @@ struct Conv1DGenerator
   StringAttr redOp;
   StringAttr poolExtOp;
   bool isPoolExt = false;
+  bool hasTensorSemantics = false;
   int strideW, dilationW;
   Value lhsShaped, rhsShaped, resShaped;
   ShapedType lhsShapedType, rhsShapedType, resShapedType;
@@ -3346,6 +3431,7 @@ struct Conv1DGenerator
 // TODO: extend the generic vectorization to support windows and drop this.
 static FailureOr<Operation *>
 vectorizeConvolution(RewriterBase &rewriter, LinalgOp op,
+                     ArrayRef<int64_t> inputVecSizes,
                      bool flatten1DDepthwiseConv) {
   // The ConvolutionOpInterface gives us guarantees of existence for
   // strides/dilations. However, we do not need to rely on those, we can simply
@@ -3371,7 +3457,14 @@ vectorizeConvolution(RewriterBase &rewriter, LinalgOp op,
   res = e.generateNcwPooling();
   if (succeeded(res))
     return res;
-  return e.generateDilatedConv(flatten1DDepthwiseConv);
+
+  uint64_t vecChDimSize = ShapedType::kDynamic;
+  if (!inputVecSizes.empty()) {
+    // Only use the input vector size corresponding to the channel dim. Other
+    // vector dims will be inferred from the Ops.
+    vecChDimSize = inputVecSizes[2];
+  }
+  return e.generateDilatedConv(vecChDimSize, flatten1DDepthwiseConv);
 }
 
 struct VectorizeConvolution : public OpInterfaceRewritePattern<LinalgOp> {
diff --git a/mlir/test/Dialect/Linalg/vectorize-conv-scalable.mlir b/mlir/test/Dialect/Linalg/vectorize-conv-scalable.mlir
new file mode 100644
index 00000000000000..d4b3574451c2bf
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/vectorize-conv-scalable.mlir
@@ -0,0 +1,161 @@
+// RUN: mlir-opt -split-input-file -transform-interpreter %s | FileCheck %s
+
+func.func @depthwise_conv1d_nwc_wc_1x8x3xi8_tensor(%input: tensor<1x8x?xi8>,
+                                                   %filter: tensor<1x?xi8>,
+                                                   %output: tensor<1x8x?xi8>) -> (tensor<1x8x?xi8>) {
+  %res = linalg.depthwise_conv_1d_nwc_wc
+    {dilations = dense<1> : vector<1xi64>,
+    strides = dense<1> : vector<1xi64>}
+    ins(%input, %filter : tensor<1x8x?xi8>, tensor<1x?xi8>)
+    outs(%output : tensor<1x8x?xi8>) -> tensor<1x8x?xi8>
+  return %res : tensor<1x8x?xi8>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.depthwise_conv_1d_nwc_wc"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    transform.structured.vectorize %0 vector_sizes [1, 8, [4], 1] : !transform.any_op
+    transform.yield
+  }
+}
+
+// CHECK-LABEL:   func.func @depthwise_conv1d_nwc_wc_1x8x3xi8_tensor(
+// CHECK-SAME:      %[[INPUT:.*]]: tensor<1x8x?xi8>,
+// CHECK-SAME:      %[[FILTER:.*]]: tensor<1x?xi8>,
+// CHECK-SAME:      %[[OUTPUT:.*]]: tensor<1x8x?xi8>) -> tensor<1x8x?xi8> {
+
+// CHECK-DAG:       arith.constant 1 : index
+// CHECK-DAG:       %[[VAL_5:.*]] = arith.constant 1 : index
+// CHECK-DAG:       %[[VAL_6:.*]] = tensor.dim %[[FILTER]], %[[VAL_5]] : tensor<1x?xi8>
+// CHECK-DAG:       %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG:       %[[C0_I8:.*]] = arith.constant 0 : i8
+
+/// Create a mask for the input tensor
+// CHECK:           %[[C2:.*]] = arith.constant 2 : index
+// CHECK:           %[[CH_DIM_SIZE_INPUT:.*]] = tensor.dim %[[INPUT]], %[[C2]] : tensor<1x8x?xi8>
+// CHECK:           %[[C1:.*]] = arith.constant 1 : index
+// CHECK:           %[[C8:.*]] = arith.constant 8 : index
+// CHECK:           %[[MASK_IN:.*]] = vector.create_mask %[[C1]], %[[C8]], %[[CH_DIM_SIZE_INPUT]] : vector<1x8x[4]xi1>
+/// Read the input tensor
+// CHECK:           %[[VEC_IN:.*]] = vector.mask %[[MASK_IN]] { vector.transfer_read %[[INPUT]]{{\[}}%[[C0]], %[[C0]], %[[C0]]], %[[C0_I8]] : tensor<1x8x?xi8>, vector<1x8x[4]xi8> } : vector<1x8x[4]xi1> -> vector<1x8x[4]xi8>
+
+/// Create a mask for the filter tensor
+// CHECK:           %[[C0_I8_1:.*]] = arith.constant 0 : i8
+// CHECK:           %[[C1:.*]] = arith.constant 1 : index
+// CHECK:           %[[CH_DIM_SIZE_FLT:.*]] = tensor.dim %[[FILTER]], %[[C1]] : tensor<1x?xi8>
+// CHECK:           %[[C1_1:.*]] = arith.constant 1 : index
+// CHECK:           %[[MASK_FLT:.*]] = vector.create_mask %[[C1_1]], %[[CH_DIM_SIZE_FLT]] : vector<1x[4]xi1>
+/// Read the filter tensor
+// CHECK:           %[[VEC_FLT:.*]] = vector.mask %[[MASK_FLT]] { vector.transfer_read %[[FILTER]]{{\[}}%[[C0]], %[[C0]]], %[[C0_I8_1]] : tensor<1x?xi8>, vector<1x[4]xi8> } : vector<1x[4]xi1> -> vector<1x[4]xi8>
+
+/// Create a mask for the output tensor
+// CHECK:           %[[VAL_22:.*]] = arith.constant 0 : i8
+// CHECK:           %[[VAL_23:.*]] = arith.constant 2 : index
+// CHECK:           %[[CH_DIM_SIZE_OUT:.*]] = tensor.dim %[[OUTPUT]], %[[VAL_23]] : tensor<1x8x?xi8>
+// CHECK:           %[[VAL_25:.*]] = arith.constant 1 : index
+// CHECK:           %[[VAL_26:.*]] = arith.constant 8 : index
+// CHECK:           %[[MASK_OUT:.*]] = vector.create_mask %[[VAL_25]], %[[VAL_26]], %[[CH_DIM_SIZE_OUT]] : vector<1x8x[4]xi1>
+/// Read the output tensor
+// CHECK:           %[[VEC_OUT:.*]] = vector.mask %[[MASK_OUT]] { vector.transfer_read %[[OUTPUT]]{{\[}}%[[C0]], %[[C0]], %[[C0]]], %[[VAL_22]] : tensor<1x8x?xi8>, vector<1x8x[4]xi8> } : vector<1x8x[4]xi1> -> vector<1x8x[4]xi8>
+
+/// Convolution
+// CHECK:           %[[VEC_IN_0:.*]] = vector.extract_strided_slice %[[VEC_IN]] {offsets = [0, 0, 0], sizes = [1, 8, 4], strides = [1, 1, 1]} : vector<1x8x[4]xi8> to vector<1x8x[4]xi8>
+// CHECK:           %[[VEC_FLT_0:.*]] = vector.extract %[[VEC_FLT]][0] : vector<[4]xi8> from vector<1x[4]xi8>
+// CHECK:           %[[VEC_OUT_0:.*]] = vector.extract_strided_slice %[[VEC_OUT]] {offsets = [0, 0, 0], sizes = [1, 8, 4], strides = [1, 1, 1]} : vector<1x8x[4]xi8> to vector<1x8x[4]xi8>
+// CHECK:           %[[FLT_B:.*]] = vector.broadcast %[[VEC_FLT_0]] : vector<[4]xi8> to vector<1x8x[4]xi8>
+// CHECK:           %[[MULI:.*]] = arith.muli %[[VEC_IN_0]], %[[FLT_B]] : vector<1x8x[4]xi8>
+// CHECK:           %[[ADDI:.*]] = arith.addi %[[MULI]], %[[VEC_OUT_0]] : vector<1x8x[4]xi8>
+// CHECK:           %[[VEC_OUT_1:.*]] = vector.insert_strided_slice %[[ADDI]], %[[VEC_OUT]] {offsets = [0, 0, 0], strides = [1, 1, 1]} : vector<1x8x[4]xi8> into vector<1x8x[4]xi8>
+
+/// Create a mask for the output tensor
+// CHECK:           %[[VAL_36:.*]] = arith.constant 2 : index
+// CHECK:           %[[VAL_37:.*]] = tensor.dim %[[OUTPUT]], %[[VAL_36]] : tensor<1x8x?xi8>
+// CHECK:           %[[VAL_38:.*]] = arith.constant 1 : index
+// CHECK:           %[[VAL_39:.*]] = arith.constant 8 : index
+// CHECK:           %[[MASK_OUT:.*]] = vector.create_mask %[[VAL_38]], %[[VAL_39]], %[[VAL_37]] : vector<1x8x[4]xi1>
+
+/// Write the output tensor
+// CHECK:           vector.mask %[[MASK_OUT]] { vector.transfer_write %[[VEC_OUT_1]], %[[OUTPUT]]{{\[}}%[[C0]], %[[C0]], %[[C0]]] : vector<1x8x[4]xi8>, tensor<1x8x?xi8> } : vector<1x8x[4]xi1> -> tensor<1x8x?xi8>
+
+
+// -----
+
+func.func @depthwise_conv1d_nwc_wc_3x5x4xf32_memref_dillation_2(%input: memref<3x5x?xf32>,
+                                                                %filter: memref<2x?xf32>,
+                                                                %output: memref<3x2x?xf32>) {
+  linalg.depthwise_conv_1d_nwc_wc
+    {dilations = dense<2> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}
+    ins(%input, %filter : memref<3x5x?xf32>, memref<2x?xf32>)
+    outs(%output : memref<3x2x?xf32>)
+  return
+}
+
+// CHECK-LABEL:   func.func @depthwise_conv1d_nwc_wc_3x5x4xf32_memref_dillation_2(
+// CHECK-SAME:      %[[INPUT:.*]]: memref<3x5x?xf32>,
+// CHECK-SAME:      %[[FILTER:.*]]: memref<2x?xf32>,
+// CHECK-SAME:      %[[OUTPUT:.*]]: memref<3x2x?xf32>) {
+
+// CHECK:           %[[VAL_3:.*]] = arith.constant 3 : index
+// CHECK:           %[[VAL_4:.*]] = arith.constant 2 : index
+// CHECK:           %[[VAL_5:.*]] = arith.constant 1 : index
+// CHECK:           %[[VAL_6:.*]] = memref.dim %[[FILTER]], %[[VAL_5]] : memref<2x?xf32>
+// CHECK:           %[[VAL_7:.*]] = arith.constant 2 : index
+// CHECK:           %[[VAL_8:.*]] = arith.constant 0 : index
+// CHECK:           %[[VAL_9:.*]] = arith.constant 0.000000e+00 : f32
+
+/// Create a mask for the input tensor
+// CHECK:           %[[VAL_10:.*]] = arith.constant 2 : index
+// CHECK:           %[[VAL_11:.*]] = memref.dim %[[INPUT]], %[[VAL_10]] : memref<3x5x?xf32>
+// CHECK:           %[[VAL_12:.*]] = arith.constant 3 : index
+// CHECK:           %[[VAL_13:.*]] = arith.constant 5 : index
+// CHECK:           %[[MASK_IN:.*]] = vector.create_mask %[[VAL_12]], %[[VAL_13]], %[[VAL_11]] : vector<3x4x[4]xi1>
+/// Read the input tensor
+// CHECK:           %[[VEC_IN:.*]] = vector.mask %[[MASK_IN]] { vector.transfer_read %[[INPUT]]{{\[}}%[[VAL_8]], %[[VAL_8]], %[[VAL_8]]], %[[...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Feb 13, 2024

@llvm/pr-subscribers-mlir-linalg

Author: Andrzej Warzyński (banach-space)

Changes

This patch adds support for scalable vectorisation of depthwise 1D HWC
convolutions,linalg.depthwise_conv_1d_nwc_wc. This is implemented by
adding support for masking.

Two major assumptions are made:

  • only the channel dimension can be scalable/dynamic (i.e. the
    trailing dim),
  • when specifying vector sizes to use in the vectoriser, only the size
    corresponding to the channel dim is effectively used (other dims are
    inferred from the context).

In terms of scalable vectorisation, this should be sufficient that cover
all practical cases (i.e. making arbitrary dim scalable wouldn't make
much sense). As for more generic cases with dynamic shapes (e.g. w or n
dims being dynamic), more work would be needed. In particular, one would
have to consider the filter and input/ouput tensors separately. However,
it's not clear whether that would be of any use in practice.


Patch is 23.99 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/81625.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp (+112-19)
  • (added) mlir/test/Dialect/Linalg/vectorize-conv-scalable.mlir (+161)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 2bd6929fea6142..aed96be5b6da00 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -54,6 +54,7 @@ using namespace mlir::linalg;
 /// Try to vectorize `convOp` as a convolution.
 static FailureOr<Operation *>
 vectorizeConvolution(RewriterBase &rewriter, LinalgOp convOp,
+                     ArrayRef<int64_t> inputVecSizes = {},
                      bool flatten1DDepthwiseConv = false);
 
 /// Return the unique instance of OpType in `block` if it is indeed unique.
@@ -1609,6 +1610,19 @@ static LogicalResult reductionPreconditions(LinalgOp op) {
 }
 
 static LogicalResult vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op) {
+  // Support dynamic shapes in 1D depthwise convolution, but only in the
+  // _channel_ dimension. That's exclusively to support scalable vectorisation.
+  if (auto conv = dyn_cast<linalg::DepthwiseConv1DNwcWcOp>(op.getOperation())) {
+    auto lhsShaped = op.getDpsInputOperand(0)->get();
+    ArrayRef<int64_t> lhsShape =
+        dyn_cast<ShapedType>(lhsShaped.getType()).getShape();
+    auto shapeWithoutCh = lhsShape.drop_back(1);
+    if (ShapedType::isDynamicShape(shapeWithoutCh))
+      return failure();
+
+    return success();
+  }
+
   // TODO: Masking only supports dynamic element-wise ops, linalg.generic ops,
   // linalg.copy ops and ops that implement ContractionOpInterface for now.
   if (!isElementwise(op) &&
@@ -1789,7 +1803,8 @@ vectorizeScalableVectorPrecondition(Operation *op,
 
   // Only element-wise ops supported in the presence of scalable dims.
   auto linalgOp = dyn_cast<LinalgOp>(op);
-  return success(linalgOp && isElementwise(linalgOp));
+  return success(linalgOp && (isElementwise(linalgOp) ||
+                              isa<linalg::DepthwiseConv1DNwcWcOp>(op)));
 }
 
 LogicalResult mlir::linalg::vectorizeOpPrecondition(
@@ -1871,7 +1886,7 @@ LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
             // features. Will require stride/dilation attributes inference.
             if (isa<ConvolutionOpInterface>(linalgOp.getOperation())) {
               FailureOr<Operation *> convOr = vectorizeConvolution(
-                  rewriter, linalgOp, flatten1DDepthwiseConv);
+                  rewriter, linalgOp, inputVectorSizes, flatten1DDepthwiseConv);
               if (succeeded(convOr)) {
                 llvm::append_range(results, (*convOr)->getResults());
                 return success();
@@ -2697,6 +2712,7 @@ struct Conv1DGenerator
         return;
       break;
     }
+    hasTensorSemantics = linalgOp.hasPureTensorSemantics();
     // The op is now known to be valid.
     valid = true;
   }
@@ -3000,13 +3016,21 @@ struct Conv1DGenerator
   /// kw is always unrolled.
   /// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is
   /// > 1.
-  FailureOr<Operation *> depthwiseConv(bool flatten) {
+  FailureOr<Operation *> depthwiseConv(uint64_t channelDimVecSize,
+                                       bool flatten) {
     if (!valid)
       return rewriter.notifyMatchFailure(op, "unvectorizable depthwise conv");
 
+    bool scalableChDim = false;
     int64_t nSize, wSize, cSize, kwSize;
     // kernel{kw, c}
     bindShapeDims(rhsShapedType, kwSize, cSize);
+    // Dynamic channel size implies scalable vectorisation
+    if (ShapedType::isDynamic(cSize)) {
+      assert(channelDimVecSize != 0 && "Channel dim vec size must be > 0");
+      cSize = channelDimVecSize;
+      scalableChDim = true;
+    }
     // out{n, w, c}
     bindShapeDims(resShapedType, nSize, wSize);
 
@@ -3027,20 +3051,74 @@ struct Conv1DGenerator
          //   (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14)
          ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) - 1,
          cSize},
-        lhsEltType);
-    VectorType rhsType = VectorType::get({kwSize, cSize}, rhsEltType);
-    VectorType resType = VectorType::get({nSize, wSize, cSize}, resEltType);
+        lhsEltType, {false, false, scalableChDim});
+    VectorType rhsType =
+        VectorType::get({kwSize, cSize}, rhsEltType,
+                        /*scalableDims=*/{false, scalableChDim});
+    VectorType resType =
+        VectorType::get({nSize, wSize, cSize}, resEltType,
+                        /*scalableDims=*/{false, false, scalableChDim});
+
+    // Masks the input xfer Op along the channel dim, iff the corresponding
+    // scalable flag is set.
+    auto maybeMaskXferOp = [&](ArrayRef<int64_t> maskShape,
+                               ArrayRef<bool> scalableDims,
+                               Operation *opToMask) {
+      bool scalableChDim = scalableDims.back();
+      if (!scalableChDim)
+        return opToMask;
+
+      auto maskType =
+          VectorType::get(maskShape, rewriter.getI1Type(), scalableDims);
+
+      SmallVector<OpFoldResult> mixedSourceDims =
+          hasTensorSemantics
+              ? TypeSwitch<Operation *, SmallVector<OpFoldResult>>(opToMask)
+                    .Case<vector::TransferReadOp>([&](auto readOp) {
+                      return tensor::getMixedSizes(rewriter, loc,
+                                                   readOp.getSource());
+                    })
+                    .Case<vector::TransferWriteOp>([&](auto writeOp) {
+                      return tensor::getMixedSizes(rewriter, loc,
+                                                   writeOp.getOperand(1));
+                    })
+              : TypeSwitch<Operation *, SmallVector<OpFoldResult>>(opToMask)
+                    .Case<vector::TransferReadOp>([&](auto readOp) {
+                      return memref::getMixedSizes(rewriter, loc,
+                                                   readOp.getSource());
+                    })
+                    .Case<vector::TransferWriteOp>([&](auto writeOp) {
+                      return memref::getMixedSizes(rewriter, loc,
+                                                   writeOp.getOperand(1));
+                    });
+
+      Value maskOp =
+          rewriter.create<vector::CreateMaskOp>(loc, maskType, mixedSourceDims);
+
+      return mlir::vector::maskOperation(rewriter, opToMask, maskOp);
+    };
 
     // Read lhs slice of size {n, w * strideW + kw * dilationW, c} @ [0, 0,
     // 0].
     Value lhs = rewriter.create<vector::TransferReadOp>(
         loc, lhsType, lhsShaped, ValueRange{zero, zero, zero});
+    auto maybeMaskedLHS = maybeMaskXferOp(
+        lhsType.getShape(),
+        /*scalableDims=*/{false, false, scalableChDim}, lhs.getDefiningOp());
+
     // Read rhs slice of size {kw, c} @ [0, 0].
     Value rhs = rewriter.create<vector::TransferReadOp>(loc, rhsType, rhsShaped,
                                                         ValueRange{zero, zero});
+    auto maybeMaskedRHS = maybeMaskXferOp(
+        rhsType.getShape(),
+        /*scalableDims=*/{false, scalableChDim}, rhs.getDefiningOp());
+
     // Read res slice of size {n, w, c} @ [0, 0, 0].
     Value res = rewriter.create<vector::TransferReadOp>(
         loc, resType, resShaped, ValueRange{zero, zero, zero});
+    auto maybeMaskedRES = maybeMaskXferOp(
+        resType.getShape(),
+        /*scalableDims=*/{false, false, scalableChDim}, res.getDefiningOp());
 
     //===------------------------------------------------------------------===//
     // Begin vector-only rewrite part
@@ -3055,7 +3133,7 @@ struct Conv1DGenerator
     for (int64_t kw = 0; kw < kwSize; ++kw) {
       for (int64_t w = 0; w < wSize; w += wSizeStep) {
         lhsVals.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
-            loc, lhs,
+            loc, maybeMaskedLHS->getResult(0),
             /*offsets=*/ArrayRef<int64_t>{0, w * strideW + kw * dilationW, 0},
             inOutSliceSizes, inOutStrides));
       }
@@ -3063,12 +3141,13 @@ struct Conv1DGenerator
     // Extract rhs slice of size {c} @ [kw].
     for (int64_t kw = 0; kw < kwSize; ++kw) {
       rhsVals.push_back(rewriter.create<vector::ExtractOp>(
-          loc, rhs, /*offsets=*/ArrayRef<int64_t>{kw}));
+          loc, maybeMaskedRHS->getResult(0),
+          /*offsets=*/ArrayRef<int64_t>{kw}));
     }
     // Extract res slice: {n, wSizeStep, c} @ [0, w, 0].
     for (int64_t w = 0; w < wSize; w += wSizeStep) {
       resVals.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
-          loc, res,
+          loc, maybeMaskedRES->getResult(0),
           /*offsets=*/ArrayRef<int64_t>{0, w, 0}, inOutSliceSizes,
           inOutStrides));
     }
@@ -3107,6 +3186,7 @@ struct Conv1DGenerator
     // Its possible we failed to create the Fma.
     if (!llvm::all_of(resVals, [](Value v) { return v; })) {
       // Manually revert (in reverse order) to avoid leaving a bad IR state.
+      // TODO: Replace with maybeMasked
       for (auto &collection :
            {resVals, rhsVals, lhsVals, {res, rhs, lhs, zero}})
         for (Value v : collection)
@@ -3117,8 +3197,8 @@ struct Conv1DGenerator
     // Write back res slice: {n, wSizeStep, c} @ [0, w, 0].
     // This does not depend on kw.
     for (int64_t w = 0; w < wSize; w += wSizeStep) {
-      res = rewriter.create<vector::InsertStridedSliceOp>(
-          loc, resVals[w], res,
+      maybeMaskedRES = rewriter.create<vector::InsertStridedSliceOp>(
+          loc, resVals[w], maybeMaskedRES->getResult(0),
           /*offsets=*/ArrayRef<int64_t>{0, w, 0},
           /*strides=*/ArrayRef<int64_t>{1, 1, 1});
     }
@@ -3127,10 +3207,12 @@ struct Conv1DGenerator
     //===------------------------------------------------------------------===//
 
     // Write back res slice of size {n, w, c} @ [0, 0, 0].
-    return rewriter
-        .create<vector::TransferWriteOp>(loc, res, resShaped,
-                                         ValueRange{zero, zero, zero})
-        .getOperation();
+    Operation *resOut = rewriter.create<vector::TransferWriteOp>(
+        loc, maybeMaskedRES->getResult(0), resShaped,
+        ValueRange{zero, zero, zero});
+    return maybeMaskXferOp(resType.getShape(),
+                           /*scalableDims=*/{false, false, scalableChDim},
+                           resOut);
   }
 
   /// Lower:
@@ -3171,8 +3253,9 @@ struct Conv1DGenerator
     if (!lhs || !rhs)
       return nullptr;
 
-    if (isa<FloatType>(resTy.getElementType()))
+    if (isa<FloatType>(resTy.getElementType())) {
       return rewriter.create<vector::FMAOp>(loc, lhs, rhs, res);
+    }
 
     auto mul = rewriter.create<arith::MulIOp>(loc, lhs, rhs);
     return rewriter.create<arith::AddIOp>(loc, mul, res);
@@ -3268,7 +3351,8 @@ struct Conv1DGenerator
 
   /// Entry point that transposes into the common form:
   ///   {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
-  FailureOr<Operation *> generateDilatedConv(bool flatten = false) {
+  FailureOr<Operation *> generateDilatedConv(uint64_t vecChDimSize = 0,
+                                             bool flatten = false) {
     AffineExpr n, w, c, kw;
     bindDims(ctx, n, w, c, kw);
     if (!iters({Par(), Par(), Par(), Red()}))
@@ -3279,7 +3363,7 @@ struct Conv1DGenerator
     if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c},
                 /*rhsIndex*/ {kw, c},
                 /*resIndex*/ {n, w, c}}))
-      return depthwiseConv(flatten);
+      return depthwiseConv(vecChDimSize, flatten);
 
     return rewriter.notifyMatchFailure(op, "not a depthwise::Nwc layout");
   }
@@ -3291,6 +3375,7 @@ struct Conv1DGenerator
   StringAttr redOp;
   StringAttr poolExtOp;
   bool isPoolExt = false;
+  bool hasTensorSemantics = false;
   int strideW, dilationW;
   Value lhsShaped, rhsShaped, resShaped;
   ShapedType lhsShapedType, rhsShapedType, resShapedType;
@@ -3346,6 +3431,7 @@ struct Conv1DGenerator
 // TODO: extend the generic vectorization to support windows and drop this.
 static FailureOr<Operation *>
 vectorizeConvolution(RewriterBase &rewriter, LinalgOp op,
+                     ArrayRef<int64_t> inputVecSizes,
                      bool flatten1DDepthwiseConv) {
   // The ConvolutionOpInterface gives us guarantees of existence for
   // strides/dilations. However, we do not need to rely on those, we can simply
@@ -3371,7 +3457,14 @@ vectorizeConvolution(RewriterBase &rewriter, LinalgOp op,
   res = e.generateNcwPooling();
   if (succeeded(res))
     return res;
-  return e.generateDilatedConv(flatten1DDepthwiseConv);
+
+  uint64_t vecChDimSize = ShapedType::kDynamic;
+  if (!inputVecSizes.empty()) {
+    // Only use the input vector size corresponding to the channel dim. Other
+    // vector dims will be inferred from the Ops.
+    vecChDimSize = inputVecSizes[2];
+  }
+  return e.generateDilatedConv(vecChDimSize, flatten1DDepthwiseConv);
 }
 
 struct VectorizeConvolution : public OpInterfaceRewritePattern<LinalgOp> {
diff --git a/mlir/test/Dialect/Linalg/vectorize-conv-scalable.mlir b/mlir/test/Dialect/Linalg/vectorize-conv-scalable.mlir
new file mode 100644
index 00000000000000..d4b3574451c2bf
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/vectorize-conv-scalable.mlir
@@ -0,0 +1,161 @@
+// RUN: mlir-opt -split-input-file -transform-interpreter %s | FileCheck %s
+
+func.func @depthwise_conv1d_nwc_wc_1x8x3xi8_tensor(%input: tensor<1x8x?xi8>,
+                                                   %filter: tensor<1x?xi8>,
+                                                   %output: tensor<1x8x?xi8>) -> (tensor<1x8x?xi8>) {
+  %res = linalg.depthwise_conv_1d_nwc_wc
+    {dilations = dense<1> : vector<1xi64>,
+    strides = dense<1> : vector<1xi64>}
+    ins(%input, %filter : tensor<1x8x?xi8>, tensor<1x?xi8>)
+    outs(%output : tensor<1x8x?xi8>) -> tensor<1x8x?xi8>
+  return %res : tensor<1x8x?xi8>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.depthwise_conv_1d_nwc_wc"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    transform.structured.vectorize %0 vector_sizes [1, 8, [4], 1] : !transform.any_op
+    transform.yield
+  }
+}
+
+// CHECK-LABEL:   func.func @depthwise_conv1d_nwc_wc_1x8x3xi8_tensor(
+// CHECK-SAME:      %[[INPUT:.*]]: tensor<1x8x?xi8>,
+// CHECK-SAME:      %[[FILTER:.*]]: tensor<1x?xi8>,
+// CHECK-SAME:      %[[OUTPUT:.*]]: tensor<1x8x?xi8>) -> tensor<1x8x?xi8> {
+
+// CHECK-DAG:       arith.constant 1 : index
+// CHECK-DAG:       %[[VAL_5:.*]] = arith.constant 1 : index
+// CHECK-DAG:       %[[VAL_6:.*]] = tensor.dim %[[FILTER]], %[[VAL_5]] : tensor<1x?xi8>
+// CHECK-DAG:       %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG:       %[[C0_I8:.*]] = arith.constant 0 : i8
+
+/// Create a mask for the input tensor
+// CHECK:           %[[C2:.*]] = arith.constant 2 : index
+// CHECK:           %[[CH_DIM_SIZE_INPUT:.*]] = tensor.dim %[[INPUT]], %[[C2]] : tensor<1x8x?xi8>
+// CHECK:           %[[C1:.*]] = arith.constant 1 : index
+// CHECK:           %[[C8:.*]] = arith.constant 8 : index
+// CHECK:           %[[MASK_IN:.*]] = vector.create_mask %[[C1]], %[[C8]], %[[CH_DIM_SIZE_INPUT]] : vector<1x8x[4]xi1>
+/// Read the input tensor
+// CHECK:           %[[VEC_IN:.*]] = vector.mask %[[MASK_IN]] { vector.transfer_read %[[INPUT]]{{\[}}%[[C0]], %[[C0]], %[[C0]]], %[[C0_I8]] : tensor<1x8x?xi8>, vector<1x8x[4]xi8> } : vector<1x8x[4]xi1> -> vector<1x8x[4]xi8>
+
+/// Create a mask for the filter tensor
+// CHECK:           %[[C0_I8_1:.*]] = arith.constant 0 : i8
+// CHECK:           %[[C1:.*]] = arith.constant 1 : index
+// CHECK:           %[[CH_DIM_SIZE_FLT:.*]] = tensor.dim %[[FILTER]], %[[C1]] : tensor<1x?xi8>
+// CHECK:           %[[C1_1:.*]] = arith.constant 1 : index
+// CHECK:           %[[MASK_FLT:.*]] = vector.create_mask %[[C1_1]], %[[CH_DIM_SIZE_FLT]] : vector<1x[4]xi1>
+/// Read the filter tensor
+// CHECK:           %[[VEC_FLT:.*]] = vector.mask %[[MASK_FLT]] { vector.transfer_read %[[FILTER]]{{\[}}%[[C0]], %[[C0]]], %[[C0_I8_1]] : tensor<1x?xi8>, vector<1x[4]xi8> } : vector<1x[4]xi1> -> vector<1x[4]xi8>
+
+/// Create a mask for the output tensor
+// CHECK:           %[[VAL_22:.*]] = arith.constant 0 : i8
+// CHECK:           %[[VAL_23:.*]] = arith.constant 2 : index
+// CHECK:           %[[CH_DIM_SIZE_OUT:.*]] = tensor.dim %[[OUTPUT]], %[[VAL_23]] : tensor<1x8x?xi8>
+// CHECK:           %[[VAL_25:.*]] = arith.constant 1 : index
+// CHECK:           %[[VAL_26:.*]] = arith.constant 8 : index
+// CHECK:           %[[MASK_OUT:.*]] = vector.create_mask %[[VAL_25]], %[[VAL_26]], %[[CH_DIM_SIZE_OUT]] : vector<1x8x[4]xi1>
+/// Read the output tensor
+// CHECK:           %[[VEC_OUT:.*]] = vector.mask %[[MASK_OUT]] { vector.transfer_read %[[OUTPUT]]{{\[}}%[[C0]], %[[C0]], %[[C0]]], %[[VAL_22]] : tensor<1x8x?xi8>, vector<1x8x[4]xi8> } : vector<1x8x[4]xi1> -> vector<1x8x[4]xi8>
+
+/// Convolution
+// CHECK:           %[[VEC_IN_0:.*]] = vector.extract_strided_slice %[[VEC_IN]] {offsets = [0, 0, 0], sizes = [1, 8, 4], strides = [1, 1, 1]} : vector<1x8x[4]xi8> to vector<1x8x[4]xi8>
+// CHECK:           %[[VEC_FLT_0:.*]] = vector.extract %[[VEC_FLT]][0] : vector<[4]xi8> from vector<1x[4]xi8>
+// CHECK:           %[[VEC_OUT_0:.*]] = vector.extract_strided_slice %[[VEC_OUT]] {offsets = [0, 0, 0], sizes = [1, 8, 4], strides = [1, 1, 1]} : vector<1x8x[4]xi8> to vector<1x8x[4]xi8>
+// CHECK:           %[[FLT_B:.*]] = vector.broadcast %[[VEC_FLT_0]] : vector<[4]xi8> to vector<1x8x[4]xi8>
+// CHECK:           %[[MULI:.*]] = arith.muli %[[VEC_IN_0]], %[[FLT_B]] : vector<1x8x[4]xi8>
+// CHECK:           %[[ADDI:.*]] = arith.addi %[[MULI]], %[[VEC_OUT_0]] : vector<1x8x[4]xi8>
+// CHECK:           %[[VEC_OUT_1:.*]] = vector.insert_strided_slice %[[ADDI]], %[[VEC_OUT]] {offsets = [0, 0, 0], strides = [1, 1, 1]} : vector<1x8x[4]xi8> into vector<1x8x[4]xi8>
+
+/// Create a mask for the output tensor
+// CHECK:           %[[VAL_36:.*]] = arith.constant 2 : index
+// CHECK:           %[[VAL_37:.*]] = tensor.dim %[[OUTPUT]], %[[VAL_36]] : tensor<1x8x?xi8>
+// CHECK:           %[[VAL_38:.*]] = arith.constant 1 : index
+// CHECK:           %[[VAL_39:.*]] = arith.constant 8 : index
+// CHECK:           %[[MASK_OUT:.*]] = vector.create_mask %[[VAL_38]], %[[VAL_39]], %[[VAL_37]] : vector<1x8x[4]xi1>
+
+/// Write the output tensor
+// CHECK:           vector.mask %[[MASK_OUT]] { vector.transfer_write %[[VEC_OUT_1]], %[[OUTPUT]]{{\[}}%[[C0]], %[[C0]], %[[C0]]] : vector<1x8x[4]xi8>, tensor<1x8x?xi8> } : vector<1x8x[4]xi1> -> tensor<1x8x?xi8>
+
+
+// -----
+
+func.func @depthwise_conv1d_nwc_wc_3x5x4xf32_memref_dillation_2(%input: memref<3x5x?xf32>,
+                                                                %filter: memref<2x?xf32>,
+                                                                %output: memref<3x2x?xf32>) {
+  linalg.depthwise_conv_1d_nwc_wc
+    {dilations = dense<2> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}
+    ins(%input, %filter : memref<3x5x?xf32>, memref<2x?xf32>)
+    outs(%output : memref<3x2x?xf32>)
+  return
+}
+
+// CHECK-LABEL:   func.func @depthwise_conv1d_nwc_wc_3x5x4xf32_memref_dillation_2(
+// CHECK-SAME:      %[[INPUT:.*]]: memref<3x5x?xf32>,
+// CHECK-SAME:      %[[FILTER:.*]]: memref<2x?xf32>,
+// CHECK-SAME:      %[[OUTPUT:.*]]: memref<3x2x?xf32>) {
+
+// CHECK:           %[[VAL_3:.*]] = arith.constant 3 : index
+// CHECK:           %[[VAL_4:.*]] = arith.constant 2 : index
+// CHECK:           %[[VAL_5:.*]] = arith.constant 1 : index
+// CHECK:           %[[VAL_6:.*]] = memref.dim %[[FILTER]], %[[VAL_5]] : memref<2x?xf32>
+// CHECK:           %[[VAL_7:.*]] = arith.constant 2 : index
+// CHECK:           %[[VAL_8:.*]] = arith.constant 0 : index
+// CHECK:           %[[VAL_9:.*]] = arith.constant 0.000000e+00 : f32
+
+/// Create a mask for the input tensor
+// CHECK:           %[[VAL_10:.*]] = arith.constant 2 : index
+// CHECK:           %[[VAL_11:.*]] = memref.dim %[[INPUT]], %[[VAL_10]] : memref<3x5x?xf32>
+// CHECK:           %[[VAL_12:.*]] = arith.constant 3 : index
+// CHECK:           %[[VAL_13:.*]] = arith.constant 5 : index
+// CHECK:           %[[MASK_IN:.*]] = vector.create_mask %[[VAL_12]], %[[VAL_13]], %[[VAL_11]] : vector<3x4x[4]xi1>
+/// Read the input tensor
+// CHECK:           %[[VEC_IN:.*]] = vector.mask %[[MASK_IN]] { vector.transfer_read %[[INPUT]]{{\[}}%[[VAL_8]], %[[VAL_8]], %[[VAL_8]]], %[[...
[truncated]

@banach-space banach-space requested a review from MacDue February 13, 2024 16:35
Copy link
Contributor

@dcaballe dcaballe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Taking a first look!

Comment on lines 3026 to 3194
bool scalableChDim = false;
int64_t nSize, wSize, cSize, kwSize;
// kernel{kw, c}
bindShapeDims(rhsShapedType, kwSize, cSize);
// Dynamic channel size implies scalable vectorisation
if (ShapedType::isDynamic(cSize)) {
assert(channelDimVecSize != 0 && "Channel dim vec size must be > 0");
cSize = channelDimVecSize;
scalableChDim = true;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why a dynamic channel dimension implies scalable vectors? We have to make sure this also works for non scalable cases.

Copy link
Member

@MacDue MacDue Feb 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder this too (though I'm not very familiar with the linalg vectorizer). Could you also use scalable vectors if your channel dimension was large (say 100 elements), but still a static size?

Would it be possible to control this with the by passing a the scalable dims flags for the vector sizes, like is the case with matmuls?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why a dynamic channel dimension implies scalable vectors?

Because I run out of time this week an that's the case I need the most :) I will send an update shortly to add support for more generic case too. It's actually quite easy. The most labor-intensive task is writing tests for these convs.

Could you also use scalable vectors if your channel dimension was large (say 100 elements), but still a static size?

Yes, we already do that for elementwise Ops. However, you need to tile first - that will lead to tensors with dynamic shapes.

Copy link
Contributor Author

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the reviews, sending updates shortly!

Comment on lines 3026 to 3194
bool scalableChDim = false;
int64_t nSize, wSize, cSize, kwSize;
// kernel{kw, c}
bindShapeDims(rhsShapedType, kwSize, cSize);
// Dynamic channel size implies scalable vectorisation
if (ShapedType::isDynamic(cSize)) {
assert(channelDimVecSize != 0 && "Channel dim vec size must be > 0");
cSize = channelDimVecSize;
scalableChDim = true;
}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why a dynamic channel dimension implies scalable vectors?

Because I run out of time this week an that's the case I need the most :) I will send an update shortly to add support for more generic case too. It's actually quite easy. The most labor-intensive task is writing tests for these convs.

Could you also use scalable vectors if your channel dimension was large (say 100 elements), but still a static size?

Yes, we already do that for elementwise Ops. However, you need to tile first - that will lead to tensors with dynamic shapes.

@banach-space banach-space changed the title [mlir][linalg] Add scalable vectorisation for depthwise convolutions [mlir][linalg] Enable masked vectorisation for depthwise convolutions Feb 16, 2024
%filter: memref<2x?xf32>,
%output: memref<3x2x?xf32>) {
linalg.depthwise_conv_1d_nwc_wc
{dilations = dense<2> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the significance of a dilation of 2? Is dilation explained anywhere? I can't see anything in the linalg docs.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've been struggling to find good documentation for this. This book provides a nice description:

See Figure 10.3.

The reason for adding this case here is to demonstrate that this patch is not doing anything that would only work for the default case (dilation = 1). But I appreciate that it's a bit tricky to follow without knowing the semantics of the Op itself. Is there anything I can do to make it clearer? You could try comparing against:

func.func @depthwise_conv1d_nwc_wc_3x5x4xf32_memref_dillation_2(%input: memref<3x5x4xf32>,
%filter: memref<2x4xf32>,
%output: memref<3x2x4xf32>) {
linalg.depthwise_conv_1d_nwc_wc
{dilations = dense<2> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}
ins(%input, %filter : memref<3x5x4xf32>, memref<2x4xf32>)
outs(%output : memref<3x2x4xf32>)
return
}
// CHECK: func @depthwise_conv1d_nwc_wc_3x5x4xf32_memref_dillation_2
// CHECK-SAME: (%[[INPUT:[0-9a-z]+]]: memref<3x5x4xf32>, %[[FILTER:[0-9a-z]+]]: memref<2x4xf32>, %[[OUTPUT:[0-9a-z]+]]: memref<3x2x4xf32>)
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[F0:.+]] = arith.constant 0.000000e+00 : f32
/// Read the whole data in one shot.
// CHECK-DAG: %[[V_INPUT_R:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C0]], %[[C0]]]
// CHECK-DAG: %[[V_FILTER_R:.+]] = vector.transfer_read %[[FILTER]][%[[C0]], %[[C0]]]
// CHECK-DAG: %[[V_OUTPUT_R:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]]
// CHECK: %[[V_INPUT_0:.+]] = vector.extract_strided_slice %[[V_INPUT_R]]
// CHECK-SAME: {offsets = [0, 0, 0], sizes = [3, 2, 4], strides = [1, 1, 1]} : vector<3x4x4xf32> to vector<3x2x4xf32>
// CHECK: %[[V_INPUT_1:.+]] = vector.extract_strided_slice %[[V_INPUT_R]]
// CHECK-SAME: {offsets = [0, 2, 0], sizes = [3, 2, 4], strides = [1, 1, 1]} : vector<3x4x4xf32> to vector<3x2x4xf32>
// CHECK: %[[V_FILTER_0:.+]] = vector.extract %[[V_FILTER_R]][0] : vector<4xf32> from vector<2x4xf32>
// CHECK: %[[V_FILTER_1:.+]] = vector.extract %[[V_FILTER_R]][1] : vector<4xf32> from vector<2x4xf32>
/// w == 0, kw = 0
// CHECK: %[[SC_V_INPUT_0:.*]] = vector.shape_cast %[[V_INPUT_0]] : vector<3x2x4xf32> to vector<3x8xf32>
// CHECK: %[[SC_V_OUTPUT_R:.*]] = vector.shape_cast %[[V_OUTPUT_R]] : vector<3x2x4xf32> to vector<3x8xf32>
// CHECK: %[[SH_FILTER_0:.*]] = vector.shuffle %[[V_FILTER_0]], %[[V_FILTER_0]]
// CHECK-SAME: [0, 1, 2, 3, 0, 1, 2, 3] : vector<4xf32>, vector<4xf32>
// CHECK: %[[B_FILTER_0:.*]] = vector.broadcast %[[SH_FILTER_0]] : vector<8xf32> to vector<3x8xf32>
// CHECK: %[[FMA_0:.*]] = vector.fma %[[SC_V_INPUT_0]], %[[B_FILTER_0]], %[[SC_V_OUTPUT_R]] : vector<3x8xf32>
/// w == 0, kw = 1
// CHECK: %[[SC_V_INPUT_1:.*]] = vector.shape_cast %[[V_INPUT_1]] : vector<3x2x4xf32> to vector<3x8xf32>
// CHECK: %[[SH_FILTER_1:.*]] = vector.shuffle %[[V_FILTER_1]], %[[V_FILTER_1]]
// CHECK-SAME: [0, 1, 2, 3, 0, 1, 2, 3] : vector<4xf32>, vector<4xf32>
// CHECK: %[[B_FILTER_1:.*]] = vector.broadcast %[[SH_FILTER_1]] : vector<8xf32> to vector<3x8xf32>
// CHECK: %[[FMA_1:.*]] = vector.fma %[[SC_V_INPUT_1]], %[[B_FILTER_1]], %[[FMA_0]] : vector<3x8xf32>
// Write the result back in one shot.
// CHECK: %[[SC_FMA_1:.*]] = vector.shape_cast %[[FMA_1]] : vector<3x8xf32> to vector<3x2x4xf32>
// CHECK: vector.transfer_write %[[SC_FMA_1]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]]
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["linalg.depthwise_conv_1d_nwc_wc"]} in %arg0 : (!transform.any_op) -> !transform.any_op
%1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
%2 = transform.structured.vectorize_children_and_apply_patterns %1 {flatten_1d_depthwise_conv} : (!transform.any_op) -> !transform.any_op
transform.yield
}
}

But note that I wasn't trying to keep the check-lines consistent and, more importantly, that is testing "flattened" convs.

@banach-space
Copy link
Contributor Author

Kind ping :)

I believe that I've addressed all your comments - could you take another look? This is quite critical for me 🙏🏻 😅

@banach-space banach-space force-pushed the andrzej/scalable_depthwise_conv branch from 5e7090a to 625eef5 Compare March 11, 2024 18:33
Copy link

github-actions bot commented Mar 11, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

@banach-space banach-space force-pushed the andrzej/scalable_depthwise_conv branch from 625eef5 to 30d407a Compare March 11, 2024 19:17
Copy link
Collaborator

@c-rhodes c-rhodes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've left a few final comments but otherwise this LGTM, cheers. I would appreciate an integration test as well, but that doesn't need to be in this patch.

// CHECK: %[[VEC_OUT:.*]] = vector.mask %[[MASK_OUT]] { vector.transfer_read %[[OUTPUT]]{{\[}}%[[C0]], %[[C0]], %[[C0]]], %[[PAD]] : tensor<1x8x?xi8>, vector<1x8x4xi8> } : vector<1x8x4xi1> -> vector<1x8x4xi8>

/// Convolution
// CHECK: %[[IN_1:.*]] = vector.extract_strided_slice %[[VEC_IN]] {offsets = [0, 0, 0], sizes = [1, 8, 4], strides = [1, 1, 1]} : vector<1x8x4xi8> to vector<1x8x4xi8>
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the insert/extract strided slice operations in this test and the scalable equivalent are not necessary, with canonicalization these are removed which I think makes it easier to follow, as well as making it more obvious what is happening for the final test where the dilation is 2 and the extract_strided_slice is necessary.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that the extra indirection is somewhat confusing, but I deliberately avoid canonicalization in these tests - otherwise we wouldn't be testing the vectoriser. In fact, we have introduced dedicated TD ops to separate:

In particular, once canonicalizations are included, the generated output will depend on many other parts of MLIR and I want to avoid that. Having said that, we could add more tests with patterns.

}
// Extract res slice: {n, wSizeStep, c} @ [0, w, 0].
for (int64_t w = 0; w < wSize; w += wSizeStep) {
resVals.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
loc, res,
loc, maybeMaskedRes->getResult(0),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The logic below this (for when flatten=true) looks like it will drop the scalable dims:

    auto inOutFlattenSliceSizes =
        SmallVector<int64_t>{nSize, wSizeStep * cSize};
    auto lhsCastType = VectorType::get(inOutFlattenSliceSizes, lhsEltType);
    auto resCastType = VectorType::get(inOutFlattenSliceSizes, resEltType);
    // Compute contraction: O{n, w, c} += I{n, sw * w + dw * kw, c} * F{c}

You may want to update that (or bail out if early if flatten=true).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a great catch, thanks! This should be captured by one of the pre-conditions and that's what I'll do.

We don't have a mechanism within TD to trigger this case, which means it's hard to test it. And I'm not sure whether extending TD just to test things is the right thing 🤔 I need to think of a use-case other than testing 😅

This patch adds support for scalable vectorisation of depthwise 1D HWC
convolutions,`linalg.depthwise_conv_1d_nwc_wc`. This is implemented by
adding support for masking.

Two major assumptions are made:
  * only the channel dimension can be scalable/dynamic (i.e. the
    trailing dim),
  * when specifying vector sizes to use in the vectoriser, only the size
    corresponding to the channel dim is effectively used (other dims are
    inferred from the context).

In terms of scalable vectorisation, this should be sufficient that cover
all practical cases (i.e. making arbitrary dim scalable wouldn't make
much sense). As for more generic cases with dynamic shapes (e.g. w or n
dims being dynamic), more work would be needed. In particular, one would
have to consider the filter and input/ouput tensors separately. However,
it's not clear whether that would be of any use in practice.
…utions

Addressing PR comments:
- add CSE in tests, update check-lines accordingly
- add support for plain (non-scalable) masked vectorisation
- moved pre-conditions for vectorisation to a dedicated hook
…onvolutions

Address comments from Ben and Crefeda
…ions

More documentaiton, some simplification (as per Cullen's comments)
…ions

Address Diego's comments, move  to vector utils
…utions

Generalise the code a tiny bit to cover for 1D depthwise NCW convs (once
supported by the vectoriser).
@banach-space banach-space force-pushed the andrzej/scalable_depthwise_conv branch from 30d407a to 2a8ce8a Compare March 14, 2024 10:35
…utions

* Add missing dyn dimension in a test
* Make sure "flattening" + "masked vectorisation" are not allowed
@banach-space banach-space force-pushed the andrzej/scalable_depthwise_conv branch from 2a8ce8a to b367636 Compare March 14, 2024 13:39
banach-space added a commit to banach-space/llvm-project that referenced this pull request Mar 14, 2024
@banach-space
Copy link
Contributor Author

I would appreciate an integration test as well, but that doesn't need to be in this patch.

#85225

Copy link
Contributor

@dcaballe dcaballe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, for enablement purposes. I understand the criticality of this and I don't want to block the development but, as discussed (also offline with Nicolas), we know that this is a short-term specialization and that all the convolution support in the vectorizer needs to be revisited. As such, it may happen that some of this functionality has to be disabled/revisited in the future if it hinders a more generic path forward. I would also be concerned if there are plans to continue extending this path so let's make sure we discuss those and plan for a better approach if that is the case.

@banach-space
Copy link
Contributor Author

Thanks all for reviewing, I will be merging this shortly (it's quite critical for us - we can't leverage wider SVE vectors otherwise).

(...) we know that this is a short-term specialization and that all the convolution support in the vectorizer needs to be revisited.

I would be very keen to participate in any such refactor. Lets continue brainstorming about this!

As such, it may happen that some of this functionality has to be disabled/revisited in the future if it hinders a more generic path forward.

I am more than happy to redesign/refactor this if need be 👍🏻

@banach-space banach-space merged commit c56bd7a into llvm:main Mar 14, 2024
@banach-space banach-space deleted the andrzej/scalable_depthwise_conv branch March 16, 2024 16:35
banach-space added a commit to banach-space/llvm-project that referenced this pull request Mar 16, 2024
banach-space added a commit to banach-space/llvm-project that referenced this pull request Mar 22, 2024
banach-space added a commit that referenced this pull request Mar 22, 2024
chencha3 pushed a commit to chencha3/llvm-project that referenced this pull request Mar 23, 2024
banach-space added a commit that referenced this pull request Mar 25, 2024
Follow-up for #81625

Relands #85225 with a minor update to the RUN line to fix buildbot
failures:
```diff
-// RUN: %{compile} | %{run} | FileCheck %s
+// RUN: rm -f %t && %{compile} && %{run} | FileCheck %s
```

Failing buildbots after landing #85225:
 * https://lab.llvm.org/buildbot/#/builders/184/builds/11363
 * https://lab.llvm.org/buildbot/#/builders/176/builds/9331
qiaojbao pushed a commit to GPUOpen-Drivers/llvm-project that referenced this pull request May 15, 2024
…d22716224

Local branch amd-gfx 02fd227 Merged main:4f873730d6ac1a8496cdef939cc451f178a864ee into amd-gfx:f078619cd03f
Remote branch main c56bd7a [mlir][linalg] Enable masked vectorisation for depthwise convolutions (llvm#81625)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants