Skip to content

Commit e29a253

Browse files
[mlir][tosa][linalg] Apply direct tosa -> linalg Conv2D lowering (#68304)
TOSA defines the filter channel ordering for 2D convolution operation `tosa.conv2d` as `[OC, KH, KW, IC]`. The LinAlg dialect supports `[F, H, W, C]` and `[H, W, C, F]` orderings via the `linalg.conv_2d_nhwc_fhwc` and `linalg.conv_2d_nhwc_hwcf` operations respectively. Where `F == OC`, `KH == H`, `KW == W` and `C == IC`. Currently `tosa.conv2d` is lowered to `linalg.conv2d_nhwc_hwcf` meaning we need to insert a transposition operation to permute the filter channels before they can be passed as weights to the linalg op, that is `[F, H, W, C]` -> `[H, W, C, F]`. An analogous transformation needs to be applied to the quantized operation that lowers to `linalg.conv_2d_nhwc_hwcf_q`. This commit updates the TOSA->LinAlg lowering so that `tosa.conv2d` is lowered to `linalg.conv2d_nhwc_fhwc` removing the need for the introduction of a transposition operation and making the mapping 1-1. It also adds a `linalg.conv_2d_nhwc_fhwc_q` quantized operation to the LinAlg dialect so the same direct 1-1 mapping can be applied to the quantized variant. This commit does not add any new lit tests but repurposes the current TosaToLinalgNamed tests by removing the checks for transpositions and updating the targeted LinAlg operations from `linalg.conv2d_nhwc_hwcf` to linalg.conv2d_nhwc_fhwc`. Signed-off-by: Jack Frankland <[email protected]>
1 parent 7850225 commit e29a253

File tree

4 files changed

+196
-34
lines changed

4 files changed

+196
-34
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml

+137
Original file line numberDiff line numberDiff line change
@@ -2575,6 +2575,143 @@ structured_op: !LinalgStructuredOpConfig
25752575
- !ScalarExpression
25762576
scalar_arg: KZp
25772577
--- !LinalgOpConfig
2578+
metadata: !LinalgOpMetadata
2579+
name: conv_2d_nhwc_fhwc_q
2580+
cpp_class_name: Conv2DNhwcFhwcQOp
2581+
doc: |-
2582+
Performs 2-D convolution with zero point offsets.
2583+
2584+
Layout:
2585+
* Input: NHWC.
2586+
* Kernel: FHWC.
2587+
2588+
Numeric casting is performed on the operands to the inner multiply, promoting
2589+
them to the same data type as the accumulator/output. This includes the zero
2590+
point offsets common to quantized operations.
2591+
implements:
2592+
- LinalgConvolutionOpInterface
2593+
structured_op: !LinalgStructuredOpConfig
2594+
args:
2595+
- !LinalgOperandDefConfig
2596+
name: I
2597+
kind: input_tensor
2598+
type_var: T1
2599+
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] -> (s0,
2600+
s1 * s2 + s3 * s4, s5 * s6 + s7 * s8, s9)>
2601+
- !LinalgOperandDefConfig
2602+
name: K
2603+
kind: input_tensor
2604+
type_var: T2
2605+
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] -> (s10,
2606+
s3, s7, s9)>
2607+
- !LinalgOperandDefConfig
2608+
name: IZp
2609+
kind: scalar
2610+
type_var: I32
2611+
- !LinalgOperandDefConfig
2612+
name: KZp
2613+
kind: scalar
2614+
type_var: I32
2615+
- !LinalgOperandDefConfig
2616+
name: O
2617+
kind: output_tensor
2618+
type_var: U
2619+
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] -> (s0,
2620+
s1, s5, s10)>
2621+
- !LinalgOperandDefConfig
2622+
name: strides
2623+
kind: index_attr
2624+
index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] ->
2625+
(s2, s6)>
2626+
default_indices:
2627+
- 1
2628+
- 1
2629+
- !LinalgOperandDefConfig
2630+
name: dilations
2631+
kind: index_attr
2632+
index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] ->
2633+
(s4, s8)>
2634+
default_indices:
2635+
- 1
2636+
- 1
2637+
indexing_maps: !LinalgIndexingMapsConfig
2638+
static_indexing_maps:
2639+
- affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8,
2640+
s9, s10] -> (d0, d1 * s2 + d4 * s4, d2 * s6 + d5 * s8, d6)>
2641+
- affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8,
2642+
s9, s10] -> (d3, d4, d5, d6)>
2643+
- affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8,
2644+
s9, s10] -> ()>
2645+
- affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8,
2646+
s9, s10] -> ()>
2647+
- affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8,
2648+
s9, s10] -> (d0, d1, d2, d3)>
2649+
iterator_types:
2650+
- parallel
2651+
- parallel
2652+
- parallel
2653+
- parallel
2654+
- reduction
2655+
- reduction
2656+
- reduction
2657+
assignments:
2658+
- !ScalarAssign
2659+
arg: O
2660+
value: !ScalarExpression
2661+
scalar_fn:
2662+
kind: binary
2663+
fn_name: add
2664+
operands:
2665+
- !ScalarExpression
2666+
scalar_arg: O
2667+
- !ScalarExpression
2668+
scalar_fn:
2669+
kind: binary
2670+
fn_name: mul
2671+
operands:
2672+
- !ScalarExpression
2673+
scalar_fn:
2674+
kind: binary
2675+
fn_name: sub
2676+
operands:
2677+
- !ScalarExpression
2678+
scalar_fn:
2679+
kind: type
2680+
fn_name: cast_signed
2681+
type_var: U
2682+
operands:
2683+
- !ScalarExpression
2684+
scalar_arg: I
2685+
- !ScalarExpression
2686+
scalar_fn:
2687+
kind: type
2688+
fn_name: cast_signed
2689+
type_var: U
2690+
operands:
2691+
- !ScalarExpression
2692+
scalar_arg: IZp
2693+
- !ScalarExpression
2694+
scalar_fn:
2695+
kind: binary
2696+
fn_name: sub
2697+
operands:
2698+
- !ScalarExpression
2699+
scalar_fn:
2700+
kind: type
2701+
fn_name: cast_signed
2702+
type_var: U
2703+
operands:
2704+
- !ScalarExpression
2705+
scalar_arg: K
2706+
- !ScalarExpression
2707+
scalar_fn:
2708+
kind: type
2709+
fn_name: cast_signed
2710+
type_var: U
2711+
operands:
2712+
- !ScalarExpression
2713+
scalar_arg: KZp
2714+
--- !LinalgOpConfig
25782715
metadata: !LinalgOpMetadata
25792716
name: conv_2d_nchw_fchw
25802717
cpp_class_name: Conv2DNchwFchwOp

mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp

+23-20
Original file line numberDiff line numberDiff line change
@@ -248,25 +248,28 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
248248
pad.resize(pad.size() + 2, 0);
249249
input = applyPad(loc, input, pad, zeroAttr, rewriter);
250250

251-
// Transpose the kernel to match dimension ordering of the linalg
252-
// convolution operation.
253-
// TODO(suderman): See if this can be efficiently folded - check whether
254-
// the input is used anywhere else, if not fold the constant.
255-
SmallVector<int64_t> weightPerm;
256-
for (int i = 1; i < resultTy.getRank(); i++)
257-
weightPerm.push_back(i);
258-
weightPerm.push_back(0);
259-
260-
SmallVector<int64_t> newWeightShape;
261-
for (auto dim : weightPerm)
262-
newWeightShape.push_back(weightShape[dim]);
263-
auto weightPermAttr = rewriter.getI64TensorAttr(weightPerm);
264-
Value weightPermValue =
265-
rewriter.create<arith::ConstantOp>(loc, weightPermAttr);
266-
Type newWeightTy =
267-
RankedTensorType::get(newWeightShape, weightTy.getElementType());
268-
weight = rewriter.create<tosa::TransposeOp>(loc, newWeightTy, weight,
269-
weightPermValue);
251+
// For Conv3D transpose the kernel to match dimension ordering of the linalg
252+
// convolution operation. Conv2D has a 1-1 mapping in linalg so better to
253+
// map directly and then transpose later if desired.
254+
if (5 == inputTy.getRank()) {
255+
// TODO(suderman): See if this can be efficiently folded - check whether
256+
// the input is used anywhere else, if not fold the constant.
257+
SmallVector<int64_t> weightPerm;
258+
for (int i = 1; i < resultTy.getRank(); i++)
259+
weightPerm.push_back(i);
260+
weightPerm.push_back(0);
261+
262+
SmallVector<int64_t> newWeightShape;
263+
for (auto dim : weightPerm)
264+
newWeightShape.push_back(weightShape[dim]);
265+
auto weightPermAttr = rewriter.getI64TensorAttr(weightPerm);
266+
Value weightPermValue =
267+
rewriter.create<arith::ConstantOp>(loc, weightPermAttr);
268+
Type newWeightTy =
269+
RankedTensorType::get(newWeightShape, weightTy.getElementType());
270+
weight = rewriter.create<tosa::TransposeOp>(loc, newWeightTy, weight,
271+
weightPermValue);
272+
}
270273

271274
auto resultZeroAttr = rewriter.getZeroAttr(resultETy);
272275
Value emptyTensor = rewriter.create<tensor::EmptyOp>(
@@ -977,7 +980,7 @@ void mlir::tosa::populateTosaToLinalgNamedConversionPatterns(
977980
RewritePatternSet *patterns) {
978981
patterns->add<
979982
// clang-format off
980-
ConvConverter<tosa::Conv2DOp, linalg::Conv2DNhwcHwcfOp, linalg::Conv2DNhwcHwcfQOp>,
983+
ConvConverter<tosa::Conv2DOp, linalg::Conv2DNhwcFhwcOp, linalg::Conv2DNhwcFhwcQOp>,
981984
ConvConverter<tosa::Conv3DOp, linalg::Conv3DNdhwcDhwcfOp, linalg::Conv3DNdhwcDhwcfQOp>,
982985
DepthwiseConvConverter,
983986
MatMulConverter,

mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py

+30
Original file line numberDiff line numberDiff line change
@@ -693,6 +693,36 @@ def conv_2d_nhwc_hwcf_q(
693693
) * (TypeFn.cast_signed(U, K[D.kh, D.kw, D.c, D.f]) - TypeFn.cast_signed(U, KZp))
694694

695695

696+
@linalg_structured_op
697+
def conv_2d_nhwc_fhwc_q(
698+
I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.C),
699+
K=TensorDef(T2, S.F, S.KH, S.KW, S.C),
700+
IZp=ScalarDef(I32),
701+
KZp=ScalarDef(I32),
702+
O=TensorDef(U, S.N, S.OH, S.OW, S.F, output=True),
703+
strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
704+
dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]),
705+
):
706+
"""Performs 2-D convolution with zero point offsets.
707+
708+
Layout:
709+
* Input: NHWC.
710+
* Kernel: FHWC.
711+
712+
Numeric casting is performed on the operands to the inner multiply, promoting
713+
them to the same data type as the accumulator/output. This includes the zero
714+
point offsets common to quantized operations.
715+
"""
716+
implements(ConvolutionOpInterface)
717+
domain(D.n, D.oh, D.ow, D.f, D.kh, D.kw, D.c)
718+
O[D.n, D.oh, D.ow, D.f] += (
719+
TypeFn.cast_signed(
720+
U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]
721+
)
722+
- TypeFn.cast_signed(U, IZp)
723+
) * (TypeFn.cast_signed(U, K[D.f, D.kh, D.kw, D.c]) - TypeFn.cast_signed(U, KZp))
724+
725+
696726
@linalg_structured_op
697727
def conv_2d_nchw_fchw(
698728
I=TensorDef(T1, S.N, S.C, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW),

mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir

+6-14
Original file line numberDiff line numberDiff line change
@@ -363,13 +363,11 @@ func.func @avg_pool_dyn(%arg0: tensor<?x6x34x62xf32>) -> (tensor<?x5x33x62xf32>)
363363

364364
// CHECK-LABEL: @conv2d_i8
365365
func.func @conv2d_i8(%input: tensor<1x49x42x27xi8>, %weights: tensor<28x1x1x27xi8>, %bias: tensor<28xi8>) -> () {
366-
// CHECK: %[[PERM:.+]] = arith.constant dense<[1, 2, 3, 0]>
367-
// CHECK: %[[W:.+]] = tosa.transpose %arg1, %[[PERM]]
368366
// CHECK: %[[M_IN:.+]] = tensor.empty()
369367
// CHECK: %[[CST:.+]] = arith.constant 0
370368
// CHECK: %[[FILL:.+]] = linalg.fill
371369
// CHECK: %[[B_IN:.+]] = tensor.empty()
372-
// CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf_q {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %[[W]], %c0_i32_0, %c0_i32_1 : tensor<1x49x42x27xi8>, tensor<1x1x27x28xi8>, i32, i32) outs(%[[FILL]] : tensor<1x45x40x28xi32>) -> tensor<1x45x40x28xi32>
370+
// CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_fhwc_q {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1, %c0_i32_0, %c0_i32_1 : tensor<1x49x42x27xi8>, tensor<28x1x1x27xi8>, i32, i32) outs(%[[FILL]] : tensor<1x45x40x28xi32>) -> tensor<1x45x40x28xi32>
373371
// CHECK: %[[B:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2, %[[CONV]] : tensor<28xi8>, tensor<1x45x40x28xi32>) outs(%[[B_IN]] : tensor<1x45x40x28xi32>)
374372
// CHECK: arith.extsi
375373
// CHECK: arith.addi
@@ -385,13 +383,11 @@ func.func @conv2d_i8(%input: tensor<1x49x42x27xi8>, %weights: tensor<28x1x1x27xi
385383

386384
// CHECK-LABEL: @conv2d_f32
387385
func.func @conv2d_f32(%input: tensor<1x49x42x27xf32>, %weights: tensor<28x3x3x27xf32>, %bias: tensor<28xf32>) -> () {
388-
// CHECK: %[[PERM:.+]] = arith.constant dense<[1, 2, 3, 0]>
389-
// CHECK: %[[W:.+]] = tosa.transpose %arg1, %[[PERM]]
390386
// CHECK: %[[M_IN:.+]] = tensor.empty()
391387
// CHECK: %[[CST:.+]] = arith.constant 0
392388
// CHECK: %[[FILL:.+]] = linalg.fill
393389
// CHECK: %[[B_IN:.+]] = tensor.empty()
394-
// CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %[[W]] : tensor<1x49x42x27xf32>, tensor<3x3x27x28xf32>) outs(%[[FILL]] : tensor<1x45x40x28xf32>)
390+
// CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_fhwc {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<1x49x42x27xf32>, tensor<28x3x3x27xf32>) outs(%[[FILL]] : tensor<1x45x40x28xf32>)
395391
// CHECK: %[[B:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2, %[[CONV]] : tensor<28xf32>, tensor<1x45x40x28xf32>) outs(%[[B_IN]] : tensor<1x45x40x28xf32>)
396392
// CHECK: arith.addf
397393
// CHECK: linalg.yield
@@ -408,13 +404,11 @@ func.func @conv2d_f32(%input: tensor<1x49x42x27xf32>, %weights: tensor<28x3x3x27
408404
func.func @conv2d_dyn(%input: tensor<?x49x42x27xf32>, %weights: tensor<28x3x3x27xf32>, %bias: tensor<28xf32>) -> () {
409405
// CHECK: %[[C0:.+]] = arith.constant 0
410406
// CHECK: %[[BATCH:.+]] = tensor.dim %arg0, %[[C0]]
411-
// CHECK: %[[PERM:.+]] = arith.constant dense<[1, 2, 3, 0]>
412-
// CHECK: %[[W:.+]] = tosa.transpose %arg1, %[[PERM]]
413407
// CHECK: %[[M_IN:.+]] = tensor.empty(%[[BATCH]])
414408
// CHECK: %[[CST:.+]] = arith.constant 0
415409
// CHECK: %[[FILL:.+]] = linalg.fill
416410
// CHECK: %[[B_IN:.+]] = tensor.empty(%[[BATCH]])
417-
// CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %[[W]] : tensor<?x49x42x27xf32>, tensor<3x3x27x28xf32>) outs(%[[FILL]] : tensor<?x45x40x28xf32>)
411+
// CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_fhwc {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<?x49x42x27xf32>, tensor<28x3x3x27xf32>) outs(%[[FILL]] : tensor<?x45x40x28xf32>)
418412
// CHECK: %[[B:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2, %[[CONV]] : tensor<28xf32>, tensor<?x45x40x28xf32>) outs(%[[B_IN]] : tensor<?x45x40x28xf32>)
419413
// CHECK: %[[ADD:.+]] = arith.addf
420414
// CHECK: linalg.yield %[[ADD]] : f32
@@ -468,13 +462,11 @@ func.func @conv2d_dyn_w_h(%input: tensor<1x?x?x27xf32>, %weights: tensor<28x3x3x
468462
// CHECK: %[[W_OUT:.+]] = arith.addi %[[DIVIDED_0]], %[[ONE_0]] : index
469463

470464
// Running convolution
471-
// CHECK: %[[PERM:.+]] = arith.constant dense<[1, 2, 3, 0]>
472-
// CHECK: %[[WEIGHT:.+]] = tosa.transpose %arg1, %[[PERM]]
473465
// CHECK: %[[M_IN:.+]] = tensor.empty(%[[H_OUT]], %[[W_OUT]])
474466
// CHECK: %[[CST:.+]] = arith.constant 0
475467
// CHECK: %[[FILL:.+]] = linalg.fill
476468
// CHECK: %[[B_IN:.+]] = tensor.empty(%[[H_OUT]], %[[W_OUT]])
477-
// CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %[[WEIGHT]] : tensor<1x?x?x27xf32>, tensor<3x3x27x28xf32>) outs(%[[FILL]] : tensor<1x?x?x28xf32>)
469+
// CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_fhwc {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<1x?x?x27xf32>, tensor<28x3x3x27xf32>) outs(%[[FILL]] : tensor<1x?x?x28xf32>)
478470
// CHECK: %[[B:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2, %[[CONV]] : tensor<28xf32>, tensor<1x?x?x28xf32>) outs(%[[B_IN]] : tensor<1x?x?x28xf32>)
479471
// CHECK: %[[ADD:.+]] = arith.addf
480472
// CHECK: linalg.yield %[[ADD]] : f32
@@ -489,7 +481,7 @@ func.func @conv2d_padded_f32(%input: tensor<1x47x40x28xf32>, %weights: tensor<28
489481
// CHECK: %[[C0:.+]] = arith.constant 0
490482
// CHECK: tensor.pad %arg0 low[0, 1, 1, 0] high[0, 1, 1, 0]
491483
// CHECK: tensor.yield %[[C0]]
492-
// CHECK: linalg.conv_2d_nhwc_hwcf
484+
// CHECK: linalg.conv_2d_nhwc_fhwc
493485
%0 = tosa.conv2d %input, %weights, %bias {pad = array<i64: 1, 1, 1, 1>, stride = array<i64: 1, 1>, dilation = array<i64: 2, 1>} : (tensor<1x47x40x28xf32>, tensor<28x3x3x28xf32>, tensor<28xf32>) -> tensor<1x45x40x28xf32>
494486
return
495487
}
@@ -501,7 +493,7 @@ func.func @conv2d_quant(%arg0 : tensor<1x12x12x1xi8>, %arg1 : tensor<1024x3x3x1x
501493
// CHECK: %[[C22:.+]] = arith.constant -22
502494
// CHECK: tensor.pad %arg0 low[0, 1, 1, 0] high[0, 1, 1, 0]
503495
// CHECK: tensor.yield %[[C22]]
504-
// CHECK: linalg.conv_2d_nhwc_hwcf_q
496+
// CHECK: linalg.conv_2d_nhwc_fhwc_q
505497
%0 = tosa.conv2d %arg0, %arg1, %arg2 {dilation = array<i64: 1, 1>, pad = array<i64: 1, 1, 1, 1>, quantization_info = #tosa.conv_quant<input_zp = -22, weight_zp = 42>, stride = array<i64: 1, 1>} : (tensor<1x12x12x1xi8>, tensor<1024x3x3x1xi8>, tensor<1024xi32>) -> tensor<1x12x12x1024xi32>
506498
return
507499
}

0 commit comments

Comments
 (0)