Skip to content

Commit 967ab7e

Browse files
authored
[mlir][TOSA] Fix linalg lowering of depthwise conv2d (#130293)
Current lowering for tosa.depthwise_conv2d assumes if both zero points are zero then it's a floating-point operation by hardcoding the use of a arith.addf in the lowered code. Fix code to check for the element type to decide what add operation to use.
1 parent 8885b5c commit 967ab7e

File tree

2 files changed

+34
-5
lines changed

2 files changed

+34
-5
lines changed

mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -477,13 +477,13 @@ class DepthwiseConvConverter
477477
return rewriter.notifyMatchFailure(
478478
op, "weight zero point must be zero for non-int8 integer types");
479479

480-
bool hasZp = (inputZpVal != 0) || (weightZpVal != 0);
480+
bool hasNullZps = (inputZpVal == 0) && (weightZpVal == 0);
481481
auto weightShape = weightTy.getShape();
482482
auto resultShape = resultTy.getShape();
483483

484484
// Apply padding as necessary.
485485
TypedAttr zeroAttr = rewriter.getZeroAttr(inputETy);
486-
if (hasZp) {
486+
if (!hasNullZps) {
487487
int64_t intMin =
488488
APInt::getSignedMinValue(inputETy.getIntOrFloatBitWidth())
489489
.getSExtValue();
@@ -536,7 +536,7 @@ class DepthwiseConvConverter
536536
indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultRank));
537537
indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultRank));
538538

539-
if (!hasZp) {
539+
if (hasNullZps) {
540540
Value conv = rewriter
541541
.create<linalg::DepthwiseConv2DNhwcHwcmOp>(
542542
loc, linalgConvTy, ValueRange{input, weight},
@@ -556,8 +556,13 @@ class DepthwiseConvConverter
556556
getNParallelLoopsAttrs(resultRank),
557557
[&](OpBuilder &nestedBuilder, Location nestedLoc,
558558
ValueRange args) {
559-
Value added = nestedBuilder.create<arith::AddFOp>(
560-
loc, args[0], args[1]);
559+
Value added;
560+
if (llvm::isa<FloatType>(inputETy))
561+
added = nestedBuilder.create<arith::AddFOp>(loc, args[0],
562+
args[1]);
563+
else
564+
added = nestedBuilder.create<arith::AddIOp>(loc, args[0],
565+
args[1]);
561566
nestedBuilder.create<linalg::YieldOp>(nestedLoc, added);
562567
})
563568
.getResult(0);

mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -824,6 +824,30 @@ func.func @depthwise_conv2d_dyn_w_h(%arg0: tensor<2x?x?x3xf32>, %arg1: tensor<3x
824824

825825
// -----
826826

827+
// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d3)>
828+
// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
829+
830+
// CHECK-LABEL: @depthwise_int_conv_zero_zp
831+
func.func @depthwise_int_conv_zero_zp(%arg0 : tensor<1x7x5x3xi8>, %arg1 : tensor<3x1x3x11xi8>, %arg2 : tensor<33xi32>) -> () {
832+
// CHECK: [[INIT:%.+]] = tensor.empty()
833+
// CHECK: [[CST0:%.+]] = arith.constant 0
834+
// CHECK: [[FILL:%.+]] = linalg.fill ins([[CST0]]{{.*}}outs([[INIT]]
835+
// CHECK: [[OUT:%.+]] = tensor.empty()
836+
// CHECK: [[DEPTH:%.+]] = linalg.depthwise_conv_2d_nhwc_hwcm {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<1x7x5x3xi8>, tensor<3x1x3x11xi8>) outs([[FILL]] : tensor<1x5x5x3x11xi32>)
837+
// CHECK: [[COLLAPSED:%.+]] = tensor.collapse_shape [[DEPTH]] {{\[}}[0], [1], [2], [3, 4]]
838+
// CHECK: [[BIAS:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2, [[COLLAPSED]] : tensor<33xi32>, tensor<1x5x5x33xi32>) outs([[OUT]] : tensor<1x5x5x33xi32>) {
839+
// CHECK: ^bb0(%[[ARG3:[0-9a-zA-Z_]+]]: i32, %[[ARG4:[0-9a-zA-Z_]+]]: i32, %[[ARG5:[0-9a-zA-Z_]+]]: i32):
840+
// CHECK: [[ADD:%.+]] = arith.addi %[[ARG3]], %[[ARG4]] : i32
841+
// CHECK: linalg.yield [[ADD]] : i32
842+
// CHECK: } -> tensor<1x5x5x33xi32>
843+
%input_zp = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
844+
%weight_zp = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
845+
%2 = tosa.depthwise_conv2d %arg0, %arg1, %arg2, %input_zp, %weight_zp {acc_type = i32, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1> } : (tensor<1x7x5x3xi8>, tensor<3x1x3x11xi8>, tensor<33xi32>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x5x5x33xi32>
846+
return
847+
}
848+
849+
// -----
850+
827851
// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d4)>
828852
// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
829853

0 commit comments

Comments
 (0)