-
Notifications
You must be signed in to change notification settings - Fork 13.5k
[TOSA] tosa.negate operator lowering update #107924
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
tosa.negate for integer types now uses the simpified branch if input_zp and output_zp values are also zero. Signed-off-by: Dmitriy Smirnov <[email protected]>
@llvm/pr-subscribers-mlir-linalg @llvm/pr-subscribers-mlir-tosa Author: Dmitriy Smirnov (d-smirnov) ChangesThis PR makes tosa.negate op for integer types to use the simplified calculation branch if input_zp and output_zp values are also zero. Full diff: https://github.com/llvm/llvm-project/pull/107924.diff 2 Files Affected:
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index ba259d4b84fceb..93e284af051883 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -139,19 +139,22 @@ static Value createLinalgBodyCalculationForElementwiseOp(
if (isa<tosa::NegateOp>(op) && isa<FloatType>(elementTy))
return rewriter.create<arith::NegFOp>(loc, resultTypes, args);
- if (isa<tosa::NegateOp>(op) && isa<IntegerType>(elementTy) &&
- !cast<tosa::NegateOp>(op).getQuantizationInfo()) {
- auto constant =
- rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(elementTy, 0));
- return rewriter.create<arith::SubIOp>(loc, resultTypes, constant, args[0]);
- }
+ if (isa<tosa::NegateOp>(op) && isa<IntegerType>(elementTy)) {
+ int64_t inZp = 0, outZp = 0;
+
+ if (cast<tosa::NegateOp>(op).getQuantizationInfo()) {
+ auto quantizationInfo = cast<tosa::NegateOp>(op).getQuantizationInfo();
+ inZp = quantizationInfo.value().getInputZp();
+ outZp = quantizationInfo.value().getOutputZp();
+ }
- if (isa<tosa::NegateOp>(op) && isa<IntegerType>(elementTy) &&
- cast<tosa::NegateOp>(op).getQuantizationInfo()) {
- auto quantizationInfo = cast<tosa::NegateOp>(op).getQuantizationInfo();
int32_t inputBitWidth = elementTy.getIntOrFloatBitWidth();
- int64_t inZp = quantizationInfo.value().getInputZp();
- int64_t outZp = quantizationInfo.value().getOutputZp();
+ if (!inZp && !outZp) {
+ auto constant = rewriter.create<arith::ConstantOp>(
+ loc, IntegerAttr::get(elementTy, 0));
+ return rewriter.create<arith::SubIOp>(loc, resultTypes, constant,
+ args[0]);
+ }
// Compute the maximum value that can occur in the intermediate buffer.
int64_t zpAdd = inZp + outZp;
@@ -402,17 +405,19 @@ static Value createLinalgBodyCalculationForElementwiseOp(
if (intTy.isUnsignedInteger()) {
minRepresentable = 0;
if (intTy.getIntOrFloatBitWidth() <= 63) {
- maxRepresentable = (int64_t)APInt::getMaxValue(intTy.getIntOrFloatBitWidth())
- .getZExtValue();
+ maxRepresentable =
+ (int64_t)APInt::getMaxValue(intTy.getIntOrFloatBitWidth())
+ .getZExtValue();
}
- } else if(intTy.getIntOrFloatBitWidth() <= 64) {
+ } else if (intTy.getIntOrFloatBitWidth() <= 64) {
// Ensure that min & max fit into signed n-bit constants.
minRepresentable = APInt::getSignedMinValue(intTy.getIntOrFloatBitWidth())
- .getSExtValue();
+ .getSExtValue();
maxRepresentable = APInt::getSignedMaxValue(intTy.getIntOrFloatBitWidth())
- .getSExtValue();
+ .getSExtValue();
}
- // Ensure that the bounds are representable as n-bit signed/unsigned integers.
+ // Ensure that the bounds are representable as n-bit signed/unsigned
+ // integers.
min = std::max(min, minRepresentable);
max = std::max(max, minRepresentable);
min = std::min(min, maxRepresentable);
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 0e35f8ea9d0cd1..f9d37f9427d4f4 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -857,16 +857,16 @@ func.func @test_bool(%arg0: tensor<1xi1>, %arg1: tensor<1xi1>) -> () {
func.func @test_negate_quantized(%arg0: tensor<1xi8>) -> () {
// CHECK: linalg.generic
// CHECK: ^bb0(%[[BBARG0:.+]]: i8,
- // CHECK: [[ZERO:%.+]] = arith.constant 0
+ // CHECK: [[CNST:%.+]] = arith.constant 7
// CHECK: [[EXT:%.+]] = arith.extsi %[[BBARG0]] : i8 to i16
- // CHECK: [[SUB:%.+]] = arith.subi [[ZERO]], [[EXT]]
+ // CHECK: [[SUB:%.+]] = arith.subi [[CNST]], [[EXT]]
// CHECK: [[MIN:%.+]] = arith.constant -128
// CHECK: [[MAX:%.+]] = arith.constant 127
// CHECK: [[LBOUND:%.+]] = arith.maxsi [[MIN]], [[SUB]]
// CHECK: [[UBOUND:%.+]] = arith.minsi [[MAX]], [[LBOUND]]
// CHECK: [[TRUNC:%.+]] = arith.trunci [[UBOUND]]
// CHECK: linalg.yield [[TRUNC]]
- %0 = tosa.negate %arg0 {quantization_info = #tosa.unary_quant<input_zp = 0, output_zp = 0>} : (tensor<1xi8>) -> tensor<1xi8>
+ %0 = tosa.negate %arg0 {quantization_info = #tosa.unary_quant<input_zp = 0, output_zp = 7>} : (tensor<1xi8>) -> tensor<1xi8>
// CHECK: linalg.generic
// CHECK: ^bb0(%[[BBARG0:.+]]: i8,
@@ -878,6 +878,13 @@ func.func @test_negate_quantized(%arg0: tensor<1xi8>) -> () {
// CHECK: [[EXT:%.+]] = arith.extsi %[[BBARG0]] : i8 to i32
%2 = tosa.negate %arg0 {quantization_info = #tosa.unary_quant<input_zp = 32640, output_zp = 0>} : (tensor<1xi8>) -> tensor<1xi8>
+ // CHECK: linalg.generic
+ // CHECK: ^bb0(%[[BBARG0:.+]]: i8,
+ // CHECK: [[ZERO:%.+]] = arith.constant 0
+ // CHECK: [[SUB:%.+]] = arith.subi [[ZERO]],
+ // CHECK: linalg.yield [[SUB]]
+ %3 = tosa.negate %arg0 {quantization_info = #tosa.unary_quant<input_zp = 0, output_zp = 0>} : (tensor<1xi8>) -> tensor<1xi8>
+
return
}
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Have a couple of questions:
- The TOSA spec highlights that
input_zp
andoutput_zp
must be zero for non-int8 types. I presume if that is the case then some of the code can be simplified when it comes to integer handling? - Is there a verifier for the negate op that checks the above?
|
This PR makes tosa.negate op for integer types to use the simplified calculation branch if input_zp and output_zp values are also zero.