Skip to content

Commit 3430bc3

Browse files
HsiangkaiTai78641lhutton1
authored
[mlir][tosa] Make TOSA RESIZE's scale, offset, border as Input (#124956)
Move the `scale`, `offset`, and `border` parameters of the RESIZE operator in the MLIR TOSA dialect from attributes to inputs and update lit tests appropriately. Add the verifier of the `tosa::ResizeOp` operation. --------- Co-authored-by: Tai Ly <[email protected]> Co-authored-by: Luke Hutton <[email protected]>
1 parent 7c24041 commit 3430bc3

File tree

13 files changed

+558
-72
lines changed

13 files changed

+558
-72
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1796,9 +1796,9 @@ def Tosa_ResizeOp : Tosa_InferShapedTypeOp<"resize"> {
17961796

17971797
let arguments = (ins
17981798
Tosa_Tensor4D:$input,
1799-
Tosa_IntArrayAttr4:$scale,
1800-
Tosa_IntArrayAttr2:$offset,
1801-
Tosa_IntArrayAttr2:$border,
1799+
Rank4TosaShape:$scale,
1800+
Rank2TosaShape:$offset,
1801+
Rank2TosaShape:$border,
18021802
Tosa_ResizeTypeAttr:$mode
18031803
);
18041804

@@ -1807,6 +1807,7 @@ def Tosa_ResizeOp : Tosa_InferShapedTypeOp<"resize"> {
18071807
);
18081808

18091809
let hasFolder = 1;
1810+
let hasVerifier = 1;
18101811
}
18111812

18121813
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,9 @@ SmallVector<int64_t> convertFromMlirShape(ArrayRef<int64_t> shape);
240240
bool getConstShapeValue(Operation *op,
241241
llvm::SmallVector<int64_t> &result_shape);
242242

243+
// returns a small vector of int64_t values that attr contains
244+
SmallVector<int64_t> convertFromIntAttr(const DenseElementsAttr &attr,
245+
const int rank);
243246
} // namespace tosa
244247
} // namespace mlir
245248

mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1387,7 +1387,10 @@ class ResizeUnaryConverter : public OpRewritePattern<tosa::ResizeOp> {
13871387
return success();
13881388
}
13891389

1390-
ArrayRef<int64_t> scale = op.getScale();
1390+
SmallVector<int64_t> scale;
1391+
if (!tosa::getConstShapeValue(op.getScale().getDefiningOp(), scale)) {
1392+
return failure();
1393+
}
13911394

13921395
// Collapse the unit width and height away.
13931396
SmallVector<ReassociationExprs, 4> reassociationMap(2);
@@ -1488,8 +1491,9 @@ class MaterializeResizeBroadcast : public OpRewritePattern<tosa::ResizeOp> {
14881491
resizeShape.push_back(channels);
14891492

14901493
auto resizeTy = resultTy.clone(resizeShape);
1491-
auto resize =
1492-
builder.create<tosa::ResizeOp>(resizeTy, input, op->getAttrs());
1494+
auto resize = builder.create<tosa::ResizeOp>(resizeTy, input, op.getScale(),
1495+
op.getOffset(), op.getBorder(),
1496+
op.getMode());
14931497

14941498
// Collapse an unit result dims.
14951499
SmallVector<ReassociationExprs, 4> reassociationMap(2);
@@ -1604,9 +1608,14 @@ class GenericResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
16041608
Value inY = b.create<arith::IndexCastOp>(b.getI32Type(), y);
16051609
Value inX = b.create<arith::IndexCastOp>(b.getI32Type(), x);
16061610

1607-
ArrayRef<int64_t> offset = op.getOffset();
1608-
ArrayRef<int64_t> border = op.getBorder();
1609-
ArrayRef<int64_t> scale = op.getScale();
1611+
SmallVector<int64_t> scale, offset, border;
1612+
if (!tosa::getConstShapeValue(op.getScale().getDefiningOp(), scale) ||
1613+
!tosa::getConstShapeValue(op.getOffset().getDefiningOp(), offset) ||
1614+
!tosa::getConstShapeValue(op.getBorder().getDefiningOp(), border)) {
1615+
return rewriter.notifyMatchFailure(
1616+
op, "tosa.resize scale/offset/border should have compile time "
1617+
"constant values.");
1618+
}
16101619

16111620
Value yScaleN, yScaleD, xScaleN, xScaleD;
16121621
yScaleN = b.create<arith::ConstantOp>(b.getI32IntegerAttr(scale[0]));

mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1034,9 +1034,22 @@ OpFoldResult PadOp::fold(FoldAdaptor adaptor) {
10341034
// Fold away cases where a tosa.resize operation returns a copy
10351035
// of the input image.
10361036
OpFoldResult ResizeOp::fold(FoldAdaptor adaptor) {
1037-
ArrayRef<int64_t> offset = getOffset();
1038-
ArrayRef<int64_t> border = getBorder();
1039-
ArrayRef<int64_t> scale = getScale();
1037+
auto scaleAttr =
1038+
llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getScale());
1039+
auto offsetAttr =
1040+
llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getOffset());
1041+
auto borderAttr =
1042+
llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getBorder());
1043+
if (!scaleAttr || !offsetAttr || !borderAttr) {
1044+
return {};
1045+
}
1046+
1047+
auto scale = tosa::convertFromIntAttr(scaleAttr, /* rank = */ 4);
1048+
auto offset = tosa::convertFromIntAttr(offsetAttr, /* rank = */ 2);
1049+
auto border = tosa::convertFromIntAttr(borderAttr, /* rank = */ 2);
1050+
if (scale.size() != 4 || offset.size() != 2 || border.size() != 2) {
1051+
return {};
1052+
}
10401053

