Skip to content

[mlir][linalg] Add NHWC + FHWC Img2Col #68708

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 13, 2023

Conversation

FranklandJack
Copy link
Contributor

Adds the Img2Col transformation for the fhwc channel ordering in a Conv2D. Because of how the channel ordering affects the matrix dimensions in the flattened filter this results in a slightly different implementation of the actual "matrix multiplication". Instead of doing a regular row-column dot-product this arrangement requires a row-row dot product, otherwise the filter matrix would first need to be transposed.

Adds a lit test to the transform dialect to check the semantics of the optimization are correct.

@llvmbot
Copy link
Member

llvmbot commented Oct 10, 2023

@llvm/pr-subscribers-mlir-linalg

@llvm/pr-subscribers-mlir

Author: Jack Frankland (FranklandJack)

Changes

Adds the Img2Col transformation for the fhwc channel ordering in a Conv2D. Because of how the channel ordering affects the matrix dimensions in the flattened filter this results in a slightly different implementation of the actual "matrix multiplication". Instead of doing a regular row-column dot-product this arrangement requires a row-row dot product, otherwise the filter matrix would first need to be transposed.

Adds a lit test to the transform dialect to check the semantics of the optimization are correct.


Full diff: https://github.com/llvm/llvm-project/pull/68708.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h (+8)
  • (modified) mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp (+3)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp (+147)
  • (modified) mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir (+70)
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 07a192f7b8606d3..3597209d7f90c25 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1175,6 +1175,14 @@ FailureOr<Operation *> rewriteInDestinationPassingStyle(RewriterBase &rewriter,
 FailureOr<std::pair<Operation *, Operation *>>
 rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp);
 
+/// Same as the above but for Fhwc channel orderings in the filter. In this case
+/// the matrix multiplication is actually a row-wise dot-product rather than a
+/// row-column dot-product. This is to avoid transposing the filter matrix which
+/// would be required for a regular matrix multiplication to produce the correct
+/// output dimensions.
+FailureOr<std::pair<Operation *, Operation *>>
+rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp);
+
 /// Similar to rewriteInIm2Col with linalg::Conv2DNhwcHwcfOp except there is no
 /// reduction among the input channels so each convolution can be a
 /// matrix-vector product and by transposing both input filter so channels are
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 9ce780d3d249cfb..8508507871d0c6c 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -3118,6 +3118,9 @@ DiagnosedSilenceableFailure transform::ConvertConv2DToImg2ColOp::applyToOne(
           .Case([&](linalg::Conv2DNhwcHwcfOp op) {
             return rewriteInIm2Col(rewriter, op);
           })
+          .Case([&](linalg::Conv2DNhwcFhwcOp op) {
+            return rewriteInIm2Col(rewriter, op);
+          })
           .Case([&](linalg::DepthwiseConv2DNhwcHwcOp op) {
             return rewriteInIm2Col(rewriter, op);
           })
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp
index 275e78aaa73dde6..848a7b5e7fc52c9 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp
@@ -494,6 +494,140 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNchwFchwOp convOp) {
                         reshapedResult.getOperation());
 }
 
