Skip to content

[mlir][tosa] Remove Quantization Attribute #125479

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 5, 2025

Conversation

FranklandJack
Copy link
Contributor

Removed the TOSA quantization attribute used in various MLIR TOSA dialect operations in favour of using builtin attributes.

Update any lit tests, conversions and transformations appropriately.

Rename operands as follows to align with the TOSA-v1.0 specification:

  • cond -> condition
  • then_branch -> then_graph
  • else_branch -> else_graph
  • inputs -> input_list
  • output -> output_list
  • cond -> cond_graph
  • body -> body_graph

@llvmbot
Copy link
Member

llvmbot commented Feb 3, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-tosa

Author: Jack Frankland (FranklandJack)

Changes

Removed the TOSA quantization attribute used in various MLIR TOSA dialect operations in favour of using builtin attributes.

Update any lit tests, conversions and transformations appropriately.

Rename operands as follows to align with the TOSA-v1.0 specification:

  • cond -> condition
  • then_branch -> then_graph
  • else_branch -> else_graph
  • inputs -> input_list
  • output -> output_list
  • cond -> cond_graph
  • body -> body_graph

Patch is 39.24 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/125479.diff

15 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td (+17-13)
  • (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp (+52-48)
  • (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp (+11-20)
  • (modified) mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp (+6-6)
  • (modified) mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp (+3-3)
  • (modified) mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp (+3-3)
  • (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+56-27)
  • (modified) mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp (+3-3)
  • (modified) mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp (+2-3)
  • (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir (+2-2)
  • (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir (+19-9)
  • (modified) mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir (+1-1)
  • (modified) mlir/test/Dialect/Tosa/canonicalize.mlir (+1-1)
  • (modified) mlir/test/Dialect/Tosa/tosa-decompose-conv2d.mlir (+3-3)
  • (modified) mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir (+5-5)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 819547855d1015..fef0f2d98d95c6 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -78,7 +78,8 @@ def Tosa_AvgPool2dOp : Tosa_InferShapedTypeOp<"avg_pool2d"> {
     Tosa_IntArrayAttr2:$stride,
     Tosa_IntArrayAttr4:$pad,
     TypeAttrOf<Tosa_AccType>:$acc_type,
-    OptionalAttr<Tosa_UnaryOpQuantizationAttr>:$quantization_info
+    OptionalAttr<I32Attr>:$input_zp,
+    OptionalAttr<I32Attr>:$output_zp
   );
 
   let results = (outs
@@ -237,7 +238,8 @@ def Tosa_FullyConnectedOp : Tosa_InferShapedTypeOp<"fully_connected"> {
     Tosa_Tensor2D:$input,
     TosaTensorRankOf<[Tosa_Weight], [2]>:$weight,
     Tosa_Tensor1D:$bias,
-    OptionalAttr<Tosa_ConvOpQuantizationAttr>:$quantization_info
+    OptionalAttr<I32Attr>:$input_zp,
+    OptionalAttr<I32Attr>:$weight_zp
   );
 
   let results = (outs
@@ -263,7 +265,8 @@ def Tosa_MatMulOp : Tosa_InferShapedTypeOp<"matmul"> {
   let arguments = (ins
     Tosa_Tensor3D:$a,
     Tosa_Tensor3D:$b,
-    OptionalAttr<Tosa_MatMulOpQuantizationAttr>:$quantization_info
+    OptionalAttr<I32Attr>:$a_zp,
+    OptionalAttr<I32Attr>:$b_zp
   );
 
   let results = (outs
@@ -1114,7 +1117,8 @@ def Tosa_NegateOp : Tosa_ElementwiseUnaryOp<"negate"> {
 
   let arguments = (ins
       Tosa_Tensor:$input1,
-      OptionalAttr<Tosa_UnaryOpQuantizationAttr>:$quantization_info
+      OptionalAttr<I32Attr>:$input1_zp,
+      OptionalAttr<I32Attr>:$output_zp
   );
 
   let results = (outs
@@ -1589,7 +1593,7 @@ def Tosa_PadOp : Tosa_InferShapedTypeOp<"pad"> {
     Tosa_RankedTensor:$input1,
     Tosa_Shape:$padding,
     Optional<Tosa_ScalarTensor>:$pad_const,
-    OptionalAttr<Tosa_PadOpQuantizationAttr>:$quantization_info
+    OptionalAttr<I32Attr>:$input_zp
   );
 
   let results = (outs
@@ -2071,17 +2075,17 @@ def Tosa_IfOp : Tosa_Op<"cond_if",
   }];
 
   let arguments = (ins
-    Tosa_I1Tensor:$cond,
+    Tosa_I1Tensor:$condition,
     Variadic<Tosa_Tensor>:$inputs
   );
 
   let results = (outs
-    Variadic<Tosa_Tensor>:$output
+    Variadic<Tosa_Tensor>:$output_list
   );
 
   let regions = (region
-    SizedRegion<1>:$then_branch,
-    SizedRegion<1>:$else_branch
+    SizedRegion<1>:$then_graph,
+    SizedRegion<1>:$else_graph
   );
 
   let hasCustomAssemblyFormat = 1;
@@ -2108,16 +2112,16 @@ def Tosa_WhileOp : Tosa_Op<"while_loop", [
   }];
 
   let arguments = (ins
-    Variadic<Tosa_Tensor>:$inputs
+    Variadic<Tosa_Tensor>:$input_list
   );
 
   let results = (outs
-    Variadic<Tosa_Tensor>:$output
+    Variadic<Tosa_Tensor>:$output_list
   );
 
   let regions = (region
-    SizedRegion<1>:$cond,
-    SizedRegion<1>:$body
+    SizedRegion<1>:$cond_graph,
+    SizedRegion<1>:$body_graph
   );
 
   let hasCustomAssemblyFormat = 1;
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index b0eb2d6cbc30b6..449baad0edeafe 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -141,63 +141,67 @@ static Value createLinalgBodyCalculationForElementwiseOp(
   }
 
   // tosa::NegateOp
-  if (isa<tosa::NegateOp>(op) && isa<FloatType>(elementTy))
-    return rewriter.create<arith::NegFOp>(loc, resultTypes, args);
+  if (isa<tosa::NegateOp>(op)) {
+    if (isa<FloatType>(elementTy))
+      return rewriter.create<arith::NegFOp>(loc, resultTypes, args);
 
-  if (isa<tosa::NegateOp>(op) && isa<IntegerType>(elementTy)) {
-    int64_t inZp = 0, outZp = 0;
+    auto inputZpAttr = cast<tosa::NegateOp>(op).getInput1ZpAttr();
+    auto outputZpAttr = cast<tosa::NegateOp>(op).getOutputZpAttr();
+    int32_t inputZpVal = inputZpAttr ? inputZpAttr.getInt() : 0;
+    int32_t outputZpVal = outputZpAttr ? outputZpAttr.getInt() : 0;
 
-    if (cast<tosa::NegateOp>(op).getQuantizationInfo()) {
-      auto quantizationInfo = cast<tosa::NegateOp>(op).getQuantizationInfo();
-      inZp = quantizationInfo.value().getInputZp();
-      outZp = quantizationInfo.value().getOutputZp();
-    }
-
-    int32_t inputBitWidth = elementTy.getIntOrFloatBitWidth();
-    if (!inZp && !outZp) {
+    if (isa<IntegerType>(elementTy) && inputZpVal == 0 && outputZpVal == 0) {
       auto constant = rewriter.create<arith::ConstantOp>(
           loc, IntegerAttr::get(elementTy, 0));
       return rewriter.create<arith::SubIOp>(loc, resultTypes, constant,
                                             args[0]);
     }
 
-    // Compute the maximum value that can occur in the intermediate buffer.
-    int64_t zpAdd = inZp + outZp;
-    int64_t maxValue = APInt::getSignedMaxValue(inputBitWidth).getSExtValue() +
-                       std::abs(zpAdd) + 1;
-
-    // Convert that maximum value into the maximum bitwidth needed to represent
-    // it. We assume 48-bit numbers may be supported further in the pipeline.
-    int intermediateBitWidth = 64;
-    if (maxValue <= APInt::getSignedMaxValue(16).getSExtValue()) {
-      intermediateBitWidth = 16;
-    } else if (maxValue <= APInt::getSignedMaxValue(32).getSExtValue()) {
-      intermediateBitWidth = 32;
-    } else if (maxValue <= APInt::getSignedMaxValue(48).getSExtValue()) {
-      intermediateBitWidth = 48;
-    }
+    if (isa<IntegerType>(elementTy) && (inputZpVal != 0 || outputZpVal != 0)) {
+      int32_t inputBitWidth = elementTy.getIntOrFloatBitWidth();
+      int64_t inZp = inputZpVal;
+      int64_t outZp = outputZpVal;
+
+      // Compute the maximum value that can occur in the intermediate buffer.
+      int64_t zpAdd = inZp + outZp;
+      int64_t maxValue =
+          APInt::getSignedMaxValue(inputBitWidth).getSExtValue() +
+          std::abs(zpAdd) + 1;
+
+      // Convert that maximum value into the maximum bitwidth needed to
+      // represent it. We assume 48-bit numbers may be supported further in the
+      // pipeline.
+      int intermediateBitWidth = 64;
+      if (maxValue <= APInt::getSignedMaxValue(16).getSExtValue()) {
+        intermediateBitWidth = 16;
+      } else if (maxValue <= APInt::getSignedMaxValue(32).getSExtValue()) {
+        intermediateBitWidth = 32;
+      } else if (maxValue <= APInt::getSignedMaxValue(48).getSExtValue()) {
+        intermediateBitWidth = 48;
+      }
 
-    Type intermediateType = rewriter.getIntegerType(intermediateBitWidth);
-    Value zpAddValue = rewriter.create<arith::ConstantOp>(
-        loc, rewriter.getIntegerAttr(intermediateType, zpAdd));
-
-    // The negation can be applied by doing:
-    //  outputValue = inZp + outZp - inputValue
-    auto ext = rewriter.create<arith::ExtSIOp>(loc, intermediateType, args[0]);
-    auto sub = rewriter.create<arith::SubIOp>(loc, zpAddValue, ext);
-
-    // Clamp to the negation range.
-    Value min = rewriter.create<arith::ConstantIntOp>(
-        loc, APInt::getSignedMinValue(inputBitWidth).getSExtValue(),
-        intermediateType);
-    Value max = rewriter.create<arith::ConstantIntOp>(
-        loc, APInt::getSignedMaxValue(inputBitWidth).getSExtValue(),
-        intermediateType);
-    auto clamp =
-        clampIntHelper(loc, sub, min, max, rewriter, /*isUnsigned=*/false);
-
-    // Truncate to the final value.
-    return rewriter.create<arith::TruncIOp>(loc, elementTy, clamp);
+      Type intermediateType = rewriter.getIntegerType(intermediateBitWidth);
+      Value zpAddValue = rewriter.create<arith::ConstantOp>(
+          loc, rewriter.getIntegerAttr(intermediateType, zpAdd));
+
+      // The negation can be applied by doing:
+      //  outputValue = inZp + outZp - inputValue
+      auto ext =
+          rewriter.create<arith::ExtSIOp>(loc, intermediateType, args[0]);
+      auto sub = rewriter.create<arith::SubIOp>(loc, zpAddValue, ext);
+
+      // Clamp to the negation range.
+      Value min = rewriter.create<arith::ConstantIntOp>(
+          loc, APInt::getSignedMinValue(inputBitWidth).getSExtValue(),
+          intermediateType);
+      Value max = rewriter.create<arith::ConstantIntOp>(
+          loc, APInt::getSignedMaxValue(inputBitWidth).getSExtValue(),
+          intermediateType);
+      auto clamp = clampIntHelper(loc, sub, min, max, rewriter, false);
+
+      // Truncate to the final value.
+      return rewriter.create<arith::TruncIOp>(loc, elementTy, clamp);
+    }
   }
 
   // tosa::BitwiseAndOp
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index cf9852e05cf7c9..1e02301f7c23d5 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -590,18 +590,15 @@ class MatMulConverter : public OpConversionPattern<tosa::MatMulOp> {
                            .create<linalg::FillOp>(loc, ValueRange{zero},
                                                    ValueRange{emptyTensor})
                            .result();
-    if (!op.getQuantizationInfo()) {
+    if (!op.getAZp() && !op.getBZp()) {
       rewriter.replaceOpWithNewOp<linalg::BatchMatmulOp>(
           op, TypeRange{op.getType()},
           ValueRange{adaptor.getA(), adaptor.getB()}, ValueRange{zeroTensor});
       return success();
     }
 
-    auto quantizationInfo = *op.getQuantizationInfo();
-    auto aZp = rewriter.create<arith::ConstantOp>(
-        loc, rewriter.getI32IntegerAttr(quantizationInfo.getAZp()));
-    auto bZp = rewriter.create<arith::ConstantOp>(
-        loc, rewriter.getI32IntegerAttr(quantizationInfo.getBZp()));
+    auto aZp = rewriter.create<arith::ConstantOp>(loc, op.getAZpAttr());
+    auto bZp = rewriter.create<arith::ConstantOp>(loc, op.getBZpAttr());
     rewriter.replaceOpWithNewOp<linalg::QuantizedBatchMatmulOp>(
         op, TypeRange{op.getType()},
         ValueRange{adaptor.getA(), adaptor.getB(), aZp, bZp}, zeroTensor);
@@ -661,7 +658,7 @@ class FullyConnectedConverter
     Value broadcastBias =
         linalgBroadcastAndMaybeExtSI(rewriter, loc, bias, biasEmptyTensor);
 
-    if (!op.getQuantizationInfo()) {
+    if (!op.getInputZp() && !op.getWeightZp()) {
       Value matmul = rewriter
                          .create<linalg::MatmulOp>(
                              loc, TypeRange{op.getType()},
@@ -672,11 +669,8 @@ class FullyConnectedConverter
       return success();
     }
 
-    auto quantizationInfo = *op.getQuantizationInfo();
-    auto inputZp = rewriter.create<arith::ConstantOp>(
-        loc, rewriter.getI32IntegerAttr(quantizationInfo.getInputZp()));
-    auto outputZp = rewriter.create<arith::ConstantOp>(
-        loc, rewriter.getI32IntegerAttr(quantizationInfo.getWeightZp()));
+    auto inputZp = rewriter.create<arith::ConstantOp>(loc, op.getInputZpAttr());
+    auto outputZp = rewriter.create<arith::ConstantOp>(loc, op.getWeightZpAttr());
     Value matmul =
         rewriter
             .create<linalg::QuantizedMatmulOp>(
@@ -958,10 +952,9 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
 
             // If we have quantization information we need to apply an offset
             // for the input zp value.
-            if (op.getQuantizationInfo()) {
-              auto quantizationInfo = *op.getQuantizationInfo();
+            if (op.getInputZp()) {
               auto inputZp = rewriter.create<arith::ConstantOp>(
-                  loc, b.getIntegerAttr(accETy, quantizationInfo.getInputZp()));
+                  loc, op.getInputZpAttr());
               Value offset =
                   rewriter.create<arith::MulIOp>(loc, accETy, count, inputZp);
               poolVal =
@@ -1013,11 +1006,9 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
 
             // If we have quantization information we need to apply output
             // zeropoint.
-            if (op.getQuantizationInfo()) {
-              auto quantizationInfo = *op.getQuantizationInfo();
-              auto outputZp = rewriter.create<arith::ConstantOp>(
-                  loc, b.getIntegerAttr(scaled.getType(),
-                                        quantizationInfo.getOutputZp()));
+            if (op.getOutputZp()) {
+              auto outputZp =
+                  rewriter.create<arith::ConstantOp>(loc, op.getOutputZpAttr());
               scaled = rewriter.create<arith::AddIOp>(loc, scaled, outputZp)
                            .getResult();
             }
diff --git a/mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp b/mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp
index 9139bf191fdf11..80c58bdc0550cc 100644
--- a/mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp
+++ b/mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp
@@ -68,13 +68,13 @@ class IfOpConverter : public OpRewritePattern<tosa::IfOp> {
   LogicalResult matchAndRewrite(tosa::IfOp op,
                                 PatternRewriter &rewriter) const final {
     auto condition =
-        rewriter.create<tensor::ExtractOp>(op.getLoc(), op.getCond());
+        rewriter.create<tensor::ExtractOp>(op.getLoc(), op.getCondition());
     auto newIf = rewriter.create<scf::IfOp>(op.getLoc(), op.getResultTypes(),
                                             condition, true);
 
-    inlineIfCase(op.getThenBranch(), newIf.getThenRegion(), op.getInputs(),
+    inlineIfCase(op.getThenGraph(), newIf.getThenRegion(), op.getInputs(),
                  rewriter);
-    inlineIfCase(op.getElseBranch(), newIf.getElseRegion(), op.getInputs(),
+    inlineIfCase(op.getElseGraph(), newIf.getElseRegion(), op.getInputs(),
                  rewriter);
 
     rewriter.replaceOp(op, newIf.getResults());
@@ -158,12 +158,12 @@ class WhileOpConverter : public OpRewritePattern<tosa::WhileOp> {
   LogicalResult matchAndRewrite(tosa::WhileOp op,
                                 PatternRewriter &rewriter) const final {
     auto newWhile = rewriter.create<scf::WhileOp>(
-        op.getLoc(), op.getResultTypes(), op.getInputs());
+        op.getLoc(), op.getResultTypes(), op.getInputList());
     rewriter.createBlock(&newWhile.getBefore());
     rewriter.createBlock(&newWhile.getAfter());
 
-    inlineWhileCase(op.getCond(), newWhile.getBefore(), rewriter, true);
-    inlineWhileCase(op.getBody(), newWhile.getAfter(), rewriter, false);
+    inlineWhileCase(op.getCondGraph(), newWhile.getBefore(), rewriter, true);
+    inlineWhileCase(op.getBodyGraph(), newWhile.getAfter(), rewriter, false);
 
     rewriter.replaceOp(op, newWhile.getResults());
 
diff --git a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
index c4b787d5c865b0..2a9b4d111bdfa2 100644
--- a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
+++ b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
@@ -358,10 +358,10 @@ class PadConverter : public OpConversionPattern<tosa::PadOp> {
       TypedAttr constantAttr;
       if (isa<FloatType>(elementTy)) {
         constantAttr = rewriter.getFloatAttr(elementTy, 0.0);
-      } else if (isa<IntegerType>(elementTy) && !padOp.getQuantizationInfo()) {
+      } else if (isa<IntegerType>(elementTy) && !padOp.getInputZpAttr()) {
         constantAttr = rewriter.getIntegerAttr(elementTy, 0);
-      } else if (isa<IntegerType>(elementTy) && padOp.getQuantizationInfo()) {
-        int64_t value = padOp.getQuantizationInfo()->getInputZp();
+      } else if (isa<IntegerType>(elementTy) && padOp.getInputZpAttr()) {
+        int64_t value = padOp.getInputZpAttr().getInt();
         constantAttr = rewriter.getIntegerAttr(elementTy, value);
       }
       if (constantAttr)
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 9d36947b4352bb..8e22c879753a33 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -207,10 +207,10 @@ struct MaterializePadValue : public OpRewritePattern<tosa::PadOp> {
     Attribute constantAttr;
     if (llvm::isa<FloatType>(elementTy)) {
       constantAttr = rewriter.getFloatAttr(elementTy, 0.0);
-    } else if (llvm::isa<IntegerType>(elementTy) && !op.getQuantizationInfo()) {
+    } else if (llvm::isa<IntegerType>(elementTy) && !op.getInputZpAttr()) {
       constantAttr = rewriter.getIntegerAttr(elementTy, 0);
-    } else if (llvm::isa<IntegerType>(elementTy) && op.getQuantizationInfo()) {
-      auto value = op.getQuantizationInfo()->getInputZp();
+    } else if (llvm::isa<IntegerType>(elementTy) && op.getInputZpAttr()) {
+      int64_t value = op.getInputZpAttr().getInt();
       constantAttr = rewriter.getIntegerAttr(elementTy, value);
     }
 
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index e8b28906135edf..9bde6a85935255 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -124,7 +124,9 @@ struct TosaDialectBytecodeInterface : public BytecodeDialectInterface {
 //===----------------------------------------------------------------------===//
 
 /// Returns the while loop body.
-SmallVector<Region *> tosa::WhileOp::getLoopRegions() { return {&getBody()}; }
+SmallVector<Region *> tosa::WhileOp::getLoopRegions() {
+  return {&getBodyGraph()};
+}
 
 //===----------------------------------------------------------------------===//
 // Tosa dialect initialization.
@@ -271,11 +273,11 @@ static LogicalResult verifyConvOp(T op) {
     }
   }
 
-  bool inputIsQuant = !llvm::isa<FloatType>(inputEType);
-  bool weightIsQuant = !llvm::isa<FloatType>(weightEType);
+  bool inputIsFloat = llvm::isa<FloatType>(inputEType);
+  bool weightIsFloat = llvm::isa<FloatType>(weightEType);
 
-  // Either both must be quantized or both unquantized.
-  if (inputIsQuant != weightIsQuant) {
+  // Either both must be float or both non-float.
+  if (inputIsFloat != weightIsFloat) {
     op.emitOpError(
         "expect both input and weight to be float or not together, got ")
         << inputEType << " and " << weightEType;
@@ -527,7 +529,12 @@ static void buildTransConvOpWithQuantInfo(
   auto quantAttr = ::buildConvOpQuantizationAttr(builder, input, weight);
 
   if (quantAttr) {
-    result.addAttribute("quantization_info", quantAttr);
+    result.addAttribute("input_zp",
+                        builder.getI32IntegerAttr(
+                            static_cast<int32_t>(quantAttr.getInputZp())));
+    result.addAttribute("weight_zp",
+                        builder.getI32IntegerAttr(
+                            static_cast<int32_t>(quantAttr.getWeightZp())));
     result.addTypes(
         buildConvOpResultTypeInfo(builder, outputType, input, weight));
   } else {
@@ -563,7 +570,10 @@ static void buildMatMulOpWithQuantInfo(OpBuilder &builder,
   auto quantAttr = ::buildMatMulOpQuantizationAttr(builder, a, b);
 
   if (quantAttr) {
-    result.addAttribute("quantization_info", quantAttr);
+    result.addAttribute("a_zp", builder.getI32IntegerAttr(
+                                    static_cast<int32_t>(quantAttr.getAZp())));
+    result.addAttribute("b_zp", builder.getI32IntegerAttr(
+                                    static_cast<int32_t>(quantAttr.getBZp())));
 
     auto inputType = llvm::dyn_cast<ShapedType>(a.getType());
     assert(inputType && "Input must be a shaped tensor type!");
@@ -603,8 +613,14 @@ buildAvgPool2dOpWithQuantInfo(OpBuilder &builder, OperationState &result,
   result.addAttribute("pad", pad);
   result.addAttribute("acc_type", accType);
   auto quantAttr = buildUnaryOpQuantizationAttr(builder, input, outputType);
-  if (quantAttr)
-    result.addAttribute("quantization_info", quantAttr);
+  if (quantAttr) {
+    result.addAttribute("input_zp",
+                        builder.getI32IntegerAttr(
+                            static_cast<int32_t>(quantAttr.getInputZp())));
+    result.addAttribute("output_zp",
+                        builder.getI32IntegerAttr(
+                            static_cast<int32_t>(quantAttr.getOutputZp())));
+  }
   result.types.push_back(outputType);
 }
 
@@ -616,8 +632,15 @@ static void buildUnaryOpWithQuantInfo(OpBuilder &builder,
                                       Value input) {
   result.addOperands(input);
   auto quantAttr = buildUnaryOpQuantizationAttr(builder, input, outputType);
-  if (quantAttr)
-    result.addAttribute("quantization_info", quantAttr);
+  if (quantAttr) {
+    // note: negateOp has attributes input1_zp and output_zp
+    result.addAttribute("input1_zp",
+                     ...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Feb 3, 2025

@llvm/pr-subscribers-mlir-linalg

Author: Jack Frankland (FranklandJack)

Changes

Removed the TOSA quantization attribute used in various MLIR TOSA dialect operations in favour of using builtin attributes.

Update any lit tests, conversions and transformations appropriately.

Rename operands as follows to align with the TOSA-v1.0 specification:

  • cond -> condition
  • then_branch -> then_graph
  • else_branch -> else_graph
  • inputs -> input_list
  • output -> output_list
  • cond -> cond_graph
  • body -> body_graph

Patch is 39.24 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/125479.diff

15 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td (+17-13)
  • (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp (+52-48)
  • (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp (+11-20)
  • (modified) mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp (+6-6)
  • (modified) mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp (+3-3)
  • (modified) mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp (+3-3)
  • (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+56-27)
  • (modified) mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp (+3-3)
  • (modified) mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp (+2-3)
  • (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir (+2-2)
  • (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir (+19-9)
  • (modified) mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir (+1-1)
  • (modified) mlir/test/Dialect/Tosa/canonicalize.mlir (+1-1)
  • (modified) mlir/test/Dialect/Tosa/tosa-decompose-conv2d.mlir (+3-3)
  • (modified) mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir (+5-5)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 819547855d1015..fef0f2d98d95c6 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -78,7 +78,8 @@ def Tosa_AvgPool2dOp : Tosa_InferShapedTypeOp<"avg_pool2d"> {
     Tosa_IntArrayAttr2:$stride,
     Tosa_IntArrayAttr4:$pad,
     TypeAttrOf<Tosa_AccType>:$acc_type,
-    OptionalAttr<Tosa_UnaryOpQuantizationAttr>:$quantization_info
+    OptionalAttr<I32Attr>:$input_zp,
+    OptionalAttr<I32Attr>:$output_zp
   );
 
   let results = (outs
@@ -237,7 +238,8 @@ def Tosa_FullyConnectedOp : Tosa_InferShapedTypeOp<"fully_connected"> {
     Tosa_Tensor2D:$input,
     TosaTensorRankOf<[Tosa_Weight], [2]>:$weight,
     Tosa_Tensor1D:$bias,
-    OptionalAttr<Tosa_ConvOpQuantizationAttr>:$quantization_info
+    OptionalAttr<I32Attr>:$input_zp,
+    OptionalAttr<I32Attr>:$weight_zp
   );
 
   let results = (outs
@@ -263,7 +265,8 @@ def Tosa_MatMulOp : Tosa_InferShapedTypeOp<"matmul"> {
   let arguments = (ins
     Tosa_Tensor3D:$a,
     Tosa_Tensor3D:$b,
-    OptionalAttr<Tosa_MatMulOpQuantizationAttr>:$quantization_info
+    OptionalAttr<I32Attr>:$a_zp,
+    OptionalAttr<I32Attr>:$b_zp
   );
 
   let results = (outs
@@ -1114,7 +1117,8 @@ def Tosa_NegateOp : Tosa_ElementwiseUnaryOp<"negate"> {
 
   let arguments = (ins
       Tosa_Tensor:$input1,
-      OptionalAttr<Tosa_UnaryOpQuantizationAttr>:$quantization_info
+      OptionalAttr<I32Attr>:$input1_zp,
+      OptionalAttr<I32Attr>:$output_zp
   );
 
   let results = (outs
@@ -1589,7 +1593,7 @@ def Tosa_PadOp : Tosa_InferShapedTypeOp<"pad"> {
     Tosa_RankedTensor:$input1,
     Tosa_Shape:$padding,
     Optional<Tosa_ScalarTensor>:$pad_const,
-    OptionalAttr<Tosa_PadOpQuantizationAttr>:$quantization_info
+    OptionalAttr<I32Attr>:$input_zp
   );
 
   let results = (outs
@@ -2071,17 +2075,17 @@ def Tosa_IfOp : Tosa_Op<"cond_if",
   }];
 
   let arguments = (ins
-    Tosa_I1Tensor:$cond,
+    Tosa_I1Tensor:$condition,
     Variadic<Tosa_Tensor>:$inputs
   );
 
   let results = (outs
-    Variadic<Tosa_Tensor>:$output
+    Variadic<Tosa_Tensor>:$output_list
   );
 
   let regions = (region
-    SizedRegion<1>:$then_branch,
-    SizedRegion<1>:$else_branch
+    SizedRegion<1>:$then_graph,
+    SizedRegion<1>:$else_graph
   );
 
   let hasCustomAssemblyFormat = 1;
@@ -2108,16 +2112,16 @@ def Tosa_WhileOp : Tosa_Op<"while_loop", [
   }];
 
   let arguments = (ins
-    Variadic<Tosa_Tensor>:$inputs
+    Variadic<Tosa_Tensor>:$input_list
   );
 
   let results = (outs
-    Variadic<Tosa_Tensor>:$output
+    Variadic<Tosa_Tensor>:$output_list
   );
 
   let regions = (region
-    SizedRegion<1>:$cond,
-    SizedRegion<1>:$body
+    SizedRegion<1>:$cond_graph,
+    SizedRegion<1>:$body_graph
   );
 
   let hasCustomAssemblyFormat = 1;
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index b0eb2d6cbc30b6..449baad0edeafe 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -141,63 +141,67 @@ static Value createLinalgBodyCalculationForElementwiseOp(
   }
 
   // tosa::NegateOp
-  if (isa<tosa::NegateOp>(op) && isa<FloatType>(elementTy))
-    return rewriter.create<arith::NegFOp>(loc, resultTypes, args);
+  if (isa<tosa::NegateOp>(op)) {
+    if (isa<FloatType>(elementTy))
+      return rewriter.create<arith::NegFOp>(loc, resultTypes, args);
 
-  if (isa<tosa::NegateOp>(op) && isa<IntegerType>(elementTy)) {
-    int64_t inZp = 0, outZp = 0;
+    auto inputZpAttr = cast<tosa::NegateOp>(op).getInput1ZpAttr();
+    auto outputZpAttr = cast<tosa::NegateOp>(op).getOutputZpAttr();
+    int32_t inputZpVal = inputZpAttr ? inputZpAttr.getInt() : 0;
+    int32_t outputZpVal = outputZpAttr ? outputZpAttr.getInt() : 0;
 
-    if (cast<tosa::NegateOp>(op).getQuantizationInfo()) {
-      auto quantizationInfo = cast<tosa::NegateOp>(op).getQuantizationInfo();
-      inZp = quantizationInfo.value().getInputZp();
-      outZp = quantizationInfo.value().getOutputZp();
-    }
-
-    int32_t inputBitWidth = elementTy.getIntOrFloatBitWidth();
-    if (!inZp && !outZp) {
+    if (isa<IntegerType>(elementTy) && inputZpVal == 0 && outputZpVal == 0) {
       auto constant = rewriter.create<arith::ConstantOp>(
           loc, IntegerAttr::get(elementTy, 0));
       return rewriter.create<arith::SubIOp>(loc, resultTypes, constant,
                                             args[0]);
     }
 
-    // Compute the maximum value that can occur in the intermediate buffer.
-    int64_t zpAdd = inZp + outZp;
-    int64_t maxValue = APInt::getSignedMaxValue(inputBitWidth).getSExtValue() +
-                       std::abs(zpAdd) + 1;
-
-    // Convert that maximum value into the maximum bitwidth needed to represent
-    // it. We assume 48-bit numbers may be supported further in the pipeline.
-    int intermediateBitWidth = 64;
-    if (maxValue <= APInt::getSignedMaxValue(16).getSExtValue()) {
-      intermediateBitWidth = 16;
-    } else if (maxValue <= APInt::getSignedMaxValue(32).getSExtValue()) {
-      intermediateBitWidth = 32;
-    } else if (maxValue <= APInt::getSignedMaxValue(48).getSExtValue()) {
-      intermediateBitWidth = 48;
-    }
+    if (isa<IntegerType>(elementTy) && (inputZpVal != 0 || outputZpVal != 0)) {
+      int32_t inputBitWidth = elementTy.getIntOrFloatBitWidth();
+      int64_t inZp = inputZpVal;
+      int64_t outZp = outputZpVal;
+
+      // Compute the maximum value that can occur in the intermediate buffer.
+      int64_t zpAdd = inZp + outZp;
+      int64_t maxValue =
+          APInt::getSignedMaxValue(inputBitWidth).getSExtValue() +
+          std::abs(zpAdd) + 1;
+
+      // Convert that maximum value into the maximum bitwidth needed to
+      // represent it. We assume 48-bit numbers may be supported further in the
+      // pipeline.
+      int intermediateBitWidth = 64;
+      if (maxValue <= APInt::getSignedMaxValue(16).getSExtValue()) {
+        intermediateBitWidth = 16;
+      } else if (maxValue <= APInt::getSignedMaxValue(32).getSExtValue()) {
+        intermediateBitWidth = 32;
+      } else if (maxValue <= APInt::getSignedMaxValue(48).getSExtValue()) {
+        intermediateBitWidth = 48;
+      }
 
-    Type intermediateType = rewriter.getIntegerType(intermediateBitWidth);
-    Value zpAddValue = rewriter.create<arith::ConstantOp>(
-        loc, rewriter.getIntegerAttr(intermediateType, zpAdd));
-
-    // The negation can be applied by doing:
-    //  outputValue = inZp + outZp - inputValue
-    auto ext = rewriter.create<arith::ExtSIOp>(loc, intermediateType, args[0]);
-    auto sub = rewriter.create<arith::SubIOp>(loc, zpAddValue, ext);
-
-    // Clamp to the negation range.
-    Value min = rewriter.create<arith::ConstantIntOp>(
-        loc, APInt::getSignedMinValue(inputBitWidth).getSExtValue(),
-        intermediateType);
-    Value max = rewriter.create<arith::ConstantIntOp>(
-        loc, APInt::getSignedMaxValue(inputBitWidth).getSExtValue(),
-        intermediateType);
-    auto clamp =
-        clampIntHelper(loc, sub, min, max, rewriter, /*isUnsigned=*/false);
-
-    // Truncate to the final value.
-    return rewriter.create<arith::TruncIOp>(loc, elementTy, clamp);
+      Type intermediateType = rewriter.getIntegerType(intermediateBitWidth);
+      Value zpAddValue = rewriter.create<arith::ConstantOp>(
+          loc, rewriter.getIntegerAttr(intermediateType, zpAdd));
+
+      // The negation can be applied by doing:
+      //  outputValue = inZp + outZp - inputValue
+      auto ext =
+          rewriter.create<arith::ExtSIOp>(loc, intermediateType, args[0]);
+      auto sub = rewriter.create<arith::SubIOp>(loc, zpAddValue, ext);
+
+      // Clamp to the negation range.
+      Value min = rewriter.create<arith::ConstantIntOp>(
+          loc, APInt::getSignedMinValue(inputBitWidth).getSExtValue(),
+          intermediateType);
+      Value max = rewriter.create<arith::ConstantIntOp>(
+          loc, APInt::getSignedMaxValue(inputBitWidth).getSExtValue(),
+          intermediateType);
+      auto clamp = clampIntHelper(loc, sub, min, max, rewriter, false);
+
+      // Truncate to the final value.
+      return rewriter.create<arith::TruncIOp>(loc, elementTy, clamp);
+    }
   }
 
   // tosa::BitwiseAndOp
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index cf9852e05cf7c9..1e02301f7c23d5 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -590,18 +590,15 @@ class MatMulConverter : public OpConversionPattern<tosa::MatMulOp> {
                            .create<linalg::FillOp>(loc, ValueRange{zero},
                                                    ValueRange{emptyTensor})
                            .result();
-    if (!op.getQuantizationInfo()) {
+    if (!op.getAZp() && !op.getBZp()) {
       rewriter.replaceOpWithNewOp<linalg::BatchMatmulOp>(
           op, TypeRange{op.getType()},
           ValueRange{adaptor.getA(), adaptor.getB()}, ValueRange{zeroTensor});
       return success();
     }
 
-    auto quantizationInfo = *op.getQuantizationInfo();
-    auto aZp = rewriter.create<arith::ConstantOp>(
-        loc, rewriter.getI32IntegerAttr(quantizationInfo.getAZp()));
-    auto bZp = rewriter.create<arith::ConstantOp>(
-        loc, rewriter.getI32IntegerAttr(quantizationInfo.getBZp()));
+    auto aZp = rewriter.create<arith::ConstantOp>(loc, op.getAZpAttr());
+    auto bZp = rewriter.create<arith::ConstantOp>(loc, op.getBZpAttr());
     rewriter.replaceOpWithNewOp<linalg::QuantizedBatchMatmulOp>(
         op, TypeRange{op.getType()},
         ValueRange{adaptor.getA(), adaptor.getB(), aZp, bZp}, zeroTensor);
@@ -661,7 +658,7 @@ class FullyConnectedConverter
     Value broadcastBias =
         linalgBroadcastAndMaybeExtSI(rewriter, loc, bias, biasEmptyTensor);
 
-    if (!op.getQuantizationInfo()) {
+    if (!op.getInputZp() && !op.getWeightZp()) {
       Value matmul = rewriter
                          .create<linalg::MatmulOp>(
                              loc, TypeRange{op.getType()},
@@ -672,11 +669,8 @@ class FullyConnectedConverter
       return success();
     }
 
-    auto quantizationInfo = *op.getQuantizationInfo();
-    auto inputZp = rewriter.create<arith::ConstantOp>(
-        loc, rewriter.getI32IntegerAttr(quantizationInfo.getInputZp()));
-    auto outputZp = rewriter.create<arith::ConstantOp>(
-        loc, rewriter.getI32IntegerAttr(quantizationInfo.getWeightZp()));
+    auto inputZp = rewriter.create<arith::ConstantOp>(loc, op.getInputZpAttr());
+    auto outputZp = rewriter.create<arith::ConstantOp>(loc, op.getWeightZpAttr());
     Value matmul =
         rewriter
             .create<linalg::QuantizedMatmulOp>(
@@ -958,10 +952,9 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
 
             // If we have quantization information we need to apply an offset
             // for the input zp value.
-            if (op.getQuantizationInfo()) {
-              auto quantizationInfo = *op.getQuantizationInfo();
+            if (op.getInputZp()) {
               auto inputZp = rewriter.create<arith::ConstantOp>(
-                  loc, b.getIntegerAttr(accETy, quantizationInfo.getInputZp()));
+                  loc, op.getInputZpAttr());
               Value offset =
                   rewriter.create<arith::MulIOp>(loc, accETy, count, inputZp);
               poolVal =
@@ -1013,11 +1006,9 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
 
             // If we have quantization information we need to apply output
             // zeropoint.
-            if (op.getQuantizationInfo()) {
-              auto quantizationInfo = *op.getQuantizationInfo();
-              auto outputZp = rewriter.create<arith::ConstantOp>(
-                  loc, b.getIntegerAttr(scaled.getType(),
-                                        quantizationInfo.getOutputZp()));
+            if (op.getOutputZp()) {
+              auto outputZp =
+                  rewriter.create<arith::ConstantOp>(loc, op.getOutputZpAttr());
               scaled = rewriter.create<arith::AddIOp>(loc, scaled, outputZp)
                            .getResult();
             }
diff --git a/mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp b/mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp
index 9139bf191fdf11..80c58bdc0550cc 100644
--- a/mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp
+++ b/mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp
@@ -68,13 +68,13 @@ class IfOpConverter : public OpRewritePattern<tosa::IfOp> {
   LogicalResult matchAndRewrite(tosa::IfOp op,
                                 PatternRewriter &rewriter) const final {
     auto condition =
-        rewriter.create<tensor::ExtractOp>(op.getLoc(), op.getCond());
+        rewriter.create<tensor::ExtractOp>(op.getLoc(), op.getCondition());
     auto newIf = rewriter.create<scf::IfOp>(op.getLoc(), op.getResultTypes(),
                                             condition, true);
 
-    inlineIfCase(op.getThenBranch(), newIf.getThenRegion(), op.getInputs(),
+    inlineIfCase(op.getThenGraph(), newIf.getThenRegion(), op.getInputs(),
                  rewriter);
-    inlineIfCase(op.getElseBranch(), newIf.getElseRegion(), op.getInputs(),
+    inlineIfCase(op.getElseGraph(), newIf.getElseRegion(), op.getInputs(),
                  rewriter);
 
     rewriter.replaceOp(op, newIf.getResults());
@@ -158,12 +158,12 @@ class WhileOpConverter : public OpRewritePattern<tosa::WhileOp> {
   LogicalResult matchAndRewrite(tosa::WhileOp op,
                                 PatternRewriter &rewriter) const final {
     auto newWhile = rewriter.create<scf::WhileOp>(
-        op.getLoc(), op.getResultTypes(), op.getInputs());
+        op.getLoc(), op.getResultTypes(), op.getInputList());
     rewriter.createBlock(&newWhile.getBefore());
     rewriter.createBlock(&newWhile.getAfter());
 
-    inlineWhileCase(op.getCond(), newWhile.getBefore(), rewriter, true);
-    inlineWhileCase(op.getBody(), newWhile.getAfter(), rewriter, false);
+    inlineWhileCase(op.getCondGraph(), newWhile.getBefore(), rewriter, true);
+    inlineWhileCase(op.getBodyGraph(), newWhile.getAfter(), rewriter, false);
 
     rewriter.replaceOp(op, newWhile.getResults());
 
diff --git a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
index c4b787d5c865b0..2a9b4d111bdfa2 100644
--- a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
+++ b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
@@ -358,10 +358,10 @@ class PadConverter : public OpConversionPattern<tosa::PadOp> {
       TypedAttr constantAttr;
       if (isa<FloatType>(elementTy)) {
         constantAttr = rewriter.getFloatAttr(elementTy, 0.0);
-      } else if (isa<IntegerType>(elementTy) && !padOp.getQuantizationInfo()) {
+      } else if (isa<IntegerType>(elementTy) && !padOp.getInputZpAttr()) {
         constantAttr = rewriter.getIntegerAttr(elementTy, 0);
-      } else if (isa<IntegerType>(elementTy) && padOp.getQuantizationInfo()) {
-        int64_t value = padOp.getQuantizationInfo()->getInputZp();
+      } else if (isa<IntegerType>(elementTy) && padOp.getInputZpAttr()) {
+        int64_t value = padOp.getInputZpAttr().getInt();
         constantAttr = rewriter.getIntegerAttr(elementTy, value);
       }
       if (constantAttr)
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 9d36947b4352bb..8e22c879753a33 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -207,10 +207,10 @@ struct MaterializePadValue : public OpRewritePattern<tosa::PadOp> {
     Attribute constantAttr;
     if (llvm::isa<FloatType>(elementTy)) {
       constantAttr = rewriter.getFloatAttr(elementTy, 0.0);
-    } else if (llvm::isa<IntegerType>(elementTy) && !op.getQuantizationInfo()) {
+    } else if (llvm::isa<IntegerType>(elementTy) && !op.getInputZpAttr()) {
       constantAttr = rewriter.getIntegerAttr(elementTy, 0);
-    } else if (llvm::isa<IntegerType>(elementTy) && op.getQuantizationInfo()) {
-      auto value = op.getQuantizationInfo()->getInputZp();
+    } else if (llvm::isa<IntegerType>(elementTy) && op.getInputZpAttr()) {
+      int64_t value = op.getInputZpAttr().getInt();
       constantAttr = rewriter.getIntegerAttr(elementTy, value);
     }
 
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index e8b28906135edf..9bde6a85935255 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -124,7 +124,9 @@ struct TosaDialectBytecodeInterface : public BytecodeDialectInterface {
 //===----------------------------------------------------------------------===//
 
 /// Returns the while loop body.
-SmallVector<Region *> tosa::WhileOp::getLoopRegions() { return {&getBody()}; }
+SmallVector<Region *> tosa::WhileOp::getLoopRegions() {
+  return {&getBodyGraph()};
+}
 
 //===----------------------------------------------------------------------===//
 // Tosa dialect initialization.
@@ -271,11 +273,11 @@ static LogicalResult verifyConvOp(T op) {
     }
   }
 
-  bool inputIsQuant = !llvm::isa<FloatType>(inputEType);
-  bool weightIsQuant = !llvm::isa<FloatType>(weightEType);
+  bool inputIsFloat = llvm::isa<FloatType>(inputEType);
+  bool weightIsFloat = llvm::isa<FloatType>(weightEType);
 
-  // Either both must be quantized or both unquantized.
-  if (inputIsQuant != weightIsQuant) {
+  // Either both must be float or both non-float.
+  if (inputIsFloat != weightIsFloat) {
     op.emitOpError(
         "expect both input and weight to be float or not together, got ")
         << inputEType << " and " << weightEType;
@@ -527,7 +529,12 @@ static void buildTransConvOpWithQuantInfo(
   auto quantAttr = ::buildConvOpQuantizationAttr(builder, input, weight);
 
   if (quantAttr) {
-    result.addAttribute("quantization_info", quantAttr);
+    result.addAttribute("input_zp",
+                        builder.getI32IntegerAttr(
+                            static_cast<int32_t>(quantAttr.getInputZp())));
+    result.addAttribute("weight_zp",
+                        builder.getI32IntegerAttr(
+                            static_cast<int32_t>(quantAttr.getWeightZp())));
     result.addTypes(
         buildConvOpResultTypeInfo(builder, outputType, input, weight));
   } else {
@@ -563,7 +570,10 @@ static void buildMatMulOpWithQuantInfo(OpBuilder &builder,
   auto quantAttr = ::buildMatMulOpQuantizationAttr(builder, a, b);
 
   if (quantAttr) {
-    result.addAttribute("quantization_info", quantAttr);
+    result.addAttribute("a_zp", builder.getI32IntegerAttr(
+                                    static_cast<int32_t>(quantAttr.getAZp())));
+    result.addAttribute("b_zp", builder.getI32IntegerAttr(
+                                    static_cast<int32_t>(quantAttr.getBZp())));
 
     auto inputType = llvm::dyn_cast<ShapedType>(a.getType());
     assert(inputType && "Input must be a shaped tensor type!");
@@ -603,8 +613,14 @@ buildAvgPool2dOpWithQuantInfo(OpBuilder &builder, OperationState &result,
   result.addAttribute("pad", pad);
   result.addAttribute("acc_type", accType);
   auto quantAttr = buildUnaryOpQuantizationAttr(builder, input, outputType);
-  if (quantAttr)
-    result.addAttribute("quantization_info", quantAttr);
+  if (quantAttr) {
+    result.addAttribute("input_zp",
+                        builder.getI32IntegerAttr(
+                            static_cast<int32_t>(quantAttr.getInputZp())));
+    result.addAttribute("output_zp",
+                        builder.getI32IntegerAttr(
+                            static_cast<int32_t>(quantAttr.getOutputZp())));
+  }
   result.types.push_back(outputType);
 }
 
@@ -616,8 +632,15 @@ static void buildUnaryOpWithQuantInfo(OpBuilder &builder,
                                       Value input) {
   result.addOperands(input);
   auto quantAttr = buildUnaryOpQuantizationAttr(builder, input, outputType);
-  if (quantAttr)
-    result.addAttribute("quantization_info", quantAttr);
+  if (quantAttr) {
+    // note: negateOp has attributes input1_zp and output_zp
+    result.addAttribute("input1_zp",
+                     ...
[truncated]

Copy link

github-actions bot commented Feb 3, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

Copy link
Contributor

@GeorgeARM GeorgeARM left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we split into two patches?
Also, didn't see any tests and verifier changes when having fp and zero-points. Is this already handled?

@FranklandJack
Copy link
Contributor Author

Also, didn't see any tests and verifier changes when having fp and zero-points. Is this already handled?

Not sure I'm completely following here, but I think until these are moved to inputs they will always be integers since they are i32 attributes?

@GeorgeARM
Copy link
Contributor

Also, didn't see any tests and verifier changes when having fp and zero-points. Is this already handled?

Not sure I'm completely following here, but I think until these are moved to inputs they will always be integers since they are i32 attributes?

Do the operators you changed can have fp as inputs/outputs? If yes, do we expect to have a zero point on such configurations? Do we check in the verifier and raise an error when a zero point is provided with fp?

@FranklandJack
Copy link
Contributor Author

Also, didn't see any tests and verifier changes when having fp and zero-points. Is this already handled?

Not sure I'm completely following here, but I think until these are moved to inputs they will always be integers since they are i32 attributes?

Do the operators you changed can have fp as inputs/outputs? If yes, do we expect to have a zero point on such configurations? Do we check in the verifier and raise an error when a zero point is provided with fp?

Ah okay I see. So as far as I can tell no, we don't, apart from maybe FC. We should add these, but since they weren't there before this commit and the quantization attribute was still optional (and given that these will move to inputs shortly anyway) I think this should probably be done in a separate commit.

Removed the TOSA quantization attribute used in various MLIR TOSA
dialect operations in favour of using builtin attributes.

Update any lit tests, conversions and transformations appropriately.

Signed-off-by: Tai Ly <[email protected]>
@FranklandJack FranklandJack merged commit f0b8ff1 into llvm:main Feb 5, 2025
8 checks passed
Icohedron pushed a commit to Icohedron/llvm-project that referenced this pull request Feb 11, 2025
Removed the TOSA quantization attribute used in various MLIR TOSA
dialect operations in favour of using builtin attributes.

Update any lit tests, conversions and transformations appropriately.

Signed-off-by: Tai Ly <[email protected]>
Co-authored-by: Tai Ly <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants