Skip to content

Commit 6c9eef3

Browse files
lhutton1Tai78641
authored andcommitted
[TOSA] Switch zero point of negate to input variable type
This commit changes the zero point attribute to an input to align with the 1.0 spec. Change-Id: Ibc9e5959b36c182a9e0c5c23a2f9d42a572a1184 Signed-off-by: Tai Ly <[email protected]>
1 parent b8a66f5 commit 6c9eef3

File tree

17 files changed

+364
-91
lines changed

17 files changed

+364
-91
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -112,8 +112,12 @@ profileComplianceMap = {
112112
{"tosa.logical_not",
113113
{{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT}}}}},
114114
{"tosa.negate",
115-
{{{Profile::pro_int}, {{i8T, i8T}, {i16T, i16T}, {i32T, i32T}}},
116-
{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
115+
{{{Profile::pro_int},
116+
{{i8T, i8T, i8T, i8T},
117+
{i16T, i16T, i16T, i16T},
118+
{i32T, i32T, i32T, i32T}}},
119+
{{Profile::pro_fp},
120+
{{fp16T, fp16T, fp16T, fp16T}, {fp32T, fp32T, fp32T, fp32T}}}}},
117121
{"tosa.reciprocal",
118122
{{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
119123
{"tosa.rsqrt", {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
@@ -308,7 +312,7 @@ extensionComplianceMap = {
308312
{"tosa.exp", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
309313
{"tosa.floor", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
310314
{"tosa.log", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
311-
{"tosa.negate", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
315+
{"tosa.negate", {{{Extension::bf16}, {{bf16T, bf16T, bf16T, bf16T}}}}},
312316
{"tosa.reciprocal", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
313317
{"tosa.rsqrt", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
314318
{"tosa.select", {{{Extension::bf16}, {{bf16T, bf16T, bf16T}}}}},

mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -178,13 +178,13 @@ def Tosa_AvgPool2dOpQuantInfoBuilder : OpBuilder<
178178
input, kernel, stride, pad, acc_type);
179179
}]>;
180180

181-
// This builder is called on single-parameter unary operators that have a scale
181+
// This builder is called on single-parameter negate operators that have a scale
182182
// relationship between their input and output, expressed by the
183183
// UnaryOpQuantizationAttr.
184-
def Tosa_UnaryOpQuantInfoBuilder : OpBuilder<
184+
def Tosa_NegateOpQuantInfoBuilder : OpBuilder<
185185
(ins "Type":$outputType, "Value":$input),
186186
[{
187-
buildUnaryOpWithQuantInfo($_builder, $_state, outputType, input);
187+
buildNegateOpWithQuantInfo($_builder, $_state, outputType, input);
188188
}]>;
189189

190190
// These builders are called on the TOSA pad operator that needs to create its

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1343,7 +1343,9 @@ def Tosa_LogicalNotOp : Tosa_ElementwiseUnaryOp<"logical_not"> {
13431343
//===----------------------------------------------------------------------===//
13441344
// Operator: negate
13451345
//===----------------------------------------------------------------------===//
1346-
def Tosa_NegateOp : Tosa_ElementwiseUnaryOp<"negate"> {
1346+
def Tosa_NegateOp : Tosa_InferShapedTypeOp<"negate", [
1347+
TosaElementwiseOperator,
1348+
Pure]> {
13471349
let summary = "Elementwise negate op";
13481350

13491351
let description = [{
@@ -1352,8 +1354,8 @@ def Tosa_NegateOp : Tosa_ElementwiseUnaryOp<"negate"> {
13521354

13531355
let arguments = (ins
13541356
Tosa_Tensor:$input1,
1355-
OptionalAttr<I32Attr>:$input1_zp,
1356-
OptionalAttr<I32Attr>:$output_zp
1357+
Tosa_ScalarTensor:$input1_zp,
1358+
Tosa_ScalarTensor:$output_zp
13571359
);
13581360

13591361
let results = (outs
@@ -1365,9 +1367,20 @@ def Tosa_NegateOp : Tosa_ElementwiseUnaryOp<"negate"> {
13651367
Extension<[Tosa_EXT_BF16]>,
13661368
];
13671369

1368-
let builders = [Tosa_UnaryOpQuantInfoBuilder];
1370+
let builders = [Tosa_NegateOpQuantInfoBuilder];
1371+
1372+
let extraClassDeclaration = [{
1373+
FailureOr<int64_t> getInput1ZeroPoint();
1374+
FailureOr<int64_t> getOutputZeroPoint();
1375+
LogicalResult verifyInput1ZeroPoint(int64_t zp);
1376+
LogicalResult verifyOutputZeroPoint(int64_t zp);
1377+
}];
13691378

13701379
let hasFolder = 1;
1380+
let hasVerifier = 1;
1381+
1382+
let assemblyFormat =
1383+
"operands attr-dict `:` functional-type(operands, results)";
13711384
}
13721385

13731386
//===----------------------------------------------------------------------===//

mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -193,18 +193,29 @@ static Value createLinalgBodyCalculationForElementwiseOp(
193193

194194
// tosa::NegateOp
195195
if (isa<tosa::NegateOp>(op)) {
196-
if (isa<FloatType>(elementTy))
197-
return rewriter.create<arith::NegFOp>(loc, resultTypes, args);
196+
auto negate = cast<tosa::NegateOp>(op);
198197

199-
if (isa<IntegerType>(elementTy)) {
200-
auto inputZpAttr = cast<tosa::NegateOp>(op).getInput1ZpAttr();
201-
auto outputZpAttr = cast<tosa::NegateOp>(op).getOutputZpAttr();
198+
FailureOr<int64_t> maybeInZp = negate.getInput1ZeroPoint();
199+
if (failed(maybeInZp)) {
200+
(void)rewriter.notifyMatchFailure(
201+
op, "input1 zero point cannot be statically determined");
202+
return nullptr;
203+
}
204+
205+
FailureOr<int64_t> maybeOutZp = negate.getOutputZeroPoint();
206+
if (failed(maybeOutZp)) {
207+
(void)rewriter.notifyMatchFailure(
208+
op, "output zero point cannot be statically determined");
209+
return nullptr;
210+
}
202211

203-
const int64_t inZp =
204-
inputZpAttr ? inputZpAttr.getValue().getSExtValue() : 0;
205-
const int64_t outZp =
206-
outputZpAttr ? outputZpAttr.getValue().getSExtValue() : 0;
212+
int64_t inZp = *maybeInZp;
213+
int64_t outZp = *maybeOutZp;
207214

215+
if (isa<FloatType>(elementTy))
216+
return rewriter.create<arith::NegFOp>(loc, resultTypes, args[0]);
217+
218+
if (isa<IntegerType>(elementTy)) {
208219
if (!inZp && !outZp) {
209220
auto constant = rewriter.create<arith::ConstantOp>(
210221
loc, IntegerAttr::get(elementTy, 0));

mlir/lib/Dialect/Tosa/IR/ShardingInterfaceImpl.cpp

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,45 @@ struct MatMulOpSharding
6060
}
6161
};
6262

63+
struct NegateOpSharding
64+
: public ShardingInterface::ExternalModel<NegateOpSharding, NegateOp> {
65+
SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
66+
Value val = op->getOperand(0);
67+
auto type = dyn_cast<RankedTensorType>(val.getType());
68+
if (!type)
69+
return {};
70+
SmallVector<utils::IteratorType> types(type.getRank(),
71+
utils::IteratorType::parallel);
72+
return types;
73+
}
74+
75+
SmallVector<AffineMap> getIndexingMaps(Operation *op) const {
76+
MLIRContext *ctx = op->getContext();
77+
Value val = op->getOperand(0);
78+
auto type = dyn_cast<RankedTensorType>(val.getType());
79+
if (!type)
80+
return {};
81+
int64_t rank = type.getRank();
82+
SmallVector<AffineMap> maps = {
83+
AffineMap::getMultiDimIdentityMap(rank, ctx),
84+
AffineMap::get(0, 0, {}, ctx), AffineMap::get(0, 0, {}, ctx),
85+
AffineMap::getMultiDimIdentityMap(rank, ctx)};
86+
return maps;
87+
}
88+
89+
LogicalResult spmdize(Operation *op, ArrayRef<Value> spmdizedOperands,
90+
ArrayRef<MeshSharding> operandShardings,
91+
ArrayRef<MeshSharding> resultShardings,
92+
IRMapping &spmdizationMap,
93+
SymbolTableCollection &symbolTable,
94+
OpBuilder &builder) const {
95+
spmdizeTriviallyShardableOperation(*op, spmdizedOperands, operandShardings,
96+
resultShardings, spmdizationMap,
97+
symbolTable, builder);
98+
return success();
99+
}
100+
};
101+
63102
template <typename OpType>
64103
static void registerElemwiseOne(MLIRContext *ctx) {
65104
OpType::template attachInterface<ElementwiseShardingInterface<OpType>>(*ctx);
@@ -82,9 +121,10 @@ void mlir::tosa::registerShardingInterfaceExternalModels(
82121
BitwiseOrOp, BitwiseXorOp, IntDivOp, LogicalAndOp, LogicalLeftShiftOp,
83122
LogicalRightShiftOp, LogicalOrOp, LogicalXorOp, MaximumOp, MinimumOp,
84123
MulOp, PowOp, SubOp, AbsOp, BitwiseNotOp, CeilOp, ClzOp, ExpOp, FloorOp,
85-
LogOp, LogicalNotOp, NegateOp, ReciprocalOp, RsqrtOp, SelectOp, EqualOp,
124+
LogOp, LogicalNotOp, ReciprocalOp, RsqrtOp, SelectOp, EqualOp,
86125
GreaterOp, GreaterEqualOp>(ctx);
87126

88127
MatMulOp::attachInterface<MatMulOpSharding>(*ctx);
128+
NegateOp::attachInterface<NegateOpSharding>(*ctx);
89129
});
90130
}

mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1190,13 +1190,36 @@ OpFoldResult tosa::ExpOp::fold(FoldAdaptor adaptor) {
11901190
}
11911191

11921192
OpFoldResult tosa::NegateOp::fold(FoldAdaptor adaptor) {
1193-
auto input = getInput1();
11941193
// Element-wise negate(negate(x)) = x
1195-
if (auto op = input.getDefiningOp<tosa::NegateOp>()) {
1196-
return op.getInput1();
1194+
// iff all zero points are constant 0
1195+
auto definingOp = getInput1().getDefiningOp<tosa::NegateOp>();
1196+
if (!definingOp) {
1197+
// defining op of input1 is not a negate, cannot fold
1198+
return {};
11971199
}
11981200

1199-
return {};
1201+
if (FailureOr<int64_t> maybeIZp = getInput1ZeroPoint();
1202+
failed(maybeIZp) || *maybeIZp != 0) {
1203+
// input1 zero point is not constant 0, cannot fold
1204+
return {};
1205+
}
1206+
if (FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
1207+
failed(maybeOZp) || *maybeOZp != 0) {
1208+
// output zero point is not constant 0, cannot fold
1209+
return {};
1210+
}
1211+
if (FailureOr<int64_t> maybeIZp = definingOp.getInput1ZeroPoint();
1212+
failed(maybeIZp) || *maybeIZp != 0) {
1213+
// definingOp's input1 zero point is not constant 0, cannot fold
1214+
return {};
1215+
}
1216+
if (FailureOr<int64_t> maybeOZp = definingOp.getOutputZeroPoint();
1217+
failed(maybeOZp) || *maybeOZp != 0) {
1218+
// definingOp's output zero point is not constant 0, cannot fold
1219+
return {};
1220+
}
1221+
1222+
return definingOp.getInput1();
12001223
}
12011224

12021225
OpFoldResult tosa::AbsOp::fold(FoldAdaptor adaptor) {

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 86 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -680,23 +680,43 @@ buildAvgPool2dOpWithQuantInfo(OpBuilder &builder, OperationState &result,
680680
result.types.push_back(outputType);
681681
}
682682

683-
/// This builder is called on single-parameter unary operators that have scale
684-
/// relationship between their input and output, expressed by the
685-
/// UnaryOpQuantizationAttr.
686-
static void buildUnaryOpWithQuantInfo(OpBuilder &builder,
687-
OperationState &result, Type outputType,
688-
Value input) {
689-
result.addOperands(input);
683+
/// This builder is called on single-parameter negate operator that
684+
/// have scale relationship between their input and output, expressed
685+
/// by the UnaryOpQuantizationAttr.
686+
static void buildNegateOpWithQuantInfo(OpBuilder &builder,
687+
OperationState &result, Type outputType,
688+
Value input) {
689+
const Location loc{result.location};
690+
int64_t input1Zp{0};
691+
int64_t outputZp{0};
690692
auto quantAttr = buildUnaryOpQuantizationAttr(builder, input, outputType);
691693
if (quantAttr) {
692-
// note: negateOp has attributes input1_zp and output_zp
693-
result.addAttribute("input1_zp",
694-
builder.getI32IntegerAttr(
695-
static_cast<int32_t>(quantAttr.getInputZp())));
696-
result.addAttribute("output_zp",
697-
builder.getI32IntegerAttr(
698-
static_cast<int32_t>(quantAttr.getOutputZp())));
694+
input1Zp = quantAttr.getInputZp();
695+
outputZp = quantAttr.getOutputZp();
696+
}
697+
const std::optional<Value> input1ZpOp =
698+
createZeroPointTensor(builder, loc, input.getType(), input1Zp);
699+
if (!input1ZpOp) {
700+
(void)emitError(
701+
loc, "Failed to create input1 zero point for quantized NEGATE op");
699702
}
703+
704+
const std::optional<Value> outputZpOp =
705+
createZeroPointTensor(builder, loc, input.getType(), outputZp);
706+
if (!outputZpOp) {
707+
(void)emitError(
708+
loc, "Failed to create output zero point for quantized NEGATE op");
709+
}
710+
711+
if (input1ZpOp && outputZpOp) {
712+
result.addOperands({input, input1ZpOp.value(), outputZpOp.value()});
713+
} else {
714+
// failed to create one or more zero points above: just add input as
715+
// operands. This will trigger error in building the op because of
716+
// missing zero points
717+
result.addOperands({input});
718+
}
719+
700720
result.types.push_back(outputType);
701721
}
702722

@@ -1560,6 +1580,9 @@ ZERO_POINT_HELPER(TransposeConv2DOp, Input)
15601580
ZERO_POINT_HELPER(TransposeConv2DOp, Weight)
15611581
ZERO_POINT_HELPER(AvgPool2dOp, Input)
15621582
ZERO_POINT_HELPER(AvgPool2dOp, Output)
1583+
ZERO_POINT_HELPER(NegateOp, Input1)
1584+
ZERO_POINT_HELPER(NegateOp, Output)
1585+
15631586
#undef ZERO_POINT_HELPER
15641587

15651588
LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
@@ -2039,7 +2062,6 @@ NARY_SHAPE_INFER(tosa::LogicalRightShiftOp)
20392062
NARY_SHAPE_INFER(tosa::LogicalXorOp)
20402063
NARY_SHAPE_INFER(tosa::MaximumOp)
20412064
NARY_SHAPE_INFER(tosa::MinimumOp)
2042-
NARY_SHAPE_INFER(tosa::NegateOp)
20432065
NARY_SHAPE_INFER(tosa::PowOp)
20442066
NARY_SHAPE_INFER(tosa::ReciprocalOp)
20452067
NARY_SHAPE_INFER(tosa::RescaleOp)
@@ -2053,6 +2075,55 @@ NARY_SHAPE_INFER(tosa::ErfOp)
20532075
NARY_SHAPE_INFER(tosa::SigmoidOp)
20542076
#undef PRED_SHAPE_INFER
20552077

2078+
LogicalResult tosa::NegateOp::inferReturnTypeComponents(
2079+
MLIRContext *context, ::std::optional<Location> location,
2080+
NegateOp::Adaptor adaptor,
2081+
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2082+
ShapeAdaptor inputShape(adaptor.getInput1().getType());
2083+
inferredReturnShapes.push_back(ShapedTypeComponents(inputShape));
2084+
return success();
2085+
}
2086+
2087+
LogicalResult tosa::NegateOp::verify() {
2088+
// Verify same element type
2089+
const Type input1Type = getInput1().getType();
2090+
const Type outputType = getOutput().getType();
2091+
if (verifySameElementTypes(*this, input1Type, outputType).failed())
2092+
return failure();
2093+
2094+
// Verify same shape
2095+
const SmallVector<Type, 2> types = {input1Type, outputType};
2096+
if (failed(verifyCompatibleShapes(types)))
2097+
return emitOpError() << "requires the same shape for input1 and output";
2098+
2099+
const Type input1EType = getStorageElementTypeOrSelf(getInput1().getType());
2100+
const Type input1ZpEType =
2101+
getStorageElementTypeOrSelf(getInput1Zp().getType());
2102+
if (input1EType != input1ZpEType) {
2103+
return emitOpError("expect both input1 and its zero point are the same "
2104+
"element type, got ")
2105+
<< input1EType << " and " << input1ZpEType;
2106+
}
2107+
const Type outputEType = getStorageElementTypeOrSelf(getOutput().getType());
2108+
const Type outputZpEType =
2109+
getStorageElementTypeOrSelf(getOutputZp().getType());
2110+
if (outputEType != outputZpEType) {
2111+
return emitOpError("expect both output and its zero point are the same "
2112+
"element type, got ")
2113+
<< outputEType << " and " << outputZpEType;
2114+
}
2115+
2116+
FailureOr<int64_t> maybeIZp = getInput1ZeroPoint();
2117+
if (succeeded(maybeIZp) && verifyInput1ZeroPoint(*maybeIZp).failed())
2118+
return failure();
2119+
2120+
FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
2121+
if (succeeded(maybeOZp) && verifyOutputZeroPoint(*maybeOZp).failed())
2122+
return failure();
2123+
2124+
return success();
2125+
}
2126+
20562127
static LogicalResult poolingInferReturnTypes(
20572128
ShapeAdaptor inputShape, ArrayRef<int64_t> kernel, ArrayRef<int64_t> stride,
20582129
ArrayRef<int64_t> pad,

0 commit comments

Comments
 (0)