Skip to content

Commit 483c23f

Browse files
Tai78641lhutton1
andauthored
[mlir][tosa] Switch zero point of negate to input variable type (#129758)
This commit changes the zero point attribute to an input to align with the 1.0 spec. Signed-off-by: Tai Ly <[email protected]> Co-authored-by: Luke Hutton <[email protected]>
1 parent 7016f2d commit 483c23f

File tree

17 files changed

+364
-91
lines changed

17 files changed

+364
-91
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -114,8 +114,12 @@ profileComplianceMap = {
114114
{"tosa.logical_not",
115115
{{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT}}}}},
116116
{"tosa.negate",
117-
{{{Profile::pro_int}, {{i8T, i8T}, {i16T, i16T}, {i32T, i32T}}},
118-
{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
117+
{{{Profile::pro_int},
118+
{{i8T, i8T, i8T, i8T},
119+
{i16T, i16T, i16T, i16T},
120+
{i32T, i32T, i32T, i32T}}},
121+
{{Profile::pro_fp},
122+
{{fp16T, fp16T, fp16T, fp16T}, {fp32T, fp32T, fp32T, fp32T}}}}},
119123
{"tosa.reciprocal",
120124
{{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
121125
{"tosa.rsqrt", {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
@@ -310,7 +314,7 @@ extensionComplianceMap = {
310314
{"tosa.exp", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
311315
{"tosa.floor", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
312316
{"tosa.log", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
313-
{"tosa.negate", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
317+
{"tosa.negate", {{{Extension::bf16}, {{bf16T, bf16T, bf16T, bf16T}}}}},
314318
{"tosa.reciprocal", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
315319
{"tosa.rsqrt", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
316320
{"tosa.select", {{{Extension::bf16}, {{bf16T, bf16T, bf16T}}}}},

mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -178,13 +178,13 @@ def Tosa_AvgPool2dOpQuantInfoBuilder : OpBuilder<
178178
input, kernel, stride, pad, acc_type);
179179
}]>;
180180

181-
// This builder is called on single-parameter unary operators that have a scale
181+
// This builder is called on single-parameter negate operators that have a scale
182182
// relationship between their input and output, expressed by the
183183
// UnaryOpQuantizationAttr.
184-
def Tosa_UnaryOpQuantInfoBuilder : OpBuilder<
184+
def Tosa_NegateOpQuantInfoBuilder : OpBuilder<
185185
(ins "Type":$outputType, "Value":$input),
186186
[{
187-
buildUnaryOpWithQuantInfo($_builder, $_state, outputType, input);
187+
buildNegateOpWithQuantInfo($_builder, $_state, outputType, input);
188188
}]>;
189189

190190
// These builders are called on the TOSA pad operator that needs to create its

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1356,7 +1356,9 @@ def Tosa_LogicalNotOp : Tosa_ElementwiseUnaryOp<"logical_not"> {
13561356
//===----------------------------------------------------------------------===//
13571357
// Operator: negate
13581358
//===----------------------------------------------------------------------===//
1359-
def Tosa_NegateOp : Tosa_ElementwiseUnaryOp<"negate"> {
1359+
def Tosa_NegateOp : Tosa_InferShapedTypeOp<"negate", [
1360+
TosaElementwiseOperator,
1361+
Pure]> {
13601362
let summary = "Elementwise negate op";
13611363

13621364
let description = [{
@@ -1365,8 +1367,8 @@ def Tosa_NegateOp : Tosa_ElementwiseUnaryOp<"negate"> {
13651367

13661368
let arguments = (ins
13671369
Tosa_Tensor:$input1,
1368-
OptionalAttr<I32Attr>:$input1_zp,
1369-
OptionalAttr<I32Attr>:$output_zp
1370+
Tosa_ScalarIntOrFloatTensor:$input1_zp,
1371+
Tosa_ScalarIntOrFloatTensor:$output_zp
13701372
);
13711373

13721374
let results = (outs
@@ -1378,9 +1380,20 @@ def Tosa_NegateOp : Tosa_ElementwiseUnaryOp<"negate"> {
13781380
Extension<[Tosa_EXT_BF16]>,
13791381
];
13801382

1381-
let builders = [Tosa_UnaryOpQuantInfoBuilder];
1383+
let builders = [Tosa_NegateOpQuantInfoBuilder];
1384+
1385+
let extraClassDeclaration = [{
1386+
FailureOr<int64_t> getInput1ZeroPoint();
1387+
FailureOr<int64_t> getOutputZeroPoint();
1388+
LogicalResult verifyInput1ZeroPoint(int64_t zp);
1389+
LogicalResult verifyOutputZeroPoint(int64_t zp);
1390+
}];
13821391

13831392
let hasFolder = 1;
1393+
let hasVerifier = 1;
1394+
1395+
let assemblyFormat =
1396+
"operands attr-dict `:` functional-type(operands, results)";
13841397
}
13851398

13861399
//===----------------------------------------------------------------------===//

mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -193,18 +193,29 @@ static Value createLinalgBodyCalculationForElementwiseOp(
193193

194194
// tosa::NegateOp
195195
if (isa<tosa::NegateOp>(op)) {
196-
if (isa<FloatType>(elementTy))
197-
return rewriter.create<arith::NegFOp>(loc, resultTypes, args);
196+
auto negate = cast<tosa::NegateOp>(op);
198197

199-
if (isa<IntegerType>(elementTy)) {
200-
auto inputZpAttr = cast<tosa::NegateOp>(op).getInput1ZpAttr();
201-
auto outputZpAttr = cast<tosa::NegateOp>(op).getOutputZpAttr();
198+
FailureOr<int64_t> maybeInZp = negate.getInput1ZeroPoint();
199+
if (failed(maybeInZp)) {
200+
(void)rewriter.notifyMatchFailure(
201+
op, "input1 zero point cannot be statically determined");
202+
return nullptr;
203+
}
204+
205+
FailureOr<int64_t> maybeOutZp = negate.getOutputZeroPoint();
206+
if (failed(maybeOutZp)) {
207+
(void)rewriter.notifyMatchFailure(
208+
op, "output zero point cannot be statically determined");
209+
return nullptr;
210+
}
202211

203-
const int64_t inZp =
204-
inputZpAttr ? inputZpAttr.getValue().getSExtValue() : 0;
205-
const int64_t outZp =
206-
outputZpAttr ? outputZpAttr.getValue().getSExtValue() : 0;
212+
int64_t inZp = *maybeInZp;
213+
int64_t outZp = *maybeOutZp;
207214

215+
if (isa<FloatType>(elementTy))
216+
return rewriter.create<arith::NegFOp>(loc, resultTypes, args[0]);
217+
218+
if (isa<IntegerType>(elementTy)) {
208219
if (!inZp && !outZp) {
209220
auto constant = rewriter.create<arith::ConstantOp>(
210221
loc, IntegerAttr::get(elementTy, 0));

mlir/lib/Dialect/Tosa/IR/ShardingInterfaceImpl.cpp

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,45 @@ struct MatMulOpSharding
6262
}
6363
};
6464

65+
struct NegateOpSharding
66+
: public ShardingInterface::ExternalModel<NegateOpSharding, NegateOp> {
67+
SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
68+
Value val = op->getOperand(0);
69+
auto type = dyn_cast<RankedTensorType>(val.getType());
70+
if (!type)
71+
return {};
72+
SmallVector<utils::IteratorType> types(type.getRank(),
73+
utils::IteratorType::parallel);
74+
return types;
75+
}
76+
77+
SmallVector<AffineMap> getIndexingMaps(Operation *op) const {
78+
MLIRContext *ctx = op->getContext();
79+
Value val = op->getOperand(0);
80+
auto type = dyn_cast<RankedTensorType>(val.getType());
81+
if (!type)
82+
return {};
83+
int64_t rank = type.getRank();
84+
SmallVector<AffineMap> maps = {
85+
AffineMap::getMultiDimIdentityMap(rank, ctx),
86+
AffineMap::get(0, 0, {}, ctx), AffineMap::get(0, 0, {}, ctx),
87+
AffineMap::getMultiDimIdentityMap(rank, ctx)};
88+
return maps;
89+
}
90+
91+
LogicalResult spmdize(Operation *op, ArrayRef<Value> spmdizedOperands,
92+
ArrayRef<MeshSharding> operandShardings,
93+
ArrayRef<MeshSharding> resultShardings,
94+
IRMapping &spmdizationMap,
95+
SymbolTableCollection &symbolTable,
96+
OpBuilder &builder) const {
97+
spmdizeTriviallyShardableOperation(*op, spmdizedOperands, operandShardings,
98+
resultShardings, spmdizationMap,
99+
symbolTable, builder);
100+
return success();
101+
}
102+
};
103+
65104
template <typename OpType>
66105
static void registerElemwiseOne(MLIRContext *ctx) {
67106
OpType::template attachInterface<ElementwiseShardingInterface<OpType>>(*ctx);
@@ -84,9 +123,10 @@ void mlir::tosa::registerShardingInterfaceExternalModels(
84123
BitwiseOrOp, BitwiseXorOp, IntDivOp, LogicalAndOp, LogicalLeftShiftOp,
85124
LogicalRightShiftOp, LogicalOrOp, LogicalXorOp, MaximumOp, MinimumOp,
86125
MulOp, PowOp, SubOp, AbsOp, BitwiseNotOp, CeilOp, ClzOp, ExpOp, FloorOp,
87-
LogOp, LogicalNotOp, NegateOp, ReciprocalOp, RsqrtOp, SelectOp, EqualOp,
126+
LogOp, LogicalNotOp, ReciprocalOp, RsqrtOp, SelectOp, EqualOp,
88127
GreaterOp, GreaterEqualOp>(ctx);
89128

90129
MatMulOp::attachInterface<MatMulOpSharding>(*ctx);
130+
NegateOp::attachInterface<NegateOpSharding>(*ctx);
91131
});
92132
}

mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1143,13 +1143,36 @@ OpFoldResult tosa::ExpOp::fold(FoldAdaptor adaptor) {
11431143
}
11441144

11451145
OpFoldResult tosa::NegateOp::fold(FoldAdaptor adaptor) {
1146-
auto input = getInput1();
11471146
// Element-wise negate(negate(x)) = x
1148-
if (auto op = input.getDefiningOp<tosa::NegateOp>()) {
1149-
return op.getInput1();
1147+
// iff all zero points are constant 0
1148+
auto definingOp = getInput1().getDefiningOp<tosa::NegateOp>();
1149+
if (!definingOp) {
1150+
// defining op of input1 is not a negate, cannot fold
1151+
return {};
11501152
}
11511153

1152-
return {};
1154+
if (FailureOr<int64_t> maybeIZp = getInput1ZeroPoint();
1155+
failed(maybeIZp) || *maybeIZp != 0) {
1156+
// input1 zero point is not constant 0, cannot fold
1157+
return {};
1158+
}
1159+
if (FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
1160+
failed(maybeOZp) || *maybeOZp != 0) {
1161+
// output zero point is not constant 0, cannot fold
1162+
return {};
1163+
}
1164+
if (FailureOr<int64_t> maybeIZp = definingOp.getInput1ZeroPoint();
1165+
failed(maybeIZp) || *maybeIZp != 0) {
1166+
// definingOp's input1 zero point is not constant 0, cannot fold
1167+
return {};
1168+
}
1169+
if (FailureOr<int64_t> maybeOZp = definingOp.getOutputZeroPoint();
1170+
failed(maybeOZp) || *maybeOZp != 0) {
1171+
// definingOp's output zero point is not constant 0, cannot fold
1172+
return {};
1173+
}
1174+
1175+
return definingOp.getInput1();
11531176
}
11541177

11551178
OpFoldResult tosa::AbsOp::fold(FoldAdaptor adaptor) {

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 86 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -697,23 +697,43 @@ buildAvgPool2dOpWithQuantInfo(OpBuilder &builder, OperationState &result,
697697
result.types.push_back(outputType);
698698
}
699699

700-
/// This builder is called on single-parameter unary operators that have scale
701-
/// relationship between their input and output, expressed by the
702-
/// UnaryOpQuantizationAttr.
703-
static void buildUnaryOpWithQuantInfo(OpBuilder &builder,
704-
OperationState &result, Type outputType,
705-
Value input) {
706-
result.addOperands(input);
700+
/// This builder is called on single-parameter negate operator
701+
/// to construct input and output zero points based on their
702+
/// types.
703+
static void buildNegateOpWithQuantInfo(OpBuilder &builder,
704+
OperationState &result, Type outputType,
705+
Value input) {
706+
const Location loc{result.location};
707+
int64_t input1Zp{0};
708+
int64_t outputZp{0};
707709
auto quantAttr = buildUnaryOpQuantizationAttr(builder, input, outputType);
708710
if (quantAttr) {
709-
// note: negateOp has attributes input1_zp and output_zp
710-
result.addAttribute("input1_zp",
711-
builder.getI32IntegerAttr(
712-
static_cast<int32_t>(quantAttr.getInputZp())));
713-
result.addAttribute("output_zp",
714-
builder.getI32IntegerAttr(
715-
static_cast<int32_t>(quantAttr.getOutputZp())));
711+
input1Zp = quantAttr.getInputZp();
712+
outputZp = quantAttr.getOutputZp();
713+
}
714+
const std::optional<Value> input1ZpOp =
715+
createZeroPointTensor(builder, loc, input.getType(), input1Zp);
716+
if (!input1ZpOp) {
717+
(void)emitError(
718+
loc, "Failed to create input1 zero point for quantized NEGATE op");
719+
}
720+
721+
const std::optional<Value> outputZpOp =
722+
createZeroPointTensor(builder, loc, input.getType(), outputZp);
723+
if (!outputZpOp) {
724+
(void)emitError(
725+
loc, "Failed to create output zero point for quantized NEGATE op");
716726
}
727+
728+
if (input1ZpOp && outputZpOp) {
729+
result.addOperands({input, input1ZpOp.value(), outputZpOp.value()});
730+
} else {
731+
// failed to create one or more zero points above: just add input as
732+
// operands. This will trigger error in building the op because of
733+
// missing zero points
734+
result.addOperands({input});
735+
}
736+
717737
result.types.push_back(outputType);
718738
}
719739

@@ -1729,6 +1749,9 @@ ZERO_POINT_HELPER(AvgPool2dOp, Input)
17291749
ZERO_POINT_HELPER(AvgPool2dOp, Output)
17301750
ZERO_POINT_HELPER(MatMulOp, A)
17311751
ZERO_POINT_HELPER(MatMulOp, B)
1752+
ZERO_POINT_HELPER(NegateOp, Input1)
1753+
ZERO_POINT_HELPER(NegateOp, Output)
1754+
17321755
#undef ZERO_POINT_HELPER
17331756

17341757
LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
@@ -2231,7 +2254,6 @@ NARY_SHAPE_INFER(tosa::LogicalRightShiftOp)
22312254
NARY_SHAPE_INFER(tosa::LogicalXorOp)
22322255
NARY_SHAPE_INFER(tosa::MaximumOp)
22332256
NARY_SHAPE_INFER(tosa::MinimumOp)
2234-
NARY_SHAPE_INFER(tosa::NegateOp)
22352257
NARY_SHAPE_INFER(tosa::PowOp)
22362258
NARY_SHAPE_INFER(tosa::ReciprocalOp)
22372259
NARY_SHAPE_INFER(tosa::ReverseOp)
@@ -2244,6 +2266,55 @@ NARY_SHAPE_INFER(tosa::ErfOp)
22442266
NARY_SHAPE_INFER(tosa::SigmoidOp)
22452267
#undef PRED_SHAPE_INFER
22462268

2269+
LogicalResult tosa::NegateOp::inferReturnTypeComponents(
2270+
MLIRContext *context, ::std::optional<Location> location,
2271+
NegateOp::Adaptor adaptor,
2272+
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2273+
ShapeAdaptor inputShape(adaptor.getInput1().getType());
2274+
inferredReturnShapes.push_back(ShapedTypeComponents(inputShape));
2275+
return success();
2276+
}
2277+
2278+
LogicalResult tosa::NegateOp::verify() {
2279+
// Verify same element type
2280+
const Type input1Type = getInput1().getType();
2281+
const Type outputType = getOutput().getType();
2282+
if (verifySameElementTypes(*this, input1Type, outputType).failed())
2283+
return failure();
2284+
2285+
// Verify same shape
2286+
const SmallVector<Type, 2> types = {input1Type, outputType};
2287+
if (failed(verifyCompatibleShapes(types)))
2288+
return emitOpError() << "requires the same shape for input1 and output";
2289+
2290+
const Type input1EType = getStorageElementTypeOrSelf(getInput1().getType());
2291+
const Type input1ZpEType =
2292+
getStorageElementTypeOrSelf(getInput1Zp().getType());
2293+
if (input1EType != input1ZpEType) {
2294+
return emitOpError("expect both input1 and its zero point are the same "
2295+
"element type, got ")
2296+
<< input1EType << " and " << input1ZpEType;
2297+
}
2298+
const Type outputEType = getStorageElementTypeOrSelf(getOutput().getType());
2299+
const Type outputZpEType =
2300+
getStorageElementTypeOrSelf(getOutputZp().getType());
2301+
if (outputEType != outputZpEType) {
2302+
return emitOpError("expect both output and its zero point are the same "
2303+
"element type, got ")
2304+
<< outputEType << " and " << outputZpEType;
2305+
}
2306+
2307+
FailureOr<int64_t> maybeIZp = getInput1ZeroPoint();
2308+
if (succeeded(maybeIZp) && verifyInput1ZeroPoint(*maybeIZp).failed())
2309+
return failure();
2310+
2311+
FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
2312+
if (succeeded(maybeOZp) && verifyOutputZeroPoint(*maybeOZp).failed())
2313+
return failure();
2314+
2315+
return success();
2316+
}
2317+
22472318
static LogicalResult poolingInferReturnTypes(
22482319
ShapeAdaptor inputShape, ArrayRef<int64_t> kernel, ArrayRef<int64_t> stride,
22492320
ArrayRef<int64_t> pad,

0 commit comments

Comments
 (0)