@@ -1502,6 +1502,9 @@ class GenericResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
1502
1502
auto resultTy = cast<ShapedType>(op.getType ());
1503
1503
auto resultETy = resultTy.getElementType ();
1504
1504
1505
+ bool floatingPointMode = resultETy.isF16 () || resultETy.isF32 ();
1506
+ auto floatTy = resultETy.isF16 () ? b.getF16Type () : b.getF32Type ();
1507
+
1505
1508
auto imageH = inputTy.getShape ()[1 ];
1506
1509
auto imageW = inputTy.getShape ()[2 ];
1507
1510
@@ -1535,16 +1538,13 @@ class GenericResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
1535
1538
1536
1539
Value zeroI32 =
1537
1540
b.create <arith::ConstantOp>(b.getZeroAttr (b.getI32Type ()));
1538
- Value zeroFp32 =
1539
- b.create <arith::ConstantOp>(b.getZeroAttr (b.getF32Type ()));
1541
+ Value zeroFp = b.create <arith::ConstantOp>(b.getZeroAttr (floatTy));
1540
1542
Value hMax = b.create <arith::ConstantOp>(b.getI32IntegerAttr (imageH - 1 ));
1541
1543
Value wMax = b.create <arith::ConstantOp>(b.getI32IntegerAttr (imageW - 1 ));
1542
1544
1543
1545
Value inY = b.create <arith::IndexCastOp>(b.getI32Type (), y);
1544
1546
Value inX = b.create <arith::IndexCastOp>(b.getI32Type (), x);
1545
1547
1546
- bool floatingPointMode = resultETy.isF32 ();
1547
-
1548
1548
ArrayRef<int64_t > offset = op.getOffset ();
1549
1549
ArrayRef<int64_t > border = op.getBorder ();
1550
1550
ArrayRef<int64_t > scale = op.getScale ();
@@ -1567,16 +1567,16 @@ class GenericResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
1567
1567
int size, ImplicitLocOpBuilder &b) {
1568
1568
if (size == 1 ) {
1569
1569
index = zeroI32;
1570
- delta = zeroFp32 ;
1570
+ delta = zeroFp ;
1571
1571
return ;
1572
1572
}
1573
1573
// x = x * scale_d + offset;
1574
1574
// ix = floor(x / scale_n)
1575
1575
// dx = x / scale_n - ix
1576
- Value val = b.create <arith::UIToFPOp>(b. getF32Type () , in);
1577
- scaleN = b.create <arith::UIToFPOp>(b. getF32Type () , scaleN);
1578
- scaleD = b.create <arith::UIToFPOp>(b. getF32Type () , scaleD);
1579
- offset = b.create <arith::SIToFPOp>(b. getF32Type () , offset);
1576
+ Value val = b.create <arith::UIToFPOp>(floatTy , in);
1577
+ scaleN = b.create <arith::UIToFPOp>(floatTy , scaleN);
1578
+ scaleD = b.create <arith::UIToFPOp>(floatTy , scaleD);
1579
+ offset = b.create <arith::SIToFPOp>(floatTy , offset);
1580
1580
val = b.create <arith::MulFOp>(val, scaleD);
1581
1581
val = b.create <arith::AddFOp>(val, offset);
1582
1582
val = b.create <arith::DivFOp>(val, scaleN);
@@ -1625,7 +1625,7 @@ class GenericResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
1625
1625
1626
1626
Value pred;
1627
1627
if (floatingPointMode) {
1628
- auto h = b.create <arith::ConstantOp>(b.getF32FloatAttr ( 0 .5f ));
1628
+ auto h = b.create <arith::ConstantOp>(b.getFloatAttr (floatTy, 0 .5f ));
1629
1629
pred = b.create <arith::CmpFOp>(arith::CmpFPredicate::OGE, dval, h);
1630
1630
} else {
1631
1631
Value dvalDouble = b.create <arith::ShLIOp>(dval, one);
@@ -1681,7 +1681,8 @@ class GenericResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
1681
1681
input, ValueRange{batch, y1 , x1, channel});
1682
1682
1683
1683
if (floatingPointMode) {
1684
- auto oneVal = b.create <arith::ConstantOp>(b.getF32FloatAttr (1 .0f ));
1684
+ auto oneVal =
1685
+ b.create <arith::ConstantOp>(b.getFloatAttr (floatTy, 1 .0f ));
1685
1686
auto interpolate = [&](Value val0, Value val1, Value delta,
1686
1687
int inputSize,
1687
1688
ImplicitLocOpBuilder &b) -> Value {
0 commit comments