10411054
// Check unit scaling.
10421055
if (scale[0] != scale[1] || scale[2] != scale[3]) {

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 100 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1598,9 +1598,14 @@ LogicalResult tosa::ResizeOp::inferReturnTypeComponents(
15981598
(inputWidth == ShapedType::kDynamic))
15991599
return failure();
16001600

1601-
llvm::ArrayRef<int64_t> scaleInt = adaptor.getScale();
1602-
llvm::ArrayRef<int64_t> offsetInt = adaptor.getOffset();
1603-
llvm::ArrayRef<int64_t> borderInt = adaptor.getBorder();
1601+
SmallVector<int64_t> scaleInt, offsetInt, borderInt;
1602+
if (!tosa::getConstShapeValue(adaptor.getScale().getDefiningOp(), scaleInt) ||
1603+
!tosa::getConstShapeValue(adaptor.getOffset().getDefiningOp(),
1604+
offsetInt) ||
1605+
!tosa::getConstShapeValue(adaptor.getBorder().getDefiningOp(),
1606+
borderInt)) {
1607+
return failure();
1608+
}
16041609

16051610
// Compute the output shape based on attributes: scale, offset, and border.
16061611
outputShape[1] =
@@ -1617,6 +1622,98 @@ LogicalResult tosa::ResizeOp::inferReturnTypeComponents(
16171622
return success();
16181623
}
16191624

1625+
LogicalResult tosa::ResizeOp::verify() {
1626+
const Value input = getInput();
1627+
const Value output = getOutput();
1628+
const RankedTensorType inputType =
1629+
llvm::dyn_cast<RankedTensorType>(input.getType());
1630+
const RankedTensorType outputType =
1631+
llvm::dyn_cast<RankedTensorType>(output.getType());
1632+
1633+
if (!inputType)
1634+
return emitOpError("expect a ranked input tensor");
1635+
if (!outputType)
1636+
return emitOpError("expect a ranked output tensor");
1637+
1638+
const int64_t oh = outputType.getDimSize(1);
1639+
const int64_t ow = outputType.getDimSize(2);
1640+
const int64_t ih = inputType.getDimSize(1);
1641+
const int64_t iw = inputType.getDimSize(2);
1642+
1643+
SmallVector<int64_t> scaleValues;
1644+
SmallVector<int64_t> offsetValues;
1645+
SmallVector<int64_t> borderValues;
1646+
if (!tosa::getConstShapeValue(getScale().getDefiningOp(), scaleValues) ||
1647+
!tosa::getConstShapeValue(getOffset().getDefiningOp(), offsetValues) ||
1648+
!tosa::getConstShapeValue(getBorder().getDefiningOp(), borderValues)) {
1649+
// Skip following checks if shape is not constant
1650+
return success();
1651+
}
1652+
1653+
if (llvm::any_of(scaleValues, [](int64_t s) { return s <= 0; }))
1654+
return emitOpError("expect all scale values to be > 0, got ")
1655+
<< scaleValues;
1656+
1657+
const int64_t scaleYN = scaleValues[0];
1658+
const int64_t scaleYD = scaleValues[1];
1659+
const int64_t scaleXN = scaleValues[2];
1660+
const int64_t scaleXD = scaleValues[3];
1661+
1662+
const int64_t offsetY = offsetValues[0];
1663+
const int64_t offsetX = offsetValues[1];
1664+
1665+
const int64_t borderY = borderValues[0];
1666+
const int64_t borderX = borderValues[1];
1667+
1668+
auto idivCheck = [](const int64_t lhs,
1669+
const int64_t rhs) -> std::optional<int64_t> {
1670+
if (lhs % rhs != 0)
1671+
return std::nullopt;
1672+
return lhs / rhs;
1673+
};
1674+
1675+
// Don't check with input height that could be broadcast (ih != 1)
1676+
// since Linalg, a consumer of TOSA, expects broadcasting support
1677+
// in resize to be available. Taking the cautious approach for now,
1678+
// we can consider removing support for broadcasting later.
1679+
if (ih != ShapedType::kDynamic && ih != 1) {
1680+
const std::optional<int64_t> calculatedOutHeightMinusOne =
1681+
idivCheck((ih - 1) * scaleYN - offsetY + borderY, scaleYD);
1682+
if (!calculatedOutHeightMinusOne.has_value())
1683+
return emitOpError("expected (input_height - 1) * scale_y_n - offset_y + "
1684+
"border_y ")
1685+
<< "to be wholly divisible by scale_y_d, got ((" << ih
1686+
<< " - 1) * " << scaleYN << " - " << offsetY << " + " << borderY
1687+
<< ") / " << scaleYD;
1688+
const int64_t calculatedOutHeight = calculatedOutHeightMinusOne.value() + 1;
1689+
if (oh != ShapedType::kDynamic && calculatedOutHeight != oh)
1690+
return emitOpError("calculated output height did not match expected: ")
1691+
<< "calculated=" << calculatedOutHeight << ", expected=" << oh;
1692+
}
1693+
1694+
// Don't check with input width that could be broadcast (iw != 1)
1695+
// since Linalg, a consumer of TOSA, expects broadcasting support
1696+
// in resize to be available. Taking the cautious approach for now,
1697+
// we can consider removing support for broadcasting later.
1698+
if (iw != ShapedType::kDynamic && iw != 1) {
1699+
const int64_t scaledInWidth = (iw - 1) * scaleXN - offsetX + borderX;
1700+
const std::optional<int64_t> calculatedOutWidthMinusOne =
1701+
idivCheck(scaledInWidth, scaleXD);
1702+
if (!calculatedOutWidthMinusOne.has_value())
1703+
return emitOpError("expected (input_width - 1) * scale_x_n - offset_x + "
1704+
"border_x ")
1705+
<< "to be wholly divisible by scale_x_d, got ((" << iw
1706+
<< " - 1) * " << scaleXN << " - " << offsetX << " + " << borderX
1707+
<< ") / " << scaleXD;
1708+
const int64_t calculatedOutWidth = calculatedOutWidthMinusOne.value() + 1;
1709+
if (ow != ShapedType::kDynamic && calculatedOutWidth != ow)
1710+
return emitOpError("calculated output width did not match expected: ")
1711+
<< "calculated=" << calculatedOutWidth << ", expected=" << ow;
1712+
}
1713+
1714+
return success();
1715+
}
1716+
16201717
LogicalResult tosa::ScatterOp::inferReturnTypeComponents(
16211718
MLIRContext *context, ::std::optional<Location> location,
16221719
ScatterOp::Adaptor adaptor,

0 commit comments

Comments
 (0)