Skip to content

[mlir][linalg] Add NHWC + FHWC Img2Col #68708

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -1175,6 +1175,14 @@ FailureOr<Operation *> rewriteInDestinationPassingStyle(RewriterBase &rewriter,
FailureOr<std::pair<Operation *, Operation *>>
rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp);

/// Same as the above but for Fhwc channel orderings in the filter. In this case
/// the matrix multiplication is actually a row-wise dot-product rather than a
/// row-column dot-product. This is to avoid transposing the filter matrix which
/// would be required for a regular matrix multiplication to produce the correct
/// output dimensions.
FailureOr<std::pair<Operation *, Operation *>>
rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp);

/// Similar to rewriteInIm2Col with linalg::Conv2DNhwcHwcfOp except there is no
/// reduction among the input channels so each convolution can be a
/// matrix-vector product and by transposing both input filter so channels are
Expand Down
3 changes: 3 additions & 0 deletions mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3118,6 +3118,9 @@ DiagnosedSilenceableFailure transform::ConvertConv2DToImg2ColOp::applyToOne(
.Case([&](linalg::Conv2DNhwcHwcfOp op) {
return rewriteInIm2Col(rewriter, op);
})
.Case([&](linalg::Conv2DNhwcFhwcOp op) {
return rewriteInIm2Col(rewriter, op);
})
.Case([&](linalg::DepthwiseConv2DNhwcHwcOp op) {
return rewriteInIm2Col(rewriter, op);
})
Expand Down
150 changes: 149 additions & 1 deletion mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,141 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNchwFchwOp convOp) {
reshapedResult.getOperation());
}

FailureOr<std::pair<Operation *, Operation *>>
rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp) {
auto inputType = cast<ShapedType>(convOp.getInputs()[0].getType());
auto filterType = cast<ShapedType>(convOp.getInputs()[1].getType());
auto outputType = cast<ShapedType>(convOp.getOutputs()[0].getType());

if (!filterType.hasStaticShape())
return rewriter.notifyMatchFailure(
convOp, "expected a static shape for the filter");

if (!inputType.hasStaticShape())
return rewriter.notifyMatchFailure(convOp,
"expected a static shape for the input");

// TODO: Support dilation.
if (!hasAllOneValues(convOp.getDilations()))
return rewriter.notifyMatchFailure(convOp,
"expected all ones for dilations");

MLIRContext *context = rewriter.getContext();
Value input = convOp.getInputs()[0];
Value filter = convOp.getInputs()[1];
Value output = convOp.getOutputs()[0];

ArrayRef<int64_t> filterShape = filterType.getShape();
ArrayRef<int64_t> outputShape = outputType.getShape();

int64_t n = outputShape[0];
int64_t oh = outputShape[1];
int64_t ow = outputShape[2];
int64_t oc = outputShape[3];
int64_t fh = filterShape[1];
int64_t fw = filterShape[2];
int64_t ic = filterShape[3];

Location loc = convOp.getLoc();

// Reshape output and filter to the LHS and result of a "row-wise" matrix
// multiplication.
SmallVector<ReassociationIndices> filterReassocIndices = {{0}, {1, 2, 3}};
auto reshapedFilterType =
RankedTensorType::get({oc, fh * fw * ic}, filterType.getElementType());
Value reshapedFilter = rewriter.create<tensor::CollapseShapeOp>(
loc, reshapedFilterType, filter, filterReassocIndices);

SmallVector<ReassociationIndices> outputReassocIndices = {{0}, {1, 2}, {3}};
RankedTensorType reshapedOutputType =
RankedTensorType::get({n, oh * ow, oc}, outputType.getElementType());
Value reshapedOutput = rewriter.create<tensor::CollapseShapeOp>(
loc, reshapedOutputType, output, outputReassocIndices);

SmallVector<int64_t> colTensorShape = {n, oh * ow, fh * fw * ic};
Value colTensor = rewriter.create<tensor::EmptyOp>(
loc, colTensorShape, inputType.getElementType());

// Convert the input to a (BMK) column tensor.
auto nloops = colTensorShape.size();

auto parallel = utils::IteratorType::parallel;
auto reduction = utils::IteratorType::reduction;
SmallVector<utils::IteratorType> img2colIterators(nloops, parallel);

SmallVector<AffineMap> img2colIndexingMaps = {
AffineMap::getMultiDimIdentityMap(nloops, context)};

auto img2ColTensor = rewriter.create<linalg::GenericOp>(
loc, colTensor.getType(),
/*inputs=*/ValueRange{}, /*outputs=*/colTensor, img2colIndexingMaps,
img2colIterators,
[&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
// Get the iterators named based on the matmul (batch, m, k).
Value bIndex = nestedBuilder.create<linalg::IndexOp>(loc, 0);
Value mIndex = nestedBuilder.create<linalg::IndexOp>(loc, 1);
Value kIndex = nestedBuilder.create<linalg::IndexOp>(loc, 2);

// Recover the original iteration indices from the problem/input sizes.
SmallVector<Value> mIndices = unrollIndex(
nestedBuilder, nestedLoc, mIndex, ArrayRef<int64_t>{oh, ow});
auto ohIndex = mIndices[0];
auto owIndex = mIndices[1];

SmallVector<Value> kIndices = unrollIndex(
nestedBuilder, nestedLoc, kIndex, ArrayRef<int64_t>{fh, fw, ic});
auto fhIndex = kIndices[0];
auto fwIndex = kIndices[1];
auto icIndex = kIndices[2];

// Extract the input element corresponding to the expanded indices.
Value hIndex =
getConvolvedIndex(nestedBuilder, nestedLoc, ohIndex, fhIndex,
convOp.getStrides().getValues<int64_t>()[0]);
Value wIndex =
getConvolvedIndex(nestedBuilder, nestedLoc, owIndex, fwIndex,
convOp.getStrides().getValues<int64_t>()[1]);

// im2col[n, oh*ow, fh*fw*ic] = input[n, sh*oh + fh, sw*ow + fw, ic]
SmallVector<Value> extractionIndices{bIndex, hIndex, wIndex, icIndex};
Value inputVal = nestedBuilder.create<tensor::ExtractOp>(
loc, input, extractionIndices);
nestedBuilder.create<linalg::YieldOp>(nestedLoc, inputVal);
});

// Because we didn't transpose the filters we don't actually have a batched
// matrix multiply. Instead, we have an operation consisting of "row-wise" dot
// products.
AffineExpr bDim, mDim, nDim, kDim;
bindDims(context, bDim, mDim, nDim, kDim);
auto lhsMap = AffineMap::get(4, 0, {bDim, mDim, kDim}, context);
auto rhsMap = AffineMap::get(4, 0, {nDim, kDim}, context);
auto resultMap = AffineMap::get(4, 0, {bDim, mDim, nDim}, context);
SmallVector<utils::IteratorType> genericIterators = {parallel, parallel,
parallel, reduction};

auto genericOp = rewriter.create<linalg::GenericOp>(
loc, reshapedOutputType,
/*inputs=*/ValueRange{img2ColTensor.getResult(0), reshapedFilter},
/*outputs=*/ValueRange{reshapedOutput},
ArrayRef<AffineMap>{lhsMap, rhsMap, resultMap}, genericIterators,
[&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
Value mul =
createMul(loc, args[0], args[1], args[2].getType(), nestedBuilder);
Value add = createAdd(loc, mul, args[2], nestedBuilder);
nestedBuilder.create<linalg::YieldOp>(nestedLoc, add);
});
Value result = genericOp.getResults().front();

auto reshapedResult = rewriter.create<tensor::ExpandShapeOp>(
loc, outputType, result, outputReassocIndices);

rewriter.replaceOp(convOp, ArrayRef<Value>{reshapedResult});

return std::make_pair(img2ColTensor.getOperation(),
reshapedResult.getOperation());
}

