|
| 1 | +//===- TransposeConv2D.cpp - Convolution transposition -------------------===// |
| 2 | +// |
| 3 | +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | +// See https://llvm.org/LICENSE.txt for license information. |
| 5 | +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | +// |
| 7 | +//===----------------------------------------------------------------------===// |
| 8 | + |
| 9 | +#include "mlir/Dialect/Func/IR/FuncOps.h" |
| 10 | +#include "mlir/Dialect/Linalg/IR/Linalg.h" |
| 11 | +#include "mlir/Dialect/MemRef/IR/MemRef.h" |
| 12 | +#include "mlir/Dialect/Tensor/IR/Tensor.h" |
| 13 | +#include "mlir/IR/BuiltinTypes.h" |
| 14 | +#include "mlir/IR/PatternMatch.h" |
| 15 | +#include "mlir/IR/ValueRange.h" |
| 16 | +#include "mlir/Support/LogicalResult.h" |
| 17 | +#include "mlir/Transforms/DialectConversion.h" |
| 18 | +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
| 19 | +#include "llvm/ADT/SmallVector.h" |
| 20 | +#include "llvm/Support/ErrorHandling.h" |
| 21 | +#include "llvm/Support/RWMutex.h" |
| 22 | +#include <memory> |
| 23 | +#include <numeric> |
| 24 | + |
| 25 | +namespace mlir { |
| 26 | +namespace linalg { |
| 27 | +namespace { |
| 28 | +// clang-format off |
| 29 | +/// Convolution converter that applies the following rewrite: |
| 30 | +/// |
| 31 | +/// Before: |
| 32 | +/// |
| 33 | +/// %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, |
| 34 | +/// strides = dense<2> : tensor<2xi64>} |
| 35 | +/// ins (%input, %filter: tensor<1x4x4x6xf32>, tensor<8x2x2x6xf32>) |
| 36 | +/// outs (%init: tensor<1x2x2x8xf32>) -> tensor<1x2x2x8xf32> |
| 37 | +/// |
| 38 | +/// After: |
| 39 | +/// |
| 40 | +/// %cst = arith.constant 0.000000e+00 : f32 |
| 41 | +/// %0 = tensor.empty() : tensor<2x2x6x8xf32> |
| 42 | +/// %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<2x2x6x8xf32>) -> tensor<2x2x6x8xf32> |
| 43 | +/// %transposed = linalg.transpose ins(%arg1 : tensor<8x2x2x6xf32>) outs(%1 : tensor<2x2x6x8xf32>) |
| 44 | +/// permutation = [1, 2, 3, 0] |
| 45 | +/// %2 = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} |
| 46 | +/// ins(%arg0, %transposed : tensor<1x4x4x6xf32>, tensor<2x2x6x8xf32>) outs(%arg2 : tensor<1x2x2x8xf32>) |
| 47 | +/// -> tensor<1x2x2x8xf32> |
| 48 | +/// |
| 49 | +/// with an analogous example for the quantized case. |
| 50 | +// clang-format on |
| 51 | +template <typename FHWCConvOp, typename HWCFConvOp> |
| 52 | +FailureOr<Operation *> transposeConv2DHelper(RewriterBase &rewriter, |
| 53 | + FHWCConvOp op) { |
| 54 | + // Construct a permutation of the filter tensor dimensions. For a 2D |
| 55 | + // convolution this will be known statically as [1, 2, 3, 0]. |
| 56 | + SmallVector<int64_t> filterPerm({1, 2, 3, 0}); |
| 57 | + |
| 58 | + // Create the type for the transposed filter tensor. |
| 59 | + auto filter = op->getOperand(1); |
| 60 | + auto filterTy = cast<ShapedType>(filter.getType()); |
| 61 | + SmallVector<int64_t> newFilterShape(filterPerm.size()); |
| 62 | + std::generate(std::begin(newFilterShape), std::end(newFilterShape), |
| 63 | + [dim = 0, &filterTy, &filterPerm]() mutable { |
| 64 | + return filterTy.getShape()[filterPerm[dim++]]; |
| 65 | + }); |
| 66 | + |
| 67 | + // Because linalg.transpose expects an "out" parameter we need to pass it a |
| 68 | + // tensor of zeros of the result type so here we construct that tensor. |
| 69 | + auto inputType = op->getOperand(0).getType(); |
| 70 | + auto elementTy = cast<ShapedType>(inputType).getElementType(); |
| 71 | + auto loc = op->getLoc(); |
| 72 | + |
| 73 | + const auto isTensorOp = isa<TensorType>(inputType); |
| 74 | + Value input; |
| 75 | + if (isTensorOp) { |
| 76 | + |
| 77 | + input = rewriter.create<tensor::EmptyOp>(loc, newFilterShape, elementTy) |
| 78 | + .getResult(); |
| 79 | + } else { |
| 80 | + input = rewriter |
| 81 | + .create<memref::AllocOp>( |
| 82 | + loc, MemRefType::get(newFilterShape, elementTy)) |
| 83 | + .getResult(); |
| 84 | + } |
| 85 | + |
| 86 | + // We can then construct the transposition on our filter. |
| 87 | + auto transpose = |
| 88 | + rewriter.create<linalg::TransposeOp>(loc, filter, input, filterPerm); |
| 89 | + |
| 90 | + Value newFilter; |
| 91 | + if (isTensorOp) { |
| 92 | + newFilter = transpose.getResult()[0]; |
| 93 | + } else { |
| 94 | + newFilter = input; |
| 95 | + } |
| 96 | + |
| 97 | + SmallVector<Value> newInputs{op.getInputs()}; |
| 98 | + // The filter is always the second input argument, the other inputs can be |
| 99 | + // left as they are. |
| 100 | + newInputs[1] = newFilter; |
| 101 | + // It is possible the convolution doesn't define any results and its |
| 102 | + // out argument is just used instead. |
| 103 | + SmallVector<Type> resultTy; |
| 104 | + if (op.getNumResults()) { |
| 105 | + resultTy.push_back(op->getResult(0).getType()); |
| 106 | + } |
| 107 | + auto newConv = |
| 108 | + rewriter.create<HWCFConvOp>(loc, resultTy, newInputs, op.getOutputs(), |
| 109 | + op.getStrides(), op.getDilations()); |
| 110 | + rewriter.replaceOp(op, newConv); |
| 111 | + return newConv.getOperation(); |
| 112 | +} |
| 113 | + |
| 114 | +template <typename FHWCConvOp, typename HWCFConvOp> |
| 115 | +class ConvConverter : public OpRewritePattern<FHWCConvOp> { |
| 116 | +public: |
| 117 | + using OpRewritePattern<FHWCConvOp>::OpRewritePattern; |
| 118 | + LogicalResult matchAndRewrite(FHWCConvOp op, |
| 119 | + PatternRewriter &rewriter) const final { |
| 120 | + if (failed(transposeConv2DHelper<FHWCConvOp, HWCFConvOp>(rewriter, op))) { |
| 121 | + return failure(); |
| 122 | + } |
| 123 | + return success(); |
| 124 | + } |
| 125 | +}; |
| 126 | +} // namespace |
| 127 | + |
| 128 | +FailureOr<Operation *> transposeConv2D(RewriterBase &rewriter, |
| 129 | + linalg::Conv2DNhwcFhwcOp op) { |
| 130 | + |
| 131 | + return transposeConv2DHelper<linalg::Conv2DNhwcFhwcOp, |
| 132 | + linalg::Conv2DNhwcHwcfOp>(rewriter, op); |
| 133 | +} |
| 134 | + |
| 135 | +FailureOr<Operation *> transposeConv2D(RewriterBase &rewriter, |
| 136 | + linalg::Conv2DNhwcFhwcQOp op) { |
| 137 | + |
| 138 | + return transposeConv2DHelper<linalg::Conv2DNhwcFhwcQOp, |
| 139 | + linalg::Conv2DNhwcHwcfQOp>(rewriter, op); |
| 140 | +} |
| 141 | + |
| 142 | +void populateTranposeConv2DPatterns(RewritePatternSet &patterns) { |
| 143 | + MLIRContext *context = patterns.getContext(); |
| 144 | + patterns.insert< |
| 145 | + ConvConverter<linalg::Conv2DNhwcFhwcOp, linalg::Conv2DNhwcHwcfOp>, |
| 146 | + ConvConverter<linalg::Conv2DNhwcFhwcQOp, linalg::Conv2DNhwcHwcfQOp>>( |
| 147 | + context); |
| 148 | +} |
| 149 | +} // namespace linalg |
| 150 | +} // namespace mlir |
0 commit comments