@@ -139,19 +139,22 @@ static Value createLinalgBodyCalculationForElementwiseOp(
139
139
if (isa<tosa::NegateOp>(op) && isa<FloatType>(elementTy))
140
140
return rewriter.create <arith::NegFOp>(loc, resultTypes, args);
141
141
142
- if (isa<tosa::NegateOp>(op) && isa<IntegerType>(elementTy) &&
143
- !cast<tosa::NegateOp>(op).getQuantizationInfo ()) {
144
- auto constant =
145
- rewriter.create <arith::ConstantOp>(loc, IntegerAttr::get (elementTy, 0 ));
146
- return rewriter.create <arith::SubIOp>(loc, resultTypes, constant, args[0 ]);
147
- }
142
+ if (isa<tosa::NegateOp>(op) && isa<IntegerType>(elementTy)) {
143
+ int64_t inZp = 0 , outZp = 0 ;
144
+
145
+ if (cast<tosa::NegateOp>(op).getQuantizationInfo ()) {
146
+ auto quantizationInfo = cast<tosa::NegateOp>(op).getQuantizationInfo ();
147
+ inZp = quantizationInfo.value ().getInputZp ();
148
+ outZp = quantizationInfo.value ().getOutputZp ();
149
+ }
148
150
149
- if (isa<tosa::NegateOp>(op) && isa<IntegerType>(elementTy) &&
150
- cast<tosa::NegateOp>(op).getQuantizationInfo ()) {
151
- auto quantizationInfo = cast<tosa::NegateOp>(op).getQuantizationInfo ();
152
151
int32_t inputBitWidth = elementTy.getIntOrFloatBitWidth ();
153
- int64_t inZp = quantizationInfo.value ().getInputZp ();
154
- int64_t outZp = quantizationInfo.value ().getOutputZp ();
152
+ if (!inZp && !outZp) {
153
+ auto constant = rewriter.create <arith::ConstantOp>(
154
+ loc, IntegerAttr::get (elementTy, 0 ));
155
+ return rewriter.create <arith::SubIOp>(loc, resultTypes, constant,
156
+ args[0 ]);
157
+ }
155
158
156
159
// Compute the maximum value that can occur in the intermediate buffer.
157
160
int64_t zpAdd = inZp + outZp;
@@ -402,17 +405,19 @@ static Value createLinalgBodyCalculationForElementwiseOp(
402
405
if (intTy.isUnsignedInteger ()) {
403
406
minRepresentable = 0 ;
404
407
if (intTy.getIntOrFloatBitWidth () <= 63 ) {
405
- maxRepresentable = (int64_t )APInt::getMaxValue (intTy.getIntOrFloatBitWidth ())
406
- .getZExtValue ();
408
+ maxRepresentable =
409
+ (int64_t )APInt::getMaxValue (intTy.getIntOrFloatBitWidth ())
410
+ .getZExtValue ();
407
411
}
408
- } else if (intTy.getIntOrFloatBitWidth () <= 64 ) {
412
+ } else if (intTy.getIntOrFloatBitWidth () <= 64 ) {
409
413
// Ensure that min & max fit into signed n-bit constants.
410
414
minRepresentable = APInt::getSignedMinValue (intTy.getIntOrFloatBitWidth ())
411
- .getSExtValue ();
415
+ .getSExtValue ();
412
416
maxRepresentable = APInt::getSignedMaxValue (intTy.getIntOrFloatBitWidth ())
413
- .getSExtValue ();
417
+ .getSExtValue ();
414
418
}
415
- // Ensure that the bounds are representable as n-bit signed/unsigned integers.
419
+ // Ensure that the bounds are representable as n-bit signed/unsigned
420
+ // integers.
416
421
min = std::max (min, minRepresentable);
417
422
max = std::max (max, minRepresentable);
418
423
min = std::min (min, maxRepresentable);
0 commit comments