Skip to content

Commit acc6f3e

Browse files
authored
TosaToLinalgNamed: add option to prefer HWCF kernel layout for Conv2D ops. (#70482)
Switching to FHWC happened in #68304 and is fine in itself but caused downstream performance regression iree-org/iree#15296 (comment) , so this PR makes this optional.
1 parent 11b3b38 commit acc6f3e

File tree

6 files changed

+73
-8
lines changed

6 files changed

+73
-8
lines changed

mlir/include/mlir/Conversion/Passes.td

+6
Original file line numberDiff line numberDiff line change
@@ -1126,6 +1126,12 @@ def TosaToLinalgNamed
11261126
Linalg named operations.
11271127
}];
11281128

1129+
let options = [
1130+
Option<"preferConv2DKernelLayoutHWCF", "prefer-conv2d-kernel-layout-hwcf",
1131+
"bool", /*default=*/"false",
1132+
"Prefer generating linalg.conv_2d_nhwc_hwcf over linalg.conv_2d_nhwc_fhwc">
1133+
];
1134+
11291135
let constructor = "tosa::createTosaToLinalgNamed()";
11301136
}
11311137

mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h

+6-2
Original file line numberDiff line numberDiff line change
@@ -26,14 +26,17 @@ namespace mlir {
2626
namespace tosa {
2727

2828
std::unique_ptr<Pass> createTosaToLinalg();
29-
std::unique_ptr<Pass> createTosaToLinalgNamed();
29+
std::unique_ptr<Pass> createTosaToLinalgNamed(
30+
const TosaToLinalgNamedOptions &options = TosaToLinalgNamedOptions());
3031

3132
/// Populates passes to convert from TOSA to Linalg on buffers. At the end of
3233
/// the pass, the function will only contain linalg ops or standard ops if the
3334
/// pipeline succeeds. The option to disable decompositions is available for
3435
/// benchmarking performance improvements from the canonicalizations.
3536
void addTosaToLinalgPasses(
3637
OpPassManager &pm, const TosaToLinalgOptions &options,
38+
const TosaToLinalgNamedOptions &tosaToLinalgNamedOptions =
39+
TosaToLinalgNamedOptions(),
3740
// Note: Default to 'none' level unless otherwise specified.
3841
tosa::TosaValidationOptions const &validationOptions = {
3942
tosa::TosaProfileEnum::Undefined, false, tosa::TosaLevelEnum::None});
@@ -46,7 +49,8 @@ void registerTosaToLinalgPipelines();
4649
void populateTosaToLinalgConversionPatterns(RewritePatternSet *patterns);
4750

4851
/// Populates conversion passes from TOSA dialect to Linalg named operations.
49-
void populateTosaToLinalgNamedConversionPatterns(RewritePatternSet *patterns);
52+
void populateTosaToLinalgNamedConversionPatterns(
53+
RewritePatternSet *patterns, const TosaToLinalgNamedOptions &options);
5054

5155
} // namespace tosa
5256
} // namespace mlir

mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp

+40-2
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
2727

2828
#include <numeric>
29+
#include <type_traits>
2930

3031
using namespace mlir;
3132
using namespace mlir::tosa;
@@ -248,6 +249,35 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
248249
pad.resize(pad.size() + 2, 0);
249250
input = applyPad(loc, input, pad, zeroAttr, rewriter);
250251

252+
if (4 == inputTy.getRank()) {
253+
// For 2D convolutions, we need to check if the target convolution op
254+
// wants a HWCF kernel layout.
255+
bool wantHwcf =
256+
isQuantized ? std::is_same_v<LinalgConvQOp, linalg::Conv2DNhwcHwcfQOp>
257+
: std::is_same_v<LinalgConvOp, linalg::Conv2DNhwcHwcfOp>;
258+
if (wantHwcf) {
259+
// Transpose the kernel to match dimension ordering of the linalg
260+
// convolution operation.
261+
// TODO(suderman): See if this can be efficiently folded - check whether
262+
// the input is used anywhere else, if not fold the constant.
263+
SmallVector<int64_t> weightPerm;
264+
for (int i = 1; i < resultTy.getRank(); i++)
265+
weightPerm.push_back(i);
266+
weightPerm.push_back(0);
267+
268+
SmallVector<int64_t> newWeightShape;
269+
for (auto dim : weightPerm)
270+
newWeightShape.push_back(weightShape[dim]);
271+
auto weightPermAttr = rewriter.getI64TensorAttr(weightPerm);
272+
Value weightPermValue =
273+
rewriter.create<arith::ConstantOp>(loc, weightPermAttr);
274+
Type newWeightTy =
275+
RankedTensorType::get(newWeightShape, weightTy.getElementType());
276+
weight = rewriter.create<tosa::TransposeOp>(loc, newWeightTy, weight,
277+
weightPermValue);
278+
}
279+
}
280+
251281
// For Conv3D transpose the kernel to match dimension ordering of the linalg
252282
// convolution operation. Conv2D has a 1-1 mapping in linalg so better to
253283
// map directly and then transpose later if desired.
@@ -977,10 +1007,18 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
9771007
} // namespace
9781008

