Skip to content

Commit 4a3d208

Browse files
[mlir][linalg] Add TransposeConv2D Transform Op (#68567)
* Add a LinAlg pass to convert 2D convolutions and quantized 2D convolutions that have the `FHWC` filter channel ordering into a transpose followed by 2D convolutions that have the `HWCF` channel ordering. * Add a lit test to check the semantics of the transformation are correct for both quantized and unquantized variants. Signed-off-by: Jack Frankland <[email protected]>
1 parent 06157a6 commit 4a3d208

File tree

6 files changed

+411
-0
lines changed

6 files changed

+411
-0
lines changed

mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td

+49
Original file line numberDiff line numberDiff line change
@@ -2249,6 +2249,55 @@ def ConvertConv2DToImg2ColOp : Op<Transform_Dialect,
22492249
}];
22502250
}
22512251

2252+
//===----------------------------------------------------------------------===//
2253+
// Transpose Conv2D
2254+
//===----------------------------------------------------------------------===//
2255+
2256+
def TransposeConv2DOp : Op<Transform_Dialect,
2257+
"structured.transpose_conv2d",
2258+
[FunctionalStyleTransformOpTrait,
2259+
MemoryEffectsOpInterface,
2260+
TransformOpInterface,
2261+
TransformEachOpTrait,
2262+
ReportTrackingListenerFailuresOpTrait]> {
2263+
let description = [{
2264+
Convert linalg.conv_2d_nhwc_fhwc into linalg.conv_2d_nhwc_hwcf by introducing
2265+
a linalg.transpose on the filter tensor/memref.
2266+
2267+
Whilst the fhwc filter channel ordering can be desirable for certain targets
2268+
and is a more direct mapping to higher level dialects such as TOSA (which only
2269+
supports this ordering) hwcf is better suited for transformations such as
2270+
img2col which can make use of optimized BLAS routines such as GEMM.
2271+
2272+
Returns one handle:
2273+
- The final operation of the sequence that replaces the original
2274+
convolution.
2275+
2276+
#### Return modes:
2277+
2278+
Returns a definite failure if target is not isolated from above.
2279+
Returns a silenceable failure if the pattern application failed.
2280+
}];
2281+
2282+
let arguments = (ins TransformHandleTypeInterface:$target);
2283+
let results = (outs TransformHandleTypeInterface:$transformed);
2284+
2285+
let assemblyFormat =
2286+
"$target attr-dict `:` functional-type($target, results)";
2287+
2288+
let builders = [
2289+
OpBuilder<(ins "Value":$target)>
2290+
];
2291+
2292+
let extraClassDeclaration = [{
2293+
::mlir::DiagnosedSilenceableFailure applyToOne(
2294+
::mlir::transform::TransformRewriter &rewriter,
2295+
::mlir::linalg::LinalgOp target,
2296+
::mlir::transform::ApplyToEachResultList &results,
2297+
::mlir::transform::TransformState &state);
2298+
}];
2299+
}
2300+
22522301
//===----------------------------------------------------------------------===//
22532302
// InsertSliceToCopyOp
22542303
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

+7
Original file line numberDiff line numberDiff line change
@@ -1225,6 +1225,13 @@ rewriteInIm2Col(RewriterBase &rewriter,
12251225
FailureOr<std::pair<Operation *, Operation *>>
12261226
rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNchwFchwOp convOp);
12271227

1228+
/// Convert linalg.conv_2d_nhwc_fhwc(_q) to linalg.conv_2d_nhwc_hwcf(_q) by
1229+
/// materializing transpose.
1230+
FailureOr<Operation *> transposeConv2D(RewriterBase &rewriter,
1231+
linalg::Conv2DNhwcFhwcOp op);
1232+
FailureOr<Operation *> transposeConv2D(RewriterBase &rewriter,
1233+
linalg::Conv2DNhwcFhwcQOp op);
1234+
12281235
//===----------------------------------------------------------------------===//
12291236
// Rewrite patterns wrapping transformations.
12301237
// TODO: every single such pattern should be a close to noop wrapper around a

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

+27
Original file line numberDiff line numberDiff line change
@@ -3169,6 +3169,33 @@ DiagnosedSilenceableFailure transform::ConvertConv2DToImg2ColOp::applyToOne(
31693169
return DiagnosedSilenceableFailure::success();
31703170
}
31713171