namespace {

class ConvertConv2DNhwcHwcf final
Expand Down Expand Up @@ -534,12 +669,25 @@ class ConvertConv2DNchwFchw final
return success();
}
};

class ConvertConv2DNhwcFhwc final
: public OpRewritePattern<linalg::Conv2DNhwcFhwcOp> {
public:
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(linalg::Conv2DNhwcFhwcOp convOp,
PatternRewriter &rewriter) const override {
if (failed(rewriteInIm2Col(rewriter, convOp)))
return failure();
return success();
}
};
} // end anonymous namespace

void populateConvertConv2DToImg2ColPatterns(RewritePatternSet &patterns) {
MLIRContext *context = patterns.getContext();
patterns.insert<ConvertConv2DNhwcHwcf, ConvertDepthwiseConv2DNhwcHwc,
ConvertConv2DNchwFchw>(context);
ConvertConv2DNchwFchw, ConvertConv2DNhwcFhwc>(context);
}
} // end namespace linalg
} // end namespace mlir
70 changes: 70 additions & 0 deletions mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,76 @@ transform.sequence failures(propagate) {

// -----

// CHECK: IR printer: tensor_producer
// CHECK-NEXT: %[[COL_TENSOR:.+]] = linalg.generic
// CHECK-SAME: affine_map<(d0, d1, d2) -> (d0, d1, d2)>]
// CHECK: ^bb0(%[[OUT_DATA:.+]]: f32)

// Collapsed indices.
// CHECK: %[[BINDEX:.+]] = linalg.index 0 : index
// CHECK: %[[MINDEX:.+]] = linalg.index 1 : index
// CHECK: %[[KINDEX:.+]] = linalg.index 2 : index

// Compute input channel/convolved indices.
// CHECK: %[[ICINDEX:.+]] = affine.apply affine_map<(d0) -> (d0 mod 4)>(%[[KINDEX]])
// CHECK: %[[CONVH:.+]] = affine.apply affine_map<(d0, d1) -> (d0 floordiv 14 + d1 floordiv 12)>(%[[MINDEX]], %[[KINDEX]])
// CHECK: %[[CONVW:.+]] = affine.apply affine_map<(d0, d1) -> (d0 mod 14 + (d1 mod 12) floordiv 4)>(%[[MINDEX]], %[[KINDEX]])

// Extract from the input tensor.
// CHECK: %[[EXTRACTED_INPUT:.+]] = tensor.extract
// CHECK-SAME: %{{.+}}{{\[}}%[[BINDEX]], %[[CONVH]], %[[CONVW]], %[[ICINDEX]]] : tensor<1x16x16x4xf32>
// CHECK: linalg.yield %[[EXTRACTED_INPUT]] : f32

// CHECK: IR printer: transformed
// CHECK: tensor.expand_shape %{{[^ ]*}} {{\[}}[0], [1, 2], [3]] : tensor<1x196x16xf32> into tensor<1x14x14x16xf32>

// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d2, d3)>
// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
// CHECK: @conv_2d_nhwc_fhwc
// CHECK-SAME: %[[INPUT:.+]]: tensor<1x16x16x4xf32>
// CHECK-SAME: %[[FILTER:.+]]: tensor<16x3x3x4xf32>
// CHECK-SAME: %[[OUTPUT:.+]]: tensor<1x14x14x16xf32>
// CHECK-DAG: %[[COLLAPSED_FILTER:.+]] = tensor.collapse_shape %[[FILTER]] {{\[}}[0], [1, 2, 3]] : tensor<16x3x3x4xf32> into tensor<16x36xf32>
// CHECK-DAG: %[[COLLAPSED_OUT:.+]] = tensor.collapse_shape %[[OUTPUT]] {{\[}}[0], [1, 2], [3]] : tensor<1x14x14x16xf32> into tensor<1x196x16xf32>
// CHECK: %[[INIT_COL_TENSOR:.+]] = tensor.empty() : tensor<1x196x36xf32>
// CHECK: %[[COL_TENSOR:.+]] = linalg.generic
// CHECK-SAME: #[[MAP0]]
// CHECK: ^bb0(%[[OUT_DATA:.+]]: f32)
// CHECK: linalg.yield %{{.+}} : f32
// CHECK: %[[MATMUL_RESULT:.+]] = linalg.generic
// CHECK-SAME: #[[MAP1]]
// CHECK-SAME: #[[MAP2]]
// CHECK-SAME: #[[MAP3]]
// CHECK-SAME: ins(%[[COL_TENSOR]], %[[COLLAPSED_FILTER]] : tensor<1x196x36xf32>, tensor<16x36xf32>)
// CHECK-SAME: outs(%[[COLLAPSED_OUT]] : tensor<1x196x16xf32>)
// CHECK: ^bb0(%[[ARG0:.+]]: f32, %[[ARG1:.+]]: f32, %[[ARG2:.+]]: f32)
// CHECK: %[[MUL:.+]] = arith.mulf %[[ARG0]], %[[ARG1]] : f32
// CHECK: %[[ADD:.+]] = arith.addf %[[MUL]], %[[ARG2]] : f32
// CHECK: linalg.yield %[[ADD]] : f32
// CHECK: } -> tensor<1x196x16xf32>
// CHECK: %[[RESULT:.+]] = tensor.expand_shape %[[MATMUL_RESULT]] {{\[}}[0], [1, 2], [3]] : tensor<1x196x16xf32> into tensor<1x14x14x16xf32>
// CHECK: return %[[RESULT]]

func.func @conv_2d_nhwc_fhwc(%arg0: tensor<1x16x16x4xf32>, %arg1: tensor<16x3x3x4xf32>, %arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> {
%0 = linalg.conv_2d_nhwc_fhwc
{dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }
ins(%arg0, %arg1: tensor<1x16x16x4xf32>, tensor<16x3x3x4xf32>)
outs(%arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32>
return %0 : tensor<1x14x14x16xf32>
}

transform.sequence failures(propagate) {
^bb1(%arg1: !transform.any_op):
%0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%img2col_tensor_producer, %transformed = transform.structured.convert_conv2d_to_img2col %0 : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
transform.print %img2col_tensor_producer {name = "tensor_producer"}: !transform.any_op
transform.print %transformed {name = "transformed"}: !transform.any_op
}

// -----

// Check for signed extend when the input type is smaller than the accumulator type.

// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
Expand Down