Skip to content

Commit 96e401c

Browse files
Jerry-Gejph-13
authored andcommitted
[mlir][tosa] Update value to values for ConstOp and ConstShapeOp (llvm#129943)
Updated the dialect to match TOSA v1.0 specification for ConstOp and ConstShapeOp (https://www.mlplatform.org/tosa/tosa_spec.html#_const). Also updated lit tests --------- Signed-off-by: Jerry Ge <[email protected]>
1 parent bdf2cc5 commit 96e401c

35 files changed

+980
-978
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2338,7 +2338,7 @@ def Tosa_RescaleOp: Tosa_Op<"rescale", [Pure,
23382338
// Operator: const
23392339
//===----------------------------------------------------------------------===//
23402340
def Tosa_ConstOp : Tosa_Op<"const", [ConstantLike, Pure,
2341-
AllShapesMatch<["value", "output"]>,
2341+
AllShapesMatch<["values", "output"]>,
23422342
FirstAttrDerivedResultType]> {
23432343
let summary = "Constant op.";
23442344

@@ -2350,12 +2350,12 @@ def Tosa_ConstOp : Tosa_Op<"const", [ConstantLike, Pure,
23502350

23512351
```mlir
23522352
// Generic form
2353-
%out = "tosa.const"() {value = dense<0> : tensor<2x3xi32>} : () -> tensor<2x3xi32>
2353+
%out = "tosa.const"() {values = dense<0> : tensor<2x3xi32>} : () -> tensor<2x3xi32>
23542354
```
23552355
}];
23562356

23572357
let arguments = (ins
2358-
ElementsAttr:$value
2358+
ElementsAttr:$values
23592359
);
23602360

23612361
let results = (outs

mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,11 +67,11 @@ def Tosa_ConstShapeOp : Tosa_ShapeOp<"const_shape", [ConstantLike, Pure]> {
6767

6868
```mlir
6969
// Generic form
70-
%out = "tosa.const_shape"() {value = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4>
70+
%out = "tosa.const_shape"() {values = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4>
7171
```
7272
}];
7373

74-
let arguments = (ins IndexElementsAttr : $value);
74+
let arguments = (ins IndexElementsAttr : $values);
7575

7676
let results = (outs Tosa_Shape : $output);
7777

mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -237,8 +237,8 @@ Value getTosaConstShape(PatternRewriter &rewriter, Location loc,
237237

238238
SmallVector<int64_t> convertFromMlirShape(ArrayRef<int64_t> shape);
239239

240-
bool getConstShapeValue(Operation *op,
241-
llvm::SmallVector<int64_t> &result_shape);
240+
bool getConstShapeValues(Operation *op,
241+
llvm::SmallVector<int64_t> &result_shape);
242242

243243
// returns a small vector of int64_t values that attr contains
244244
SmallVector<int64_t> convertFromIntAttr(const DenseElementsAttr &attr,

mlir/lib/Conversion/TosaToArith/TosaToArith.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ class ConstOpConverter : public OpRewritePattern<tosa::ConstOp> {
2828

2929
LogicalResult matchAndRewrite(tosa::ConstOp op,
3030
PatternRewriter &rewriter) const final {
31-
rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, op.getValue());
31+
rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, op.getValues());
3232
return success();
3333
}
3434
};

mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1578,7 +1578,7 @@ class ResizeUnaryConverter : public OpRewritePattern<tosa::ResizeOp> {
15781578
}
15791579

15801580
SmallVector<int64_t> scale;
1581-
if (!tosa::getConstShapeValue(op.getScale().getDefiningOp(), scale)) {
1581+
if (!tosa::getConstShapeValues(op.getScale().getDefiningOp(), scale)) {
15821582
return failure();
15831583
}
15841584

@@ -1799,9 +1799,9 @@ class GenericResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
17991799
Value inX = b.create<arith::IndexCastOp>(b.getI32Type(), x);
18001800

18011801
SmallVector<int64_t> scale, offset, border;
1802-
if (!tosa::getConstShapeValue(op.getScale().getDefiningOp(), scale) ||
1803-
!tosa::getConstShapeValue(op.getOffset().getDefiningOp(), offset) ||
1804-
!tosa::getConstShapeValue(op.getBorder().getDefiningOp(), border)) {
1802+
if (!tosa::getConstShapeValues(op.getScale().getDefiningOp(), scale) ||
1803+
!tosa::getConstShapeValues(op.getOffset().getDefiningOp(), offset) ||
1804+
!tosa::getConstShapeValues(op.getBorder().getDefiningOp(), border)) {
18051805
return rewriter.notifyMatchFailure(
18061806
op, "tosa.resize scale/offset/border should have compile time "
18071807
"constant values.");

mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -243,8 +243,8 @@ class ReshapeConverter : public OpConversionPattern<tosa::ReshapeOp> {
243243
}
244244

245245
llvm::SmallVector<int64_t> newShape;
246-
if (!tosa::getConstShapeValue(reshape.getShape().getDefiningOp(),
247-
newShape)) {
246+
if (!tosa::getConstShapeValues(reshape.getShape().getDefiningOp(),
247+
newShape)) {
248248
return failure();
249249
}
250250

mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -882,9 +882,9 @@ OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
882882
return {};
883883
}
884884

885-
OpFoldResult ConstOp::fold(FoldAdaptor adaptor) { return getValueAttr(); }
885+
OpFoldResult ConstOp::fold(FoldAdaptor adaptor) { return getValuesAttr(); }
886886

887-
OpFoldResult ConstShapeOp::fold(FoldAdaptor adaptor) { return getValueAttr(); }
887+
OpFoldResult ConstShapeOp::fold(FoldAdaptor adaptor) { return getValuesAttr(); }
888888

889889
#define REDUCE_FOLDER(OP) \
890890
OpFoldResult OP::fold(FoldAdaptor adaptor) { \
@@ -947,7 +947,7 @@ OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
947947
return {};
948948

949949
llvm::SmallVector<int64_t> shapeVec;
950-
if (!tosa::getConstShapeValue(getShape().getDefiningOp(), shapeVec))
950+
if (!tosa::getConstShapeValues(getShape().getDefiningOp(), shapeVec))
951951
return {};
952952

953953
return operand.reshape(

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 22 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -352,7 +352,7 @@ static LogicalResult verifyConvOp(T op) {
352352

353353
LogicalResult tosa::ConstOp::verify() {
354354

355-
auto attrType = llvm::dyn_cast<TensorType>(getValueAttr().getType());
355+
auto attrType = llvm::dyn_cast<TensorType>(getValuesAttr().getType());
356356
auto outputType = llvm::dyn_cast<TensorType>(getOutput().getType());
357357

358358
if (!attrType || !outputType) {
@@ -1179,8 +1179,8 @@ LogicalResult tosa::PadOp::inferReturnTypeComponents(
11791179

11801180
SmallVector<int64_t> paddingValues;
11811181
// If the paddings value is not a constant, all dimensions must be dynamic.
1182-
if (!tosa::getConstShapeValue(adaptor.getPadding().getDefiningOp(),
1183-
paddingValues)) {
1182+
if (!tosa::getConstShapeValues(adaptor.getPadding().getDefiningOp(),
1183+
paddingValues)) {
11841184
outputShape.resize(inputShape.getRank(), ShapedType::kDynamic);
11851185
inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
11861186
return success();
@@ -1252,8 +1252,8 @@ LogicalResult tosa::SliceOp::inferReturnTypeComponents(
12521252
SmallVector<int64_t> start;
12531253
SmallVector<int64_t> size;
12541254

1255-
if (!tosa::getConstShapeValue(adaptor.getStart().getDefiningOp(), start) ||
1256-
!tosa::getConstShapeValue(adaptor.getSize().getDefiningOp(), size)) {
1255+
if (!tosa::getConstShapeValues(adaptor.getStart().getDefiningOp(), start) ||
1256+
!tosa::getConstShapeValues(adaptor.getSize().getDefiningOp(), size)) {
12571257
auto rank = cast<tosa::shapeType>(adaptor.getSize().getType()).getRank();
12581258
SmallVector<int64_t> fallback(rank, ShapedType::kDynamic);
12591259
inferredReturnShapes.push_back(ShapedTypeComponents(fallback, inputType));
@@ -1561,8 +1561,8 @@ LogicalResult tosa::ReshapeOp::inferReturnTypeComponents(
15611561
ShapeAdaptor inputShape(adaptor.getInput1().getType());
15621562
Type inputType = getElementTypeOrSelf(adaptor.getInput1().getType());
15631563
llvm::SmallVector<int64_t> newShapeValue;
1564-
if (!tosa::getConstShapeValue(adaptor.getShape().getDefiningOp(),
1565-
newShapeValue)) {
1564+
if (!tosa::getConstShapeValues(adaptor.getShape().getDefiningOp(),
1565+
newShapeValue)) {
15661566
auto rank = cast<tosa::shapeType>(adaptor.getShape().getType()).getRank();
15671567
SmallVector<int64_t> fallback(rank, ShapedType::kDynamic);
15681568
inferredReturnShapes.push_back(ShapedTypeComponents(fallback, inputType));
@@ -1611,7 +1611,7 @@ llvm::LogicalResult tosa::ReshapeOp::verify() {
16111611
RankedTensorType outputType = getType();
16121612

16131613
SmallVector<int64_t> shapeValues;
1614-
if (!tosa::getConstShapeValue(getShape().getDefiningOp(), shapeValues)) {
1614+
if (!tosa::getConstShapeValues(getShape().getDefiningOp(), shapeValues)) {
16151615
// skip following checks if shape is not constant
16161616
return mlir::success();
16171617
}
@@ -1916,11 +1916,12 @@ LogicalResult tosa::ResizeOp::inferReturnTypeComponents(
19161916
return failure();
19171917

19181918
SmallVector<int64_t> scaleInt, offsetInt, borderInt;
1919-
if (!tosa::getConstShapeValue(adaptor.getScale().getDefiningOp(), scaleInt) ||
1920-
!tosa::getConstShapeValue(adaptor.getOffset().getDefiningOp(),
1921-
offsetInt) ||
1922-
!tosa::getConstShapeValue(adaptor.getBorder().getDefiningOp(),
1923-
borderInt)) {
1919+
if (!tosa::getConstShapeValues(adaptor.getScale().getDefiningOp(),
1920+
scaleInt) ||
1921+
!tosa::getConstShapeValues(adaptor.getOffset().getDefiningOp(),
1922+
offsetInt) ||
1923+
!tosa::getConstShapeValues(adaptor.getBorder().getDefiningOp(),
1924+
borderInt)) {
19241925
return failure();
19251926
}
19261927

@@ -1960,9 +1961,9 @@ LogicalResult tosa::ResizeOp::verify() {
19601961
SmallVector<int64_t> scaleValues;
19611962
SmallVector<int64_t> offsetValues;
19621963
SmallVector<int64_t> borderValues;
1963-
if (!tosa::getConstShapeValue(getScale().getDefiningOp(), scaleValues) ||
1964-
!tosa::getConstShapeValue(getOffset().getDefiningOp(), offsetValues) ||
1965-
!tosa::getConstShapeValue(getBorder().getDefiningOp(), borderValues)) {
1964+
if (!tosa::getConstShapeValues(getScale().getDefiningOp(), scaleValues) ||
1965+
!tosa::getConstShapeValues(getOffset().getDefiningOp(), offsetValues) ||
1966+
!tosa::getConstShapeValues(getBorder().getDefiningOp(), borderValues)) {
19661967
// Skip following checks if shape is not constant
19671968
return success();
19681969
}
@@ -3051,14 +3052,14 @@ OpTrait::tosa::verifyTosaShapeOperatorWithSameRanks(Operation *op) {
30513052

30523053
LogicalResult tosa::ConstShapeOp::verify() {
30533054
// check one dimensional rank
3054-
auto valuesRank = getValue().getType().getRank();
3055+
auto valuesRank = getValues().getType().getRank();
30553056
if (valuesRank != 1)
3056-
return emitOpError("expect elements in attribute value with rank 1");
3057-
// check that number of elements in value attr equal to rank of result shape
3058-
auto count = getValue().getNumElements();
3057+
return emitOpError("expect elements in attribute values with rank 1");
3058+
// check that number of elements in values attr equal to rank of result shape
3059+
auto count = getValues().getNumElements();
30593060
auto rank = (cast<tosa::shapeType>(getResult().getType())).getRank();
30603061
if (!(count == rank || (count == 1 && rank == 0))) {
3061-
return emitOpError("expect number of elements in attribute value (")
3062+
return emitOpError("expect number of elements in attribute values (")
30623063
<< count << ") to be equal to the rank (" << rank
30633064
<< ") for the result shape type";
30643065
}

mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -363,7 +363,7 @@ struct ReduceConstantOptimization : public OpRewritePattern<OperationType> {
363363
return rewriter.notifyMatchFailure(op, "result type shape is not static");
364364

365365
auto reductionAxis = op.getAxis();
366-
const auto denseElementsAttr = constOp.getValue();
366+
const auto denseElementsAttr = constOp.getValues();
367367
const auto shapedOldElementsValues =
368368
cast<ShapedType>(denseElementsAttr.getType());
369369

mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -399,8 +399,8 @@ std::optional<Value> TosaReduceTransposes::buildMappedToValue(
399399

400400
// Do not insert a TransposeOp, instead we fold the reshape and its attribute.
401401
llvm::SmallVector<int64_t> newShape;
402-
if (!tosa::getConstShapeValue(reshapeOp.getShape().getDefiningOp(),
403-
newShape)) {
402+
if (!tosa::getConstShapeValues(reshapeOp.getShape().getDefiningOp(),
403+
newShape)) {
404404
// this mean shape is not constant
405405
return std::nullopt;
406406
}
@@ -418,7 +418,7 @@ std::optional<Value> TosaReduceTransposes::buildMappedToValue(
418418
std::optional<Value> TosaReduceTransposes::buildMappedToValue(
419419
ConstOp constOp, const DenseMap<Value, Value> &valuesMap,
420420
IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms) {
421-
auto denseAttr = llvm::dyn_cast<DenseElementsAttr>(constOp.getValue());
421+
auto denseAttr = llvm::dyn_cast<DenseElementsAttr>(constOp.getValues());
422422
if (!denseAttr)
423423
return std::nullopt;
424424
auto maybeNewDenseAttr = transposeDenseAttribute(denseAttr, hoistedPerms);

mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -342,7 +342,8 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
342342
bool levelCheckResize(Operation *op) {
343343
if (auto resize = dyn_cast<tosa::ResizeOp>(op)) {
344344
SmallVector<int64_t> scale;
345-
if (!tosa::getConstShapeValue(resize.getScale().getDefiningOp(), scale)) {
345+
if (!tosa::getConstShapeValues(resize.getScale().getDefiningOp(),
346+
scale)) {
346347
return false;
347348
}
348349
const int64_t scaleYN = scale[0];
@@ -736,7 +737,7 @@ bool checkErrorIfResize(Operation *op) {
736737
}
737738

738739
SmallVector<int64_t> scale;
739-
if (!tosa::getConstShapeValue(resize.getScale().getDefiningOp(), scale)) {
740+
if (!tosa::getConstShapeValues(resize.getScale().getDefiningOp(), scale)) {
740741
return false;
741742
}
742743

@@ -761,8 +762,8 @@ bool checkErrorIfResize(Operation *op) {
761762

762763
SmallVector<int64_t> offset;
763764
SmallVector<int64_t> border;
764-
if (!tosa::getConstShapeValue(resize.getOffset().getDefiningOp(), offset) ||
765-
!tosa::getConstShapeValue(resize.getBorder().getDefiningOp(), border)) {
765+
if (!tosa::getConstShapeValues(resize.getOffset().getDefiningOp(), offset) ||
766+
!tosa::getConstShapeValues(resize.getBorder().getDefiningOp(), border)) {
766767
return false;
767768
}
768769

mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -178,13 +178,13 @@ SmallVector<int64_t> mlir::tosa::convertFromMlirShape(ArrayRef<int64_t> shape) {
178178
}));
179179
}
180180

181-
bool mlir::tosa::getConstShapeValue(Operation *op,
182-
llvm::SmallVector<int64_t> &result_shape) {
181+
bool mlir::tosa::getConstShapeValues(Operation *op,
182+
llvm::SmallVector<int64_t> &result_shape) {
183183
if (!op) {
184184
return false;
185185
}
186186
if (auto constOp = mlir::dyn_cast<tosa::ConstShapeOp>(op)) {
187-
Attribute constOpAttr = constOp->getAttr("value");
187+
Attribute constOpAttr = constOp->getAttr("values");
188188
DenseElementsAttr elementsAttr = cast<DenseElementsAttr>(constOpAttr);
189189
for (int i = 0; i < elementsAttr.size(); i++) {
190190
int64_t val = elementsAttr.getValues<int64_t>()[i];

mlir/test/Conversion/TosaToArith/tosa-to-arith.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
// CHECK-LABEL: func @const_test
55
func.func @const_test() -> (tensor<i32>) {
66
// CHECK: [[C3:%.+]] = arith.constant dense<3> : tensor<i32>
7-
%result = "tosa.const"() {value = dense<3> : tensor<i32>} : () -> tensor<i32>
7+
%result = "tosa.const"() {values = dense<3> : tensor<i32>} : () -> tensor<i32>
88

99
// CHECK: return [[C3]]
1010
return %result : tensor<i32>

mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ func.func @unranked_add(%arg0 : tensor<10x10xf32> , %arg1 : tensor<10x10xf32>, %
2424
%reduce = tosa.reduce_max %arg0 {axis = 1 : i32} : (tensor<10x10xf32>) -> tensor<10x1xf32>
2525
%1 = tosa.add %reduce, %arg1 : (tensor<10x1xf32>, tensor<10x10xf32>) -> tensor<10x10xf32>
2626
%0 = tosa.add %1, %arg2 : (tensor<10x10xf32>, tensor<*xf32>) -> tensor<*xf32>
27-
%s = tosa.const_shape {value = dense<[10, 10]> : tensor<2xindex>} : () -> !tosa.shape<2>
27+
%s = tosa.const_shape {values = dense<[10, 10]> : tensor<2xindex>} : () -> !tosa.shape<2>
2828
%2 = tosa.reshape %0, %s : (tensor<*xf32>, !tosa.shape<2>) -> tensor<10x10xf32>
2929
return %2 : tensor<10x10xf32>
3030
}

0 commit comments

Comments
 (0)