3172+
//===----------------------------------------------------------------------===//
3173+
// TransposeConv2DOp
3174+
//===----------------------------------------------------------------------===//
3175+
3176+
DiagnosedSilenceableFailure transform::TransposeConv2DOp::applyToOne(
3177+
transform::TransformRewriter &rewriter, linalg::LinalgOp target,
3178+
transform::ApplyToEachResultList &results,
3179+
transform::TransformState &state) {
3180+
rewriter.setInsertionPoint(target);
3181+
auto maybeTransformed =
3182+
TypeSwitch<Operation *, FailureOr<Operation *>>(target)
3183+
.Case([&](linalg::Conv2DNhwcFhwcOp op) {
3184+
return transposeConv2D(rewriter, op);
3185+
})
3186+
.Case([&](linalg::Conv2DNhwcFhwcQOp op) {
3187+
return transposeConv2D(rewriter, op);
3188+
})
3189+
.Default([&](Operation *op) {
3190+
return rewriter.notifyMatchFailure(op, "not supported");
3191+
});
3192+
if (failed(maybeTransformed))
3193+
return emitDefaultSilenceableFailure(target);
3194+
// Handle to the new Conv2D operation with transposed filters
3195+
results.push_back(*maybeTransformed);
3196+
return DiagnosedSilenceableFailure::success();
3197+
}
3198+
31723199
//===----------------------------------------------------------------------===//
31733200
// InsertSliceToCopyOp
31743201
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
3232
Tiling.cpp
3333
TilingInterfaceImpl.cpp
3434
Transforms.cpp
35+
TransposeConv2D.cpp
3536
Vectorization.cpp
3637