9791009
void mlir::tosa::populateTosaToLinalgNamedConversionPatterns(
980-
RewritePatternSet *patterns) {
1010+
RewritePatternSet *patterns, const TosaToLinalgNamedOptions &options) {
1011+
if (options.preferConv2DKernelLayoutHWCF) {
1012+
patterns->add<ConvConverter<tosa::Conv2DOp, linalg::Conv2DNhwcHwcfOp,
1013+
linalg::Conv2DNhwcHwcfQOp>>(
1014+
patterns->getContext());
1015+
} else {
1016+
patterns->add<ConvConverter<tosa::Conv2DOp, linalg::Conv2DNhwcFhwcOp,
1017+
linalg::Conv2DNhwcFhwcQOp>>(
1018+
patterns->getContext());
1019+
}
9811020
patterns->add<
9821021
// clang-format off
983-
ConvConverter<tosa::Conv2DOp, linalg::Conv2DNhwcFhwcOp, linalg::Conv2DNhwcFhwcQOp>,
9841022
ConvConverter<tosa::Conv3DOp, linalg::Conv3DNdhwcDhwcfOp, linalg::Conv3DNdhwcDhwcfQOp>,
9851023
DepthwiseConvConverter,
9861024
MatMulConverter,

mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamedPass.cpp

+9-3
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,9 @@ namespace {
3737
struct TosaToLinalgNamed
3838
: public impl::TosaToLinalgNamedBase<TosaToLinalgNamed> {
3939
public:
40+
TosaToLinalgNamed(const TosaToLinalgNamedOptions &options)
41+
: impl::TosaToLinalgNamedBase<TosaToLinalgNamed>(options) {}
42+
4043
void getDependentDialects(DialectRegistry &registry) const override {
4144
registry
4245
.insert<arith::ArithDialect, linalg::LinalgDialect, math::MathDialect,
@@ -61,13 +64,16 @@ struct TosaToLinalgNamed
6164
target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
6265

6366
FunctionOpInterface func = getOperation();
64-
mlir::tosa::populateTosaToLinalgNamedConversionPatterns(&patterns);
67+
TosaToLinalgNamedOptions options;
68+
options.preferConv2DKernelLayoutHWCF = preferConv2DKernelLayoutHWCF;
69+
tosa::populateTosaToLinalgNamedConversionPatterns(&patterns, options);
6570
if (failed(applyFullConversion(func, target, std::move(patterns))))
6671
signalPassFailure();
6772
}
6873
};
6974
} // namespace
7075

71-
std::unique_ptr<Pass> mlir::tosa::createTosaToLinalgNamed() {
72-
return std::make_unique<TosaToLinalgNamed>();
76+
std::unique_ptr<Pass>
77+
mlir::tosa::createTosaToLinalgNamed(const TosaToLinalgNamedOptions &options) {
78+
return std::make_unique<TosaToLinalgNamed>(options);
7379
}

mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp

+5-1
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ std::unique_ptr<Pass> mlir::tosa::createTosaToLinalg() {
7676

7777
void mlir::tosa::addTosaToLinalgPasses(
7878
OpPassManager &pm, const TosaToLinalgOptions &options,
79+
const TosaToLinalgNamedOptions &tosaToLinalgNamedOptions,
7980
tosa::TosaValidationOptions const &validationOptions) {
8081
// Optional decompositions are designed to benefit linalg.
8182
if (!options.disableTosaDecompositions)
@@ -84,7 +85,8 @@ void mlir::tosa::addTosaToLinalgPasses(
8485

8586
pm.addNestedPass<func::FuncOp>(tosa::createTosaInferShapesPass());
8687
pm.addNestedPass<func::FuncOp>(tosa::createTosaMakeBroadcastablePass());
87-
pm.addNestedPass<func::FuncOp>(tosa::createTosaToLinalgNamed());
88+
pm.addNestedPass<func::FuncOp>(
89+
tosa::createTosaToLinalgNamed(tosaToLinalgNamedOptions));
8890
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
8991
// TODO: Remove pass that operates on const tensor and enable optionality
9092
pm.addNestedPass<func::FuncOp>(tosa::createTosaLayerwiseConstantFoldPass(
@@ -106,7 +108,9 @@ void mlir::tosa::registerTosaToLinalgPipelines() {
106108
"named operations.",
107109
[](OpPassManager &pm) {
108110
TosaToLinalgOptions tosaToLinalgOptions;
111+
TosaToLinalgNamedOptions tosaToLinalgNamedOptions;
109112
tosa::addTosaToLinalgPasses(pm, tosaToLinalgOptions,
113+
tosaToLinalgNamedOptions,
110114
/* validationOptions = */
111115
{tosa::TosaProfileEnum::BaseInference,
112116
/* StrictOperationSpecAlignment = */ true,

mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir

+7
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
// RUN: mlir-opt --split-input-file -pass-pipeline="builtin.module(func.func(tosa-to-linalg-named))" %s -verify-diagnostics -o -| FileCheck %s
2+
// RUN: mlir-opt --split-input-file -pass-pipeline="builtin.module(func.func(tosa-to-linalg-named{prefer-conv2d-kernel-layout-hwcf=true}))" %s -verify-diagnostics -o -| FileCheck --check-prefix="HWCF" %s
23

34
// CHECK-LABEL: @matmul
45
func.func @matmul(%arg0: tensor<1x5x3xf32>, %arg1: tensor<1x3x6xf32>) -> (tensor<1x5x6xf32>) {
@@ -363,11 +364,14 @@ func.func @avg_pool_dyn(%arg0: tensor<?x6x34x62xf32>) -> (tensor<?x5x33x62xf32>)
363364

364365
// CHECK-LABEL: @conv2d_i8
365366
func.func @conv2d_i8(%input: tensor<1x49x42x27xi8>, %weights: tensor<28x1x1x27xi8>, %bias: tensor<28xi8>) -> () {
367+
// HWCF: %[[TRANSPOSE_DIMS:.+]] = arith.constant dense<[1, 2, 3, 0]> : tensor<4xi64>
368+
// HWCF: %[[TRANSPOSE:.+]] = tosa.transpose %arg1, %[[TRANSPOSE_DIMS]] : (tensor<28x1x1x27xi8>, tensor<4xi64>) -> tensor<1x1x27x28xi8>
366369
// CHECK: %[[M_IN:.+]] = tensor.empty()
367370
// CHECK: %[[CST:.+]] = arith.constant 0
368371
// CHECK: %[[FILL:.+]] = linalg.fill
369372
// CHECK: %[[B_IN:.+]] = tensor.empty()
370373
// CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_fhwc_q {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1, %c0_i32_0, %c0_i32_1 : tensor<1x49x42x27xi8>, tensor<28x1x1x27xi8>, i32, i32) outs(%[[FILL]] : tensor<1x45x40x28xi32>) -> tensor<1x45x40x28xi32>
374+
// HWCF: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf_q {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %[[TRANSPOSE]], %c0_i32_0, %c0_i32_1 : tensor<1x49x42x27xi8>, tensor<1x1x27x28xi8>, i32, i32) outs(%{{[a-zA-Z0-9_]*}} : tensor<1x45x40x28xi32>) -> tensor<1x45x40x28xi32>
371375
// CHECK: %[[B:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2, %[[CONV]] : tensor<28xi8>, tensor<1x45x40x28xi32>) outs(%[[B_IN]] : tensor<1x45x40x28xi32>)
372376
// CHECK: arith.extsi
373377
// CHECK: arith.addi
@@ -383,11 +387,14 @@ func.func @conv2d_i8(%input: tensor<1x49x42x27xi8>, %weights: tensor<28x1x1x27xi
383387

384388
// CHECK-LABEL: @conv2d_f32
385389
func.func @conv2d_f32(%input: tensor<1x49x42x27xf32>, %weights: tensor<28x3x3x27xf32>, %bias: tensor<28xf32>) -> () {
390+
// HWCF: %[[TRANSPOSE_DIMS:.+]] = arith.constant dense<[1, 2, 3, 0]> : tensor<4xi64>
391+
// HWCF: %[[TRANSPOSE:.+]] = tosa.transpose %arg1, %[[TRANSPOSE_DIMS]] : (tensor<28x3x3x27xf32>, tensor<4xi64>) -> tensor<3x3x27x28xf32>
386392
// CHECK: %[[M_IN:.+]] = tensor.empty()
387393
// CHECK: %[[CST:.+]] = arith.constant 0
388394
// CHECK: %[[FILL:.+]] = linalg.fill
389395
// CHECK: %[[B_IN:.+]] = tensor.empty()
390396
// CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_fhwc {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<1x49x42x27xf32>, tensor<28x3x3x27xf32>) outs(%[[FILL]] : tensor<1x45x40x28xf32>)
397+
// HWCF: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %[[TRANSPOSE]] : tensor<1x49x42x27xf32>, tensor<3x3x27x28xf32>) outs(%{{[a-zA-Z0-9_]*}} : tensor<1x45x40x28xf32>
391398
// CHECK: %[[B:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2, %[[CONV]] : tensor<28xf32>, tensor<1x45x40x28xf32>) outs(%[[B_IN]] : tensor<1x45x40x28xf32>)
392399
// CHECK: arith.addf
393400
// CHECK: linalg.yield

0 commit comments

Comments
 (0)