Skip to content

[mlir][tosa] Switch zero point of negate to input variable type #129758

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,12 @@ profileComplianceMap = {
{"tosa.logical_not",
{{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT}}}}},
{"tosa.negate",
{{{Profile::pro_int}, {{i8T, i8T}, {i16T, i16T}, {i32T, i32T}}},
{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
{{{Profile::pro_int},
{{i8T, i8T, i8T, i8T},
{i16T, i16T, i16T, i16T},
{i32T, i32T, i32T, i32T}}},
{{Profile::pro_fp},
{{fp16T, fp16T, fp16T, fp16T}, {fp32T, fp32T, fp32T, fp32T}}}}},
{"tosa.reciprocal",
{{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
{"tosa.rsqrt", {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
Expand Down Expand Up @@ -310,7 +314,7 @@ extensionComplianceMap = {
{"tosa.exp", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
{"tosa.floor", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
{"tosa.log", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
{"tosa.negate", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
{"tosa.negate", {{{Extension::bf16}, {{bf16T, bf16T, bf16T, bf16T}}}}},
{"tosa.reciprocal", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
{"tosa.rsqrt", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
{"tosa.select", {{{Extension::bf16}, {{bf16T, bf16T, bf16T}}}}},
Expand Down
6 changes: 3 additions & 3 deletions mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -178,13 +178,13 @@ def Tosa_AvgPool2dOpQuantInfoBuilder : OpBuilder<
input, kernel, stride, pad, acc_type);
}]>;

// This builder is called on single-parameter unary operators that have a scale
// This builder is called on single-parameter negate operators that have a scale
// relationship between their input and output, expressed by the
// UnaryOpQuantizationAttr.
def Tosa_UnaryOpQuantInfoBuilder : OpBuilder<
def Tosa_NegateOpQuantInfoBuilder : OpBuilder<
(ins "Type":$outputType, "Value":$input),
[{
buildUnaryOpWithQuantInfo($_builder, $_state, outputType, input);
buildNegateOpWithQuantInfo($_builder, $_state, outputType, input);
}]>;

// These builders are called on the TOSA pad operator that needs to create its
Expand Down
21 changes: 17 additions & 4 deletions mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1356,7 +1356,9 @@ def Tosa_LogicalNotOp : Tosa_ElementwiseUnaryOp<"logical_not"> {
//===----------------------------------------------------------------------===//
// Operator: negate
//===----------------------------------------------------------------------===//
def Tosa_NegateOp : Tosa_ElementwiseUnaryOp<"negate"> {
def Tosa_NegateOp : Tosa_InferShapedTypeOp<"negate", [
TosaElementwiseOperator,
Pure]> {
let summary = "Elementwise negate op";

let description = [{
Expand All @@ -1365,8 +1367,8 @@ def Tosa_NegateOp : Tosa_ElementwiseUnaryOp<"negate"> {

let arguments = (ins
Tosa_Tensor:$input1,
OptionalAttr<I32Attr>:$input1_zp,
OptionalAttr<I32Attr>:$output_zp
Tosa_ScalarIntOrFloatTensor:$input1_zp,
Tosa_ScalarIntOrFloatTensor:$output_zp
);

let results = (outs
Expand All @@ -1378,9 +1380,20 @@ def Tosa_NegateOp : Tosa_ElementwiseUnaryOp<"negate"> {
Extension<[Tosa_EXT_BF16]>,
];

let builders = [Tosa_UnaryOpQuantInfoBuilder];
let builders = [Tosa_NegateOpQuantInfoBuilder];

let extraClassDeclaration = [{
FailureOr<int64_t> getInput1ZeroPoint();
FailureOr<int64_t> getOutputZeroPoint();
LogicalResult verifyInput1ZeroPoint(int64_t zp);
LogicalResult verifyOutputZeroPoint(int64_t zp);
}];

let hasFolder = 1;
let hasVerifier = 1;

let assemblyFormat =
"operands attr-dict `:` functional-type(operands, results)";
}

//===----------------------------------------------------------------------===//
Expand Down
29 changes: 20 additions & 9 deletions mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -193,18 +193,29 @@ static Value createLinalgBodyCalculationForElementwiseOp(

// tosa::NegateOp
if (isa<tosa::NegateOp>(op)) {
if (isa<FloatType>(elementTy))
return rewriter.create<arith::NegFOp>(loc, resultTypes, args);
auto negate = cast<tosa::NegateOp>(op);

if (isa<IntegerType>(elementTy)) {
auto inputZpAttr = cast<tosa::NegateOp>(op).getInput1ZpAttr();
auto outputZpAttr = cast<tosa::NegateOp>(op).getOutputZpAttr();
FailureOr<int64_t> maybeInZp = negate.getInput1ZeroPoint();
if (failed(maybeInZp)) {
(void)rewriter.notifyMatchFailure(
op, "input1 zero point cannot be statically determined");
return nullptr;
}

FailureOr<int64_t> maybeOutZp = negate.getOutputZeroPoint();
if (failed(maybeOutZp)) {
(void)rewriter.notifyMatchFailure(
op, "output zero point cannot be statically determined");
return nullptr;
}

const int64_t inZp =
inputZpAttr ? inputZpAttr.getValue().getSExtValue() : 0;
const int64_t outZp =
outputZpAttr ? outputZpAttr.getValue().getSExtValue() : 0;
int64_t inZp = *maybeInZp;
int64_t outZp = *maybeOutZp;

if (isa<FloatType>(elementTy))
return rewriter.create<arith::NegFOp>(loc, resultTypes, args[0]);

if (isa<IntegerType>(elementTy)) {
if (!inZp && !outZp) {
auto constant = rewriter.create<arith::ConstantOp>(
loc, IntegerAttr::get(elementTy, 0));
Expand Down
42 changes: 41 additions & 1 deletion mlir/lib/Dialect/Tosa/IR/ShardingInterfaceImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,45 @@ struct MatMulOpSharding
}
};

struct NegateOpSharding
: public ShardingInterface::ExternalModel<NegateOpSharding, NegateOp> {
SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
Value val = op->getOperand(0);
auto type = dyn_cast<RankedTensorType>(val.getType());
if (!type)
return {};
SmallVector<utils::IteratorType> types(type.getRank(),
utils::IteratorType::parallel);
return types;
}

SmallVector<AffineMap> getIndexingMaps(Operation *op) const {
MLIRContext *ctx = op->getContext();
Value val = op->getOperand(0);
auto type = dyn_cast<RankedTensorType>(val.getType());
if (!type)
return {};
int64_t rank = type.getRank();
SmallVector<AffineMap> maps = {
AffineMap::getMultiDimIdentityMap(rank, ctx),
AffineMap::get(0, 0, {}, ctx), AffineMap::get(0, 0, {}, ctx),
AffineMap::getMultiDimIdentityMap(rank, ctx)};
return maps;
}

LogicalResult spmdize(Operation *op, ArrayRef<Value> spmdizedOperands,
ArrayRef<MeshSharding> operandShardings,
ArrayRef<MeshSharding> resultShardings,
IRMapping &spmdizationMap,
SymbolTableCollection &symbolTable,
OpBuilder &builder) const {
spmdizeTriviallyShardableOperation(*op, spmdizedOperands, operandShardings,
resultShardings, spmdizationMap,
symbolTable, builder);
return success();
}
};

template <typename OpType>
static void registerElemwiseOne(MLIRContext *ctx) {
OpType::template attachInterface<ElementwiseShardingInterface<OpType>>(*ctx);
Expand All @@ -84,9 +123,10 @@ void mlir::tosa::registerShardingInterfaceExternalModels(
BitwiseOrOp, BitwiseXorOp, IntDivOp, LogicalAndOp, LogicalLeftShiftOp,
LogicalRightShiftOp, LogicalOrOp, LogicalXorOp, MaximumOp, MinimumOp,
MulOp, PowOp, SubOp, AbsOp, BitwiseNotOp, CeilOp, ClzOp, ExpOp, FloorOp,
LogOp, LogicalNotOp, NegateOp, ReciprocalOp, RsqrtOp, SelectOp, EqualOp,
LogOp, LogicalNotOp, ReciprocalOp, RsqrtOp, SelectOp, EqualOp,
GreaterOp, GreaterEqualOp>(ctx);

MatMulOp::attachInterface<MatMulOpSharding>(*ctx);
NegateOp::attachInterface<NegateOpSharding>(*ctx);
});
}
31 changes: 27 additions & 4 deletions mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1143,13 +1143,36 @@ OpFoldResult tosa::ExpOp::fold(FoldAdaptor adaptor) {
}

OpFoldResult tosa::NegateOp::fold(FoldAdaptor adaptor) {
auto input = getInput1();
// Element-wise negate(negate(x)) = x
if (auto op = input.getDefiningOp<tosa::NegateOp>()) {
return op.getInput1();
// iff all zero points are constant 0
auto definingOp = getInput1().getDefiningOp<tosa::NegateOp>();
if (!definingOp) {
// defining op of input1 is not a negate, cannot fold
return {};
}

return {};
if (FailureOr<int64_t> maybeIZp = getInput1ZeroPoint();
failed(maybeIZp) || *maybeIZp != 0) {
// input1 zero point is not constant 0, cannot fold
return {};
}
if (FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
failed(maybeOZp) || *maybeOZp != 0) {
// output zero point is not constant 0, cannot fold
return {};
}
if (FailureOr<int64_t> maybeIZp = definingOp.getInput1ZeroPoint();
failed(maybeIZp) || *maybeIZp != 0) {
// definingOp's input1 zero point is not constant 0, cannot fold
return {};
}
if (FailureOr<int64_t> maybeOZp = definingOp.getOutputZeroPoint();
failed(maybeOZp) || *maybeOZp != 0) {
// definingOp's output zero point is not constant 0, cannot fold
return {};
}

return definingOp.getInput1();
}

OpFoldResult tosa::AbsOp::fold(FoldAdaptor adaptor) {
Expand Down
101 changes: 86 additions & 15 deletions mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -697,23 +697,43 @@ buildAvgPool2dOpWithQuantInfo(OpBuilder &builder, OperationState &result,
result.types.push_back(outputType);
}

/// This builder is called on single-parameter unary operators that have scale
/// relationship between their input and output, expressed by the
/// UnaryOpQuantizationAttr.
static void buildUnaryOpWithQuantInfo(OpBuilder &builder,
OperationState &result, Type outputType,
Value input) {
result.addOperands(input);
/// This builder is called on single-parameter negate operator
/// to construct input and output zero points based on their
/// types.
static void buildNegateOpWithQuantInfo(OpBuilder &builder,
OperationState &result, Type outputType,
Value input) {
const Location loc{result.location};
int64_t input1Zp{0};
int64_t outputZp{0};
auto quantAttr = buildUnaryOpQuantizationAttr(builder, input, outputType);
if (quantAttr) {
// note: negateOp has attributes input1_zp and output_zp
result.addAttribute("input1_zp",
builder.getI32IntegerAttr(
static_cast<int32_t>(quantAttr.getInputZp())));
result.addAttribute("output_zp",
builder.getI32IntegerAttr(
static_cast<int32_t>(quantAttr.getOutputZp())));
input1Zp = quantAttr.getInputZp();
outputZp = quantAttr.getOutputZp();
}
const std::optional<Value> input1ZpOp =
createZeroPointTensor(builder, loc, input.getType(), input1Zp);
if (!input1ZpOp) {
(void)emitError(
loc, "Failed to create input1 zero point for quantized NEGATE op");
}

const std::optional<Value> outputZpOp =
createZeroPointTensor(builder, loc, input.getType(), outputZp);
if (!outputZpOp) {
(void)emitError(
loc, "Failed to create output zero point for quantized NEGATE op");
}

if (input1ZpOp && outputZpOp) {
result.addOperands({input, input1ZpOp.value(), outputZpOp.value()});
} else {
// failed to create one or more zero points above: just add input as
// operands. This will trigger error in building the op because of
// missing zero points
result.addOperands({input});
}

result.types.push_back(outputType);
}

Expand Down Expand Up @@ -1728,6 +1748,9 @@ ZERO_POINT_HELPER(AvgPool2dOp, Input)
ZERO_POINT_HELPER(AvgPool2dOp, Output)
ZERO_POINT_HELPER(MatMulOp, A)
ZERO_POINT_HELPER(MatMulOp, B)
ZERO_POINT_HELPER(NegateOp, Input1)
ZERO_POINT_HELPER(NegateOp, Output)

#undef ZERO_POINT_HELPER

LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
Expand Down Expand Up @@ -2230,7 +2253,6 @@ NARY_SHAPE_INFER(tosa::LogicalRightShiftOp)
NARY_SHAPE_INFER(tosa::LogicalXorOp)
NARY_SHAPE_INFER(tosa::MaximumOp)
NARY_SHAPE_INFER(tosa::MinimumOp)
NARY_SHAPE_INFER(tosa::NegateOp)
NARY_SHAPE_INFER(tosa::PowOp)
NARY_SHAPE_INFER(tosa::ReciprocalOp)
NARY_SHAPE_INFER(tosa::ReverseOp)
Expand All @@ -2243,6 +2265,55 @@ NARY_SHAPE_INFER(tosa::ErfOp)
NARY_SHAPE_INFER(tosa::SigmoidOp)
#undef PRED_SHAPE_INFER

LogicalResult tosa::NegateOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
NegateOp::Adaptor adaptor,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
ShapeAdaptor inputShape(adaptor.getInput1().getType());
inferredReturnShapes.push_back(ShapedTypeComponents(inputShape));
return success();
}

LogicalResult tosa::NegateOp::verify() {
// Verify same element type
const Type input1Type = getInput1().getType();
const Type outputType = getOutput().getType();
if (verifySameElementTypes(*this, input1Type, outputType).failed())
return failure();

// Verify same shape
const SmallVector<Type, 2> types = {input1Type, outputType};
if (failed(verifyCompatibleShapes(types)))
return emitOpError() << "requires the same shape for input1 and output";

const Type input1EType = getStorageElementTypeOrSelf(getInput1().getType());
const Type input1ZpEType =
getStorageElementTypeOrSelf(getInput1Zp().getType());
if (input1EType != input1ZpEType) {
return emitOpError("expect both input1 and its zero point are the same "
"element type, got ")
<< input1EType << " and " << input1ZpEType;
}
const Type outputEType = getStorageElementTypeOrSelf(getOutput().getType());
const Type outputZpEType =
getStorageElementTypeOrSelf(getOutputZp().getType());
if (outputEType != outputZpEType) {
return emitOpError("expect both output and its zero point are the same "
"element type, got ")
<< outputEType << " and " << outputZpEType;
}

FailureOr<int64_t> maybeIZp = getInput1ZeroPoint();
if (succeeded(maybeIZp) && verifyInput1ZeroPoint(*maybeIZp).failed())
return failure();

FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
if (succeeded(maybeOZp) && verifyOutputZeroPoint(*maybeOZp).failed())
return failure();

return success();
}

static LogicalResult poolingInferReturnTypes(
ShapeAdaptor inputShape, ArrayRef<int64_t> kernel, ArrayRef<int64_t> stride,
ArrayRef<int64_t> pad,
Expand Down
Loading