3738
ADDITIONAL_HEADER_DIRS
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
//===- TransposeConv2D.cpp - Convolution transposition -------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/Dialect/Func/IR/FuncOps.h"
10+
#include "mlir/Dialect/Linalg/IR/Linalg.h"
11+
#include "mlir/Dialect/MemRef/IR/MemRef.h"
12+
#include "mlir/Dialect/Tensor/IR/Tensor.h"
13+
#include "mlir/IR/BuiltinTypes.h"
14+
#include "mlir/IR/PatternMatch.h"
15+
#include "mlir/IR/ValueRange.h"
16+
#include "mlir/Support/LogicalResult.h"
17+
#include "mlir/Transforms/DialectConversion.h"
18+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
19+
#include "llvm/ADT/SmallVector.h"
20+
#include "llvm/Support/ErrorHandling.h"
21+
#include "llvm/Support/RWMutex.h"
22+
#include <memory>
23+
#include <numeric>
24+
25+
namespace mlir {
26+
namespace linalg {
27+
namespace {
28+
// clang-format off
29+
/// Convolution converter that applies the following rewrite:
30+
///
31+
/// Before:
32+
///
33+
/// %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>,
34+
/// strides = dense<2> : tensor<2xi64>}
35+
/// ins (%input, %filter: tensor<1x4x4x6xf32>, tensor<8x2x2x6xf32>)
36+
/// outs (%init: tensor<1x2x2x8xf32>) -> tensor<1x2x2x8xf32>
37+
///
38+
/// After:
39+
///
40+
/// %cst = arith.constant 0.000000e+00 : f32
41+
/// %0 = tensor.empty() : tensor<2x2x6x8xf32>
42+
/// %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<2x2x6x8xf32>) -> tensor<2x2x6x8xf32>
43+
/// %transposed = linalg.transpose ins(%arg1 : tensor<8x2x2x6xf32>) outs(%1 : tensor<2x2x6x8xf32>)
44+
/// permutation = [1, 2, 3, 0]
45+
/// %2 = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>}
46+
/// ins(%arg0, %transposed : tensor<1x4x4x6xf32>, tensor<2x2x6x8xf32>) outs(%arg2 : tensor<1x2x2x8xf32>)
47+
/// -> tensor<1x2x2x8xf32>
48+
///
49+
/// with an analogous example for the quantized case.
50+
// clang-format on
51+
template <typename FHWCConvOp, typename HWCFConvOp>
52+
FailureOr<Operation *> transposeConv2DHelper(RewriterBase &rewriter,
53+
FHWCConvOp op) {
54+
// Construct a permutation of the filter tensor dimensions. For a 2D
55+
// convolution this will be known statically as [1, 2, 3, 0].
56+
SmallVector<int64_t> filterPerm({1, 2, 3, 0});
57+
58+
// Create the type for the transposed filter tensor.
59+
auto filter = op->getOperand(1);
60+
auto filterTy = cast<ShapedType>(filter.getType());
61+
SmallVector<int64_t> newFilterShape(filterPerm.size());
62+
std::generate(std::begin(newFilterShape), std::end(newFilterShape),
63+
[dim = 0, &filterTy, &filterPerm]() mutable {
64+
return filterTy.getShape()[filterPerm[dim++]];
65+
});
66+
67+
// Because linalg.transpose expects an "out" parameter we need to pass it a
68+
// tensor of zeros of the result type so here we construct that tensor.
69+
auto inputType = op->getOperand(0).getType();
70+
auto elementTy = cast<ShapedType>(inputType).getElementType();
71+
auto loc = op->getLoc();
72+
73+
const auto isTensorOp = isa<TensorType>(inputType);
74+
Value input;
75+
if (isTensorOp) {
76+
77+
input = rewriter.create<tensor::EmptyOp>(loc, newFilterShape, elementTy)
78+
.getResult();
79+
} else {
80+
input = rewriter
81+
.create<memref::AllocOp>(
82+
loc, MemRefType::get(newFilterShape, elementTy))
83+
.getResult();
84+
}
85+
86+
// We can then construct the transposition on our filter.
87+
auto transpose =
88+
rewriter.create<linalg::TransposeOp>(loc, filter, input, filterPerm);
89+
90+
Value newFilter;
91+
if (isTensorOp) {
92+
newFilter = transpose.getResult()[0];
93+
} else {
94+
newFilter = input;
95+
}
96+
97+
SmallVector<Value> newInputs{op.getInputs()};
98+
// The filter is always the second input argument, the other inputs can be
99+
// left as they are.
100+
newInputs[1] = newFilter;
101+
// It is possible the convolution doesn't define any results and its
102+
// out argument is just used instead.
103+
SmallVector<Type> resultTy;
104+
if (op.getNumResults()) {
105+
resultTy.push_back(op->getResult(0).getType());
106+
}
107+
auto newConv =
108+
rewriter.create<HWCFConvOp>(loc, resultTy, newInputs, op.getOutputs(),
109+
op.getStrides(), op.getDilations());
110+
rewriter.replaceOp(op, newConv);
111+
return newConv.getOperation();
112+
}
113+
114+
template <typename FHWCConvOp, typename HWCFConvOp>
115+
class ConvConverter : public OpRewritePattern<FHWCConvOp> {
116+
public:
117+
using OpRewritePattern<FHWCConvOp>::OpRewritePattern;
118+
LogicalResult matchAndRewrite(FHWCConvOp op,
119+
PatternRewriter &rewriter) const final {
120+
if (failed(transposeConv2DHelper<FHWCConvOp, HWCFConvOp>(rewriter, op))) {
121+
return failure();
122+
}
123+
return success();
124+
}
125+
};
126+
} // namespace
127+
128+
FailureOr<Operation *> transposeConv2D(RewriterBase &rewriter,
129+
linalg::Conv2DNhwcFhwcOp op) {
130+
131+
return transposeConv2DHelper<linalg::Conv2DNhwcFhwcOp,
132+
linalg::Conv2DNhwcHwcfOp>(rewriter, op);
133+
}
134+
135+
FailureOr<Operation *> transposeConv2D(RewriterBase &rewriter,
136+
linalg::Conv2DNhwcFhwcQOp op) {
137+
138+
return transposeConv2DHelper<linalg::Conv2DNhwcFhwcQOp,
139+
linalg::Conv2DNhwcHwcfQOp>(rewriter, op);
140+
}
141+
142+
void populateTranposeConv2DPatterns(RewritePatternSet &patterns) {
143+
MLIRContext *context = patterns.getContext();
144+
patterns.insert<
145+
ConvConverter<linalg::Conv2DNhwcFhwcOp, linalg::Conv2DNhwcHwcfOp>,
146+
ConvConverter<linalg::Conv2DNhwcFhwcQOp, linalg::Conv2DNhwcHwcfQOp>>(
147+
context);
148+
}
149+
} // namespace linalg
150+
} // namespace mlir

0 commit comments

Comments
 (0)