+FailureOr<std::pair<Operation *, Operation *>>
+rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp) {
+  auto inputType = cast<ShapedType>(convOp.getInputs()[0].getType());
+  auto filterType = cast<ShapedType>(convOp.getInputs()[1].getType());
+  auto outputType = cast<ShapedType>(convOp.getOutputs()[0].getType());
+
+  if (!filterType.hasStaticShape())
+    return rewriter.notifyMatchFailure(
+        convOp, "expected a static shape for the filter");
+
+  if (!inputType.hasStaticShape())
+    return rewriter.notifyMatchFailure(convOp,
+                                       "expected a static shape for the input");
+
+  // TODO: Support dilation.
+  if (!hasAllOneValues(convOp.getDilations()))
+    return rewriter.notifyMatchFailure(convOp,
+                                       "expected all ones for dilations");
+
+  MLIRContext *context = rewriter.getContext();
+  Value input = convOp.getInputs()[0];
+  Value filter = convOp.getInputs()[1];
+  Value output = convOp.getOutputs()[0];
+
+  ArrayRef<int64_t> filterShape = filterType.getShape();
+  ArrayRef<int64_t> outputShape = outputType.getShape();
+
+  int64_t n = outputShape[0];
+  int64_t oh = outputShape[1];
+  int64_t ow = outputShape[2];
+  int64_t oc = outputShape[3];
+  int64_t fh = filterShape[1];
+  int64_t fw = filterShape[2];
+  int64_t ic = filterShape[3];
+
+  Location loc = convOp.getLoc();
+
+  // Reshape output and filter to the LHS and result of a (B)MNK matmul.
+  SmallVector<ReassociationIndices> filterReassocIndices = {{0}, {1, 2, 3}};
+  auto reshapedFilterType =
+      RankedTensorType::get({oc, fh * fw * ic}, filterType.getElementType());
+  Value reshapedFilter = rewriter.create<tensor::CollapseShapeOp>(
+      loc, reshapedFilterType, filter, filterReassocIndices);
+
+  SmallVector<ReassociationIndices> outputReassocIndices = {{0}, {1, 2}, {3}};
+  RankedTensorType reshapedOutputType =
+      RankedTensorType::get({n, oh * ow, oc}, outputType.getElementType());
+  Value reshapedOutput = rewriter.create<tensor::CollapseShapeOp>(
+      loc, reshapedOutputType, output, outputReassocIndices);
+
+  SmallVector<int64_t> colTensorShape = {n, oh * ow, fh * fw * ic};
+  Value colTensor = rewriter.create<tensor::EmptyOp>(
+      loc, colTensorShape, inputType.getElementType());
+
+  // Convert the input to a (BMK) column tensor.
+  auto nloops = colTensorShape.size();
+
+  auto parallel = utils::IteratorType::parallel;
+  auto reduction = utils::IteratorType::reduction;
+  SmallVector<utils::IteratorType> img2colIterators(nloops, parallel);
+
+  SmallVector<AffineMap> img2colIndexingMaps = {
+      AffineMap::getMultiDimIdentityMap(nloops, context)};
+
+  auto img2ColTensor = rewriter.create<linalg::GenericOp>(
+      loc, colTensor.getType(),
+      /*inputs=*/ValueRange{}, /*outputs=*/colTensor, img2colIndexingMaps,
+      img2colIterators,
+      [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
+        // Get the iterators named based on the matmul (batch, m, k).
+        Value bIndex = nestedBuilder.create<linalg::IndexOp>(loc, 0);
+        Value mIndex = nestedBuilder.create<linalg::IndexOp>(loc, 1);
+        Value kIndex = nestedBuilder.create<linalg::IndexOp>(loc, 2);
+
+        // Recover the original iteration indices from the problem/input sizes.
+        SmallVector<Value> mIndices = unrollIndex(
+            nestedBuilder, nestedLoc, mIndex, ArrayRef<int64_t>{oh, ow});
+        auto ohIndex = mIndices[0];
+        auto owIndex = mIndices[1];
+
+        SmallVector<Value> kIndices = unrollIndex(
+            nestedBuilder, nestedLoc, kIndex, ArrayRef<int64_t>{fh, fw, ic});
+        auto fhIndex = kIndices[0];
+        auto fwIndex = kIndices[1];
+        auto icIndex = kIndices[2];
+
+        // Extract the input element corresponding to the expanded indices.
+        Value hIndex =
+            getConvolvedIndex(nestedBuilder, nestedLoc, ohIndex, fhIndex,
+                              convOp.getStrides().getValues<int64_t>()[0]);
+        Value wIndex =
+            getConvolvedIndex(nestedBuilder, nestedLoc, owIndex, fwIndex,
+                              convOp.getStrides().getValues<int64_t>()[1]);
+
+        // im2col[n, oh*ow, fh*fw*ic] = input[n, sh*oh + fh, sw*ow + fw, ic]
+        SmallVector<Value> extractionIndices{bIndex, hIndex, wIndex, icIndex};
+        Value inputVal = nestedBuilder.create<tensor::ExtractOp>(
+            loc, input, extractionIndices);
+        nestedBuilder.create<linalg::YieldOp>(nestedLoc, inputVal);
+      });
+
+  // Because we didn't transpose the filters we don't actually have a batched
+  // matrix multiply. Instead, we have an operation consisting of "rowise" dot
+  // products.
+  AffineExpr bDim, mDim, nDim, kDim;
+  bindDims(context, bDim, mDim, nDim, kDim);
+  auto lhsMap = AffineMap::get(4, 0, {bDim, mDim, kDim}, context);
+  auto rhsMap = AffineMap::get(4, 0, {nDim, kDim}, context);
+  auto resultMap = AffineMap::get(4, 0, {bDim, mDim, nDim}, context);
+  SmallVector<utils::IteratorType> genericIterators = {parallel, parallel,
+                                                       parallel, reduction};
+
+  auto genericOp = rewriter.create<linalg::GenericOp>(
+      loc, reshapedOutputType,
+      /*inputs=*/ValueRange{img2ColTensor.getResult(0), reshapedFilter},
+      /*outputs=*/ValueRange{reshapedOutput},
+      ArrayRef<AffineMap>{lhsMap, rhsMap, resultMap}, genericIterators,
+      [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
+        Value mul =
+            createMul(loc, args[0], args[1], args[2].getType(), nestedBuilder);
+        Value add = createAdd(loc, mul, args[2], nestedBuilder);
+        nestedBuilder.create<linalg::YieldOp>(nestedLoc, add);
+      });
+  Value result = genericOp.getResults().front();
+
+  auto reshapedResult = rewriter.create<tensor::ExpandShapeOp>(
+      loc, outputType, result, outputReassocIndices);
+
+  rewriter.replaceOp(convOp, ArrayRef<Value>{reshapedResult});
+
+  return std::make_pair(img2ColTensor.getOperation(),
+                        reshapedResult.getOperation());
+}
+
 namespace {
 
 class ConvertConv2DNhwcHwcf final
@@ -534,6 +668,19 @@ class ConvertConv2DNchwFchw final
     return success();
   }
 };
