Skip to content

[TOSA] tosa.negate operator lowering update #107924

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 10, 2024
Merged

Conversation

d-smirnov
Copy link
Contributor

This PR makes tosa.negate op for integer types to use the simplified calculation branch if input_zp and output_zp values are also zero.

tosa.negate for integer types now uses the simpified branch
if input_zp and output_zp values are also zero.

Signed-off-by: Dmitriy Smirnov <[email protected]>
@llvmbot
Copy link
Member

llvmbot commented Sep 9, 2024

@llvm/pr-subscribers-mlir-linalg
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-tosa

Author: Dmitriy Smirnov (d-smirnov)

Changes

This PR makes tosa.negate op for integer types to use the simplified calculation branch if input_zp and output_zp values are also zero.


Full diff: https://github.com/llvm/llvm-project/pull/107924.diff

2 Files Affected:

  • (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp (+22-17)
  • (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir (+10-3)
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index ba259d4b84fceb..93e284af051883 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -139,19 +139,22 @@ static Value createLinalgBodyCalculationForElementwiseOp(
   if (isa<tosa::NegateOp>(op) && isa<FloatType>(elementTy))
     return rewriter.create<arith::NegFOp>(loc, resultTypes, args);
 
-  if (isa<tosa::NegateOp>(op) && isa<IntegerType>(elementTy) &&
-      !cast<tosa::NegateOp>(op).getQuantizationInfo()) {
-    auto constant =
-        rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(elementTy, 0));
-    return rewriter.create<arith::SubIOp>(loc, resultTypes, constant, args[0]);
-  }
+  if (isa<tosa::NegateOp>(op) && isa<IntegerType>(elementTy)) {
+    int64_t inZp = 0, outZp = 0;
+
+    if (cast<tosa::NegateOp>(op).getQuantizationInfo()) {
+      auto quantizationInfo = cast<tosa::NegateOp>(op).getQuantizationInfo();
+      inZp = quantizationInfo.value().getInputZp();
+      outZp = quantizationInfo.value().getOutputZp();
+    }
 
-  if (isa<tosa::NegateOp>(op) && isa<IntegerType>(elementTy) &&
-      cast<tosa::NegateOp>(op).getQuantizationInfo()) {
-    auto quantizationInfo = cast<tosa::NegateOp>(op).getQuantizationInfo();
     int32_t inputBitWidth = elementTy.getIntOrFloatBitWidth();
-    int64_t inZp = quantizationInfo.value().getInputZp();
-    int64_t outZp = quantizationInfo.value().getOutputZp();
+    if (!inZp && !outZp) {
+      auto constant = rewriter.create<arith::ConstantOp>(
+          loc, IntegerAttr::get(elementTy, 0));
+      return rewriter.create<arith::SubIOp>(loc, resultTypes, constant,
+                                            args[0]);
+    }
 
     // Compute the maximum value that can occur in the intermediate buffer.
     int64_t zpAdd = inZp + outZp;
@@ -402,17 +405,19 @@ static Value createLinalgBodyCalculationForElementwiseOp(
     if (intTy.isUnsignedInteger()) {
       minRepresentable = 0;
       if (intTy.getIntOrFloatBitWidth() <= 63) {
-        maxRepresentable = (int64_t)APInt::getMaxValue(intTy.getIntOrFloatBitWidth())
-                          .getZExtValue();
+        maxRepresentable =
+            (int64_t)APInt::getMaxValue(intTy.getIntOrFloatBitWidth())
+                .getZExtValue();
       }
-    } else if(intTy.getIntOrFloatBitWidth() <= 64) {
+    } else if (intTy.getIntOrFloatBitWidth() <= 64) {
       // Ensure that min & max fit into signed n-bit constants.
       minRepresentable = APInt::getSignedMinValue(intTy.getIntOrFloatBitWidth())
-                            .getSExtValue();
+                             .getSExtValue();
       maxRepresentable = APInt::getSignedMaxValue(intTy.getIntOrFloatBitWidth())
-                            .getSExtValue();
+                             .getSExtValue();
     }
-    // Ensure that the bounds are representable as n-bit signed/unsigned integers.
+    // Ensure that the bounds are representable as n-bit signed/unsigned
+    // integers.
     min = std::max(min, minRepresentable);
     max = std::max(max, minRepresentable);
     min = std::min(min, maxRepresentable);
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 0e35f8ea9d0cd1..f9d37f9427d4f4 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -857,16 +857,16 @@ func.func @test_bool(%arg0: tensor<1xi1>, %arg1: tensor<1xi1>) -> () {
 func.func @test_negate_quantized(%arg0: tensor<1xi8>) -> () {
   // CHECK: linalg.generic
   // CHECK: ^bb0(%[[BBARG0:.+]]: i8,
-  // CHECK: [[ZERO:%.+]] = arith.constant 0
+  // CHECK: [[CNST:%.+]] = arith.constant 7
   // CHECK: [[EXT:%.+]] = arith.extsi %[[BBARG0]] : i8 to i16
-  // CHECK: [[SUB:%.+]] = arith.subi [[ZERO]], [[EXT]]
+  // CHECK: [[SUB:%.+]] = arith.subi [[CNST]], [[EXT]]
   // CHECK: [[MIN:%.+]] = arith.constant -128
   // CHECK: [[MAX:%.+]] = arith.constant 127
   // CHECK: [[LBOUND:%.+]] = arith.maxsi [[MIN]], [[SUB]]
   // CHECK: [[UBOUND:%.+]] = arith.minsi [[MAX]], [[LBOUND]]
   // CHECK: [[TRUNC:%.+]] = arith.trunci [[UBOUND]]
   // CHECK: linalg.yield [[TRUNC]]
-  %0 = tosa.negate %arg0 {quantization_info = #tosa.unary_quant<input_zp = 0, output_zp = 0>} : (tensor<1xi8>) -> tensor<1xi8>
+  %0 = tosa.negate %arg0 {quantization_info = #tosa.unary_quant<input_zp = 0, output_zp = 7>} : (tensor<1xi8>) -> tensor<1xi8>
 
   // CHECK: linalg.generic
   // CHECK: ^bb0(%[[BBARG0:.+]]: i8,
@@ -878,6 +878,13 @@ func.func @test_negate_quantized(%arg0: tensor<1xi8>) -> () {
   // CHECK: [[EXT:%.+]] = arith.extsi %[[BBARG0]] : i8 to i32
   %2 = tosa.negate %arg0 {quantization_info = #tosa.unary_quant<input_zp = 32640, output_zp = 0>} : (tensor<1xi8>) -> tensor<1xi8>
 
+  // CHECK: linalg.generic
+  // CHECK: ^bb0(%[[BBARG0:.+]]: i8,
+  // CHECK: [[ZERO:%.+]] = arith.constant 0
+  // CHECK: [[SUB:%.+]] = arith.subi [[ZERO]],
+  // CHECK: linalg.yield [[SUB]]
+  %3 = tosa.negate %arg0 {quantization_info = #tosa.unary_quant<input_zp = 0, output_zp = 0>} : (tensor<1xi8>) -> tensor<1xi8>
+
   return
 }
 

Copy link
Contributor

@GeorgeARM GeorgeARM left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Have a couple of questions:

  • The TOSA spec highlights that input_zp and output_zp must be zero for non-int8 types. I presume if that is the case then some of the code can be simplified when it comes to integer handling?
  • Is there a verifier for the negate op that checks the above?

@d-smirnov
Copy link
Contributor Author

d-smirnov commented Sep 10, 2024

  • This is for all integer types (including i8) which have input and output ZPs equal to zero, to avoid unnecessary usage of accumulator of extended integer type.
  • There is no such check as far as I can tell from the source code.

@GeorgeARM GeorgeARM merged commit 2778d9d into llvm:main Sep 10, 2024
12 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants