Skip to content

Commit 3a772c3

Browse files
authored
[mlir][tosa] Add fp16 support to tosa.resize (#73019)
1 parent a9673bd commit 3a772c3

File tree

2 files changed

+36
-17
lines changed

2 files changed

+36
-17
lines changed

mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1502,6 +1502,9 @@ class GenericResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
15021502
auto resultTy = cast<ShapedType>(op.getType());
15031503
auto resultETy = resultTy.getElementType();
15041504

1505+
bool floatingPointMode = resultETy.isF16() || resultETy.isF32();
1506+
auto floatTy = resultETy.isF16() ? b.getF16Type() : b.getF32Type();
1507+
15051508
auto imageH = inputTy.getShape()[1];
15061509
auto imageW = inputTy.getShape()[2];
15071510

@@ -1535,16 +1538,13 @@ class GenericResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
15351538

15361539
Value zeroI32 =
15371540
b.create<arith::ConstantOp>(b.getZeroAttr(b.getI32Type()));
1538-
Value zeroFp32 =
1539-
b.create<arith::ConstantOp>(b.getZeroAttr(b.getF32Type()));
1541+
Value zeroFp = b.create<arith::ConstantOp>(b.getZeroAttr(floatTy));
15401542
Value hMax = b.create<arith::ConstantOp>(b.getI32IntegerAttr(imageH - 1));
15411543
Value wMax = b.create<arith::ConstantOp>(b.getI32IntegerAttr(imageW - 1));
15421544

15431545
Value inY = b.create<arith::IndexCastOp>(b.getI32Type(), y);
15441546
Value inX = b.create<arith::IndexCastOp>(b.getI32Type(), x);
15451547

1546-
bool floatingPointMode = resultETy.isF32();
1547-
15481548
ArrayRef<int64_t> offset = op.getOffset();
15491549
ArrayRef<int64_t> border = op.getBorder();
15501550
ArrayRef<int64_t> scale = op.getScale();
@@ -1567,16 +1567,16 @@ class GenericResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
15671567
int size, ImplicitLocOpBuilder &b) {
15681568
if (size == 1) {
15691569
index = zeroI32;
1570-
delta = zeroFp32;
1570+
delta = zeroFp;
15711571
return;
15721572
}
15731573
// x = x * scale_d + offset;
15741574
// ix = floor(x / scale_n)
15751575
// dx = x / scale_n - ix
1576-
Value val = b.create<arith::UIToFPOp>(b.getF32Type(), in);
1577-
scaleN = b.create<arith::UIToFPOp>(b.getF32Type(), scaleN);
1578-
scaleD = b.create<arith::UIToFPOp>(b.getF32Type(), scaleD);
1579-
offset = b.create<arith::SIToFPOp>(b.getF32Type(), offset);
1576+
Value val = b.create<arith::UIToFPOp>(floatTy, in);
1577+
scaleN = b.create<arith::UIToFPOp>(floatTy, scaleN);
1578+
scaleD = b.create<arith::UIToFPOp>(floatTy, scaleD);
1579+
offset = b.create<arith::SIToFPOp>(floatTy, offset);
15801580
val = b.create<arith::MulFOp>(val, scaleD);
15811581
val = b.create<arith::AddFOp>(val, offset);
15821582
val = b.create<arith::DivFOp>(val, scaleN);
@@ -1625,7 +1625,7 @@ class GenericResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
16251625

16261626
Value pred;
16271627
if (floatingPointMode) {
1628-
auto h = b.create<arith::ConstantOp>(b.getF32FloatAttr(0.5f));
1628+
auto h = b.create<arith::ConstantOp>(b.getFloatAttr(floatTy, 0.5f));
16291629
pred = b.create<arith::CmpFOp>(arith::CmpFPredicate::OGE, dval, h);
16301630
} else {
16311631
Value dvalDouble = b.create<arith::ShLIOp>(dval, one);
@@ -1681,7 +1681,8 @@ class GenericResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
16811681
input, ValueRange{batch, y1, x1, channel});
16821682

16831683
if (floatingPointMode) {
1684-
auto oneVal = b.create<arith::ConstantOp>(b.getF32FloatAttr(1.0f));
1684+
auto oneVal =
1685+
b.create<arith::ConstantOp>(b.getFloatAttr(floatTy, 1.0f));
16851686
auto interpolate = [&](Value val0, Value val1, Value delta,
16861687
int inputSize,
16871688
ImplicitLocOpBuilder &b) -> Value {

mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-resize.mlir

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,41 @@
11
// RUN: mlir-opt --split-input-file -pass-pipeline="builtin.module(func.func(tosa-to-linalg))" %s -o -| FileCheck %s
22

3-
// CHECK-LABEL: @unary_resize_nearest_fp
4-
func.func @unary_resize_nearest_fp(%arg0 : tensor<3x1x1x7xf32>) -> tensor<3x1x1x7xf32> {
3+
// CHECK-LABEL: @unary_resize_nearest_fp32
4+
func.func @unary_resize_nearest_fp32(%arg0 : tensor<3x1x1x7xf32>) -> tensor<3x1x1x7xf32> {
55
%resize = "tosa.resize"(%arg0) {mode = "NEAREST_NEIGHBOR", scale = array<i64: 2, 2, 1, 1>, offset = array<i64: 0, 0>, border = array<i64: 0, 0>} : (tensor<3x1x1x7xf32>) -> tensor<3x1x1x7xf32>
66
// CHECK: return %arg0
77
return %resize : tensor<3x1x1x7xf32>
88
}
99

1010
// -----
1111

12-
// CHECK-LABEL: @unary_resize_bilinear_fp
13-
func.func @unary_resize_bilinear_fp(%arg0 : tensor<3x1x1x7xf32>) -> tensor<3x1x1x7xf32> {
12+
// CHECK-LABEL: @unary_resize_nearest_fp16
13+
func.func @unary_resize_nearest_fp16(%arg0 : tensor<3x1x1x7xf16>) -> tensor<3x1x1x7xf16> {
14+
%resize = "tosa.resize"(%arg0) {mode = "NEAREST_NEIGHBOR", scale = array<i64: 2, 2, 1, 1>, offset = array<i64: 0, 0>, border = array<i64: 0, 0>} : (tensor<3x1x1x7xf16>) -> tensor<3x1x1x7xf16>
15+
// CHECK: return %arg0
16+
return %resize : tensor<3x1x1x7xf16>
17+
}
18+
19+
// -----
20+
21+
// CHECK-LABEL: @unary_resize_bilinear_fp32
22+
func.func @unary_resize_bilinear_fp32(%arg0 : tensor<3x1x1x7xf32>) -> tensor<3x1x1x7xf32> {
1423
%resize = "tosa.resize"(%arg0) {mode = "BILINEAR", scale = array<i64: 2, 2, 1, 1>, offset = array<i64: 0, 0>, border = array<i64: 0, 0>} : (tensor<3x1x1x7xf32>) -> tensor<3x1x1x7xf32>
1524
// CHECK: return %arg0
1625
return %resize : tensor<3x1x1x7xf32>
1726
}
1827

1928
// -----
2029

30+
// CHECK-LABEL: @unary_resize_bilinear_fp16
31+
func.func @unary_resize_bilinear_fp16(%arg0 : tensor<3x1x1x7xf16>) -> tensor<3x1x1x7xf16> {
32+
%resize = "tosa.resize"(%arg0) {mode = "BILINEAR", scale = array<i64: 2, 2, 1, 1>, offset = array<i64: 0, 0>, border = array<i64: 0, 0>} : (tensor<3x1x1x7xf16>) -> tensor<3x1x1x7xf16>
33+
// CHECK: return %arg0
34+
return %resize : tensor<3x1x1x7xf16>
35+
}
36+
37+
// -----
38+
2139
// CHECK-LABEL: @unary_resize_nearest_i8
2240
func.func @unary_resize_nearest_i8(%arg0 : tensor<3x1x1x7xi8>) -> tensor<3x1x1x7xi8> {
2341
%resize = "tosa.resize"(%arg0) {mode = "NEAREST_NEIGHBOR", scale = array<i64: 2, 1, 3, 1>, offset = array<i64: 0, 0>, border = array<i64: 0, 0>} : (tensor<3x1x1x7xi8>) -> tensor<3x1x1x7xi8>
@@ -285,8 +303,8 @@ func.func @resize_bilinear_int(%arg0: tensor<1x19x20x1xi8>) {
285303

286304
// -----
287305

288-
// CHECK-LABEL: @resize_nearest_fp
289-
func.func @resize_nearest_fp(%input: tensor<1x50x48x1xf32>) -> () {
306+
// CHECK-LABEL: @resize_nearest_fp32
307+
func.func @resize_nearest_fp32(%input: tensor<1x50x48x1xf32>) -> () {
290308
// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x1600x1536x1xf32>
291309
// CHECK: %[[GENERIC:.+]] = linalg.generic
292310
// CHECK: %[[IDX0:.+]] = linalg.index 0

0 commit comments

Comments
 (0)