+
+class ConvertConv2DNhwcFhwc final
+    : public OpRewritePattern<linalg::Conv2DNhwcFhwcOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(linalg::Conv2DNhwcFhwcOp convOp,
+                                PatternRewriter &rewriter) const override {
+    if (failed(rewriteInIm2Col(rewriter, convOp)))
+      return failure();
+    return success();
+  }
+};
 } // end anonymous namespace
 
 void populateConvertConv2DToImg2ColPatterns(RewritePatternSet &patterns) {
diff --git a/mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir b/mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir
index 657cf83f25460fd..b2470ed7b748042 100644
--- a/mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir
+++ b/mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir
@@ -279,6 +279,76 @@ transform.sequence failures(propagate) {
 
 // -----
 
+// CHECK: IR printer: tensor_producer
+// CHECK-NEXT: %[[COL_TENSOR:.+]] = linalg.generic
+// CHECK-SAME: affine_map<(d0, d1, d2) -> (d0, d1, d2)>]
+// CHECK: ^bb0(%[[OUT_DATA:.+]]: f32)
+
+// Collapsed indices.
+// CHECK: %[[BINDEX:.+]] = linalg.index 0 : index
+// CHECK: %[[MINDEX:.+]] = linalg.index 1 : index
+// CHECK: %[[KINDEX:.+]] = linalg.index 2 : index
+
+// Compute input channel/convolved indices.
+// CHECK: %[[ICINDEX:.+]] = affine.apply affine_map<(d0) -> (d0 mod 4)>(%[[KINDEX]])
+// CHECK: %[[CONVH:.+]] = affine.apply affine_map<(d0, d1) -> (d0 floordiv 14 + d1 floordiv 12)>(%[[MINDEX]], %[[KINDEX]])
+// CHECK: %[[CONVW:.+]] = affine.apply affine_map<(d0, d1) -> (d0 mod 14 + (d1 mod 12) floordiv 4)>(%[[MINDEX]], %[[KINDEX]])
+
+// Extract from the input tensor.
+// CHECK: %[[EXTRACTED_INPUT:.+]] = tensor.extract
+// CHECK-SAME: %{{.+}}{{\[}}%[[BINDEX]], %[[CONVH]], %[[CONVW]], %[[ICINDEX]]] : tensor<1x16x16x4xf32>
+// CHECK: linalg.yield %[[EXTRACTED_INPUT]] : f32
+
+// CHECK: IR printer: transformed
+// CHECK: tensor.expand_shape %{{[^ ]*}} {{\[}}[0], [1, 2], [3]] : tensor<1x196x16xf32> into tensor<1x14x14x16xf32>
+
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d2, d3)>
+// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+//      CHECK: @conv_2d_nhwc_fhwc
+//      CHECK-SAME: %[[INPUT:.+]]: tensor<1x16x16x4xf32>
+//      CHECK-SAME: %[[FILTER:.+]]: tensor<16x3x3x4xf32>
+//      CHECK-SAME: %[[OUTPUT:.+]]: tensor<1x14x14x16xf32>
+//  CHECK-DAG: %[[COLLAPSED_FILTER:.+]] = tensor.collapse_shape %[[FILTER]] {{\[}}[0], [1, 2, 3]] : tensor<16x3x3x4xf32> into tensor<16x36xf32>
+//  CHECK-DAG: %[[COLLAPSED_OUT:.+]] = tensor.collapse_shape %[[OUTPUT]] {{\[}}[0], [1, 2], [3]] : tensor<1x14x14x16xf32> into tensor<1x196x16xf32>
+//      CHECK: %[[INIT_COL_TENSOR:.+]] = tensor.empty() : tensor<1x196x36xf32>
+//      CHECK: %[[COL_TENSOR:.+]] = linalg.generic
+//           CHECK-SAME: #[[MAP0]]
+//                CHECK: ^bb0(%[[OUT_DATA:.+]]: f32)
+//                CHECK: linalg.yield %{{.+}} : f32
+//      CHECK: %[[MATMUL_RESULT:.+]] = linalg.generic
+//           CHECK-SAME: #[[MAP1]]
+//           CHECK-SAME: #[[MAP2]]
+//           CHECK-SAME: #[[MAP3]]
+//           CHECK-SAME: ins(%[[COL_TENSOR]], %[[COLLAPSED_FILTER]] : tensor<1x196x36xf32>, tensor<16x36xf32>)
+//           CHECK-SAME: outs(%[[COLLAPSED_OUT]] : tensor<1x196x16xf32>)
+//                CHECK: ^bb0(%[[ARG0:.+]]: f32, %[[ARG1:.+]]: f32, %[[ARG2:.+]]: f32)
+//                CHECK:     %[[MUL:.+]] = arith.mulf %[[ARG0]], %[[ARG1]] : f32
+//                CHECK:     %[[ADD:.+]] = arith.addf %[[MUL]], %[[ARG2]] : f32
+//                CHECK:     linalg.yield %[[ADD]] : f32
+//                CHECK: } -> tensor<1x196x16xf32>
+//      CHECK: %[[RESULT:.+]] = tensor.expand_shape %[[MATMUL_RESULT]] {{\[}}[0], [1, 2], [3]] : tensor<1x196x16xf32> into tensor<1x14x14x16xf32>
+//      CHECK: return %[[RESULT]]
+
+func.func @conv_2d_nhwc_fhwc(%arg0: tensor<1x16x16x4xf32>, %arg1: tensor<16x3x3x4xf32>, %arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> {
+    %0 = linalg.conv_2d_nhwc_fhwc
+      {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }
+       ins(%arg0, %arg1: tensor<1x16x16x4xf32>, tensor<16x3x3x4xf32>)
+      outs(%arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32>
+    return %0 : tensor<1x14x14x16xf32>
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+  %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+  %img2col_tensor_producer, %transformed = transform.structured.convert_conv2d_to_img2col %0 : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+  transform.print %img2col_tensor_producer {name = "tensor_producer"}: !transform.any_op
+  transform.print %transformed {name = "transformed"}: !transform.any_op
+}
+
+// -----
+
 // Check for signed extend when the input type is smaller than the accumulator type.
 
 // CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>

Copy link
Contributor

@qedawkins qedawkins left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Looks good, just a quick addition to the set of Img2Col patterns.

Adds the Img2Col transformation for the fhwc channel ordering in a
Conv2D. Because of how the channel ordering affects the matrix
dimensions in the flattened filter this results in a slightly different
implementation of the actual "matrix multiplication". Instead of doing a
regular row-column dot-product this arrangement requires a row-row dot
product, otherwise the filter matrix would first need to be transposed.

Adds a lit test to the transform dialect to check the semantics of the
optimization are correct.

Signed-off-by: Jack Frankland <[email protected]>
@FranklandJack FranklandJack force-pushed the jacfra01/img2col-nhwc+fhwc branch from 8bc2dce to aa22d0f Compare October 12, 2023 10:32
Copy link
Contributor Author

@FranklandJack FranklandJack left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the review!

Copy link
Contributor

@qedawkins qedawkins left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@c-rhodes c-rhodes merged commit 92e751d into llvm:main Oct 13, 2023
@c-rhodes
Copy link
Collaborator

Landed on behalf of @FranklandJack as he doesn't have commit access.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants