Skip to content

Commit 2778d9d

Browse files
authored
[TOSA] tosa.negate operator lowering update (llvm#107924)
This PR makes tosa.negate op for integer types to use the simplified calculation branch if input_zp and output_zp values are also zero. Signed-off-by: Dmitriy Smirnov <[email protected]>
1 parent a794ee4 commit 2778d9d

File tree

2 files changed

+32
-20
lines changed

2 files changed

+32
-20
lines changed

mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp

+22-17
Original file line numberDiff line numberDiff line change
@@ -139,19 +139,22 @@ static Value createLinalgBodyCalculationForElementwiseOp(
139139
if (isa<tosa::NegateOp>(op) && isa<FloatType>(elementTy))
140140
return rewriter.create<arith::NegFOp>(loc, resultTypes, args);
141141

142-
if (isa<tosa::NegateOp>(op) && isa<IntegerType>(elementTy) &&
143-
!cast<tosa::NegateOp>(op).getQuantizationInfo()) {
144-
auto constant =
145-
rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(elementTy, 0));
146-
return rewriter.create<arith::SubIOp>(loc, resultTypes, constant, args[0]);
147-
}
142+
if (isa<tosa::NegateOp>(op) && isa<IntegerType>(elementTy)) {
143+
int64_t inZp = 0, outZp = 0;
144+
145+
if (cast<tosa::NegateOp>(op).getQuantizationInfo()) {
146+
auto quantizationInfo = cast<tosa::NegateOp>(op).getQuantizationInfo();
147+
inZp = quantizationInfo.value().getInputZp();
148+
outZp = quantizationInfo.value().getOutputZp();
149+
}
148150

149-
if (isa<tosa::NegateOp>(op) && isa<IntegerType>(elementTy) &&
150-
cast<tosa::NegateOp>(op).getQuantizationInfo()) {
151-
auto quantizationInfo = cast<tosa::NegateOp>(op).getQuantizationInfo();
152151
int32_t inputBitWidth = elementTy.getIntOrFloatBitWidth();
153-
int64_t inZp = quantizationInfo.value().getInputZp();
154-
int64_t outZp = quantizationInfo.value().getOutputZp();
152+
if (!inZp && !outZp) {
153+
auto constant = rewriter.create<arith::ConstantOp>(
154+
loc, IntegerAttr::get(elementTy, 0));
155+
return rewriter.create<arith::SubIOp>(loc, resultTypes, constant,
156+
args[0]);
157+
}
155158

156159
// Compute the maximum value that can occur in the intermediate buffer.
157160
int64_t zpAdd = inZp + outZp;
@@ -402,17 +405,19 @@ static Value createLinalgBodyCalculationForElementwiseOp(
402405
if (intTy.isUnsignedInteger()) {
403406
minRepresentable = 0;
404407
if (intTy.getIntOrFloatBitWidth() <= 63) {
405-
maxRepresentable = (int64_t)APInt::getMaxValue(intTy.getIntOrFloatBitWidth())
406-
.getZExtValue();
408+
maxRepresentable =
409+
(int64_t)APInt::getMaxValue(intTy.getIntOrFloatBitWidth())
410+
.getZExtValue();
407411
}
408-
} else if(intTy.getIntOrFloatBitWidth() <= 64) {
412+
} else if (intTy.getIntOrFloatBitWidth() <= 64) {
409413
// Ensure that min & max fit into signed n-bit constants.
410414
minRepresentable = APInt::getSignedMinValue(intTy.getIntOrFloatBitWidth())
411-
.getSExtValue();
415+
.getSExtValue();
412416
maxRepresentable = APInt::getSignedMaxValue(intTy.getIntOrFloatBitWidth())
413-
.getSExtValue();
417+
.getSExtValue();
414418
}
415-
// Ensure that the bounds are representable as n-bit signed/unsigned integers.
419+
// Ensure that the bounds are representable as n-bit signed/unsigned
420+
// integers.
416421
min = std::max(min, minRepresentable);
417422
max = std::max(max, minRepresentable);
418423
min = std::min(min, maxRepresentable);

mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir

+10-3
Original file line numberDiff line numberDiff line change
@@ -857,16 +857,16 @@ func.func @test_bool(%arg0: tensor<1xi1>, %arg1: tensor<1xi1>) -> () {
857857
func.func @test_negate_quantized(%arg0: tensor<1xi8>) -> () {
858858
// CHECK: linalg.generic
859859
// CHECK: ^bb0(%[[BBARG0:.+]]: i8,
860-
// CHECK: [[ZERO:%.+]] = arith.constant 0
860+
// CHECK: [[CNST:%.+]] = arith.constant 7
861861
// CHECK: [[EXT:%.+]] = arith.extsi %[[BBARG0]] : i8 to i16
862-
// CHECK: [[SUB:%.+]] = arith.subi [[ZERO]], [[EXT]]
862+
// CHECK: [[SUB:%.+]] = arith.subi [[CNST]], [[EXT]]
863863
// CHECK: [[MIN:%.+]] = arith.constant -128
864864
// CHECK: [[MAX:%.+]] = arith.constant 127
865865
// CHECK: [[LBOUND:%.+]] = arith.maxsi [[MIN]], [[SUB]]
866866
// CHECK: [[UBOUND:%.+]] = arith.minsi [[MAX]], [[LBOUND]]
867867
// CHECK: [[TRUNC:%.+]] = arith.trunci [[UBOUND]]
868868
// CHECK: linalg.yield [[TRUNC]]
869-
%0 = tosa.negate %arg0 {quantization_info = #tosa.unary_quant<input_zp = 0, output_zp = 0>} : (tensor<1xi8>) -> tensor<1xi8>
869+
%0 = tosa.negate %arg0 {quantization_info = #tosa.unary_quant<input_zp = 0, output_zp = 7>} : (tensor<1xi8>) -> tensor<1xi8>
870870

871871
// CHECK: linalg.generic
872872
// CHECK: ^bb0(%[[BBARG0:.+]]: i8,
@@ -878,6 +878,13 @@ func.func @test_negate_quantized(%arg0: tensor<1xi8>) -> () {
878878
// CHECK: [[EXT:%.+]] = arith.extsi %[[BBARG0]] : i8 to i32
879879
%2 = tosa.negate %arg0 {quantization_info = #tosa.unary_quant<input_zp = 32640, output_zp = 0>} : (tensor<1xi8>) -> tensor<1xi8>
880880

881+
// CHECK: linalg.generic
882+
// CHECK: ^bb0(%[[BBARG0:.+]]: i8,
883+
// CHECK: [[ZERO:%.+]] = arith.constant 0
884+
// CHECK: [[SUB:%.+]] = arith.subi [[ZERO]],
885+
// CHECK: linalg.yield [[SUB]]
886+
%3 = tosa.negate %arg0 {quantization_info = #tosa.unary_quant<input_zp = 0, output_zp = 0>} : (tensor<1xi8>) -> tensor<1xi8>
887+
881888
return
882889
}
883890

0 commit comments

Comments
 (0)