Skip to content

Commit 9688657

Browse files
Tai78641FranklandJack
authored andcommitted
[mlir][tosa] Remove Quantization Attribute
Removed the TOSA quantization attribute used in various MLIR TOSA dialect operations in favour of using builtin attributes. Update any lit tests, conversions and transformations appropriately. Signed-off-by: Tai Ly <[email protected]>
1 parent d9af03b commit 9688657

File tree

14 files changed

+161
-128
lines changed

14 files changed

+161
-128
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,8 @@ def Tosa_AvgPool2dOp : Tosa_InferShapedTypeOp<"avg_pool2d"> {
7878
Tosa_IntArrayAttr2:$stride,
7979
Tosa_IntArrayAttr4:$pad,
8080
TypeAttrOf<Tosa_AccType>:$acc_type,
81-
OptionalAttr<Tosa_UnaryOpQuantizationAttr>:$quantization_info
81+
OptionalAttr<I32Attr>:$input_zp,
82+
OptionalAttr<I32Attr>:$output_zp
8283
);
8384

8485
let results = (outs
@@ -237,7 +238,8 @@ def Tosa_FullyConnectedOp : Tosa_InferShapedTypeOp<"fully_connected"> {
237238
Tosa_Tensor2D:$input,
238239
TosaTensorRankOf<[Tosa_Weight], [2]>:$weight,
239240
Tosa_Tensor1D:$bias,
240-
OptionalAttr<Tosa_ConvOpQuantizationAttr>:$quantization_info
241+
OptionalAttr<I32Attr>:$input_zp,
242+
OptionalAttr<I32Attr>:$weight_zp
241243
);
242244

243245
let results = (outs
@@ -263,7 +265,8 @@ def Tosa_MatMulOp : Tosa_InferShapedTypeOp<"matmul"> {
263265
let arguments = (ins
264266
Tosa_Tensor3D:$a,
265267
Tosa_Tensor3D:$b,
266-
OptionalAttr<Tosa_MatMulOpQuantizationAttr>:$quantization_info
268+
OptionalAttr<I32Attr>:$a_zp,
269+
OptionalAttr<I32Attr>:$b_zp
267270
);
268271

269272
let results = (outs
@@ -1114,7 +1117,8 @@ def Tosa_NegateOp : Tosa_ElementwiseUnaryOp<"negate"> {
11141117

11151118
let arguments = (ins
11161119
Tosa_Tensor:$input1,
1117-
OptionalAttr<Tosa_UnaryOpQuantizationAttr>:$quantization_info
1120+
OptionalAttr<I32Attr>:$input1_zp,
1121+
OptionalAttr<I32Attr>:$output_zp
11181122
);
11191123

11201124
let results = (outs
@@ -1589,7 +1593,7 @@ def Tosa_PadOp : Tosa_InferShapedTypeOp<"pad"> {
15891593
Tosa_RankedTensor:$input1,
15901594
Tosa_Shape:$padding,
15911595
Optional<Tosa_ScalarTensor>:$pad_const,
1592-
OptionalAttr<Tosa_PadOpQuantizationAttr>:$quantization_info
1596+
OptionalAttr<I32Attr>:$input_zp
15931597
);
15941598

15951599
let results = (outs

mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp

Lines changed: 54 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -141,63 +141,65 @@ static Value createLinalgBodyCalculationForElementwiseOp(
141141
}
142142

143143
// tosa::NegateOp
144-
if (isa<tosa::NegateOp>(op) && isa<FloatType>(elementTy))
145-
return rewriter.create<arith::NegFOp>(loc, resultTypes, args);
144+
if (isa<tosa::NegateOp>(op)) {
145+
if (isa<FloatType>(elementTy))
146+
return rewriter.create<arith::NegFOp>(loc, resultTypes, args);
146147

147-
if (isa<tosa::NegateOp>(op) && isa<IntegerType>(elementTy)) {
148-
int64_t inZp = 0, outZp = 0;
148+
if (isa<IntegerType>(elementTy)) {
149+
auto inputZpAttr = cast<tosa::NegateOp>(op).getInput1Zp();
150+
auto outputZpAttr = cast<tosa::NegateOp>(op).getOutputZp();
149151

150-
if (cast<tosa::NegateOp>(op).getQuantizationInfo()) {
151-
auto quantizationInfo = cast<tosa::NegateOp>(op).getQuantizationInfo();
152-
inZp = quantizationInfo.value().getInputZp();
153-
outZp = quantizationInfo.value().getOutputZp();
154-
}
152+
const int64_t inZp = inputZpAttr ? *inputZpAttr : 0;
153+
const int64_t outZp = outputZpAttr ? *outputZpAttr : 0;
155154

156-
int32_t inputBitWidth = elementTy.getIntOrFloatBitWidth();
157-
if (!inZp && !outZp) {
158-
auto constant = rewriter.create<arith::ConstantOp>(
159-
loc, IntegerAttr::get(elementTy, 0));
160-
return rewriter.create<arith::SubIOp>(loc, resultTypes, constant,
161-
args[0]);
162-
}
155+
if (!inZp && !outZp) {
156+
auto constant = rewriter.create<arith::ConstantOp>(
157+
loc, IntegerAttr::get(elementTy, 0));
158+
return rewriter.create<arith::SubIOp>(loc, resultTypes, constant,
159+
args[0]);
160+
}
163161

164-
// Compute the maximum value that can occur in the intermediate buffer.
165-
int64_t zpAdd = inZp + outZp;
166-
int64_t maxValue = APInt::getSignedMaxValue(inputBitWidth).getSExtValue() +
167-
std::abs(zpAdd) + 1;
168-
169-
// Convert that maximum value into the maximum bitwidth needed to represent
170-
// it. We assume 48-bit numbers may be supported further in the pipeline.
171-
int intermediateBitWidth = 64;
172-
if (maxValue <= APInt::getSignedMaxValue(16).getSExtValue()) {
173-
intermediateBitWidth = 16;
174-
} else if (maxValue <= APInt::getSignedMaxValue(32).getSExtValue()) {
175-
intermediateBitWidth = 32;
176-
} else if (maxValue <= APInt::getSignedMaxValue(48).getSExtValue()) {
177-
intermediateBitWidth = 48;
178-
}
162+
// Compute the maximum value that can occur in the intermediate buffer.
163+
const int32_t inputBitWidth = elementTy.getIntOrFloatBitWidth();
164+
const int64_t zpAdd = inZp + outZp;
165+
const int64_t maxValue =
166+
APInt::getSignedMaxValue(inputBitWidth).getSExtValue() +
167+
std::abs(zpAdd) + 1;
168+
169+
// Convert that maximum value into the maximum bitwidth needed to
170+
// represent it. We assume 48-bit numbers may be supported further in
171+
// the pipeline.
172+
int intermediateBitWidth = 64;
173+
if (maxValue <= APInt::getSignedMaxValue(16).getSExtValue()) {
174+
intermediateBitWidth = 16;
175+
} else if (maxValue <= APInt::getSignedMaxValue(32).getSExtValue()) {
176+
intermediateBitWidth = 32;
177+
} else if (maxValue <= APInt::getSignedMaxValue(48).getSExtValue()) {
178+
intermediateBitWidth = 48;
179+
}
179180

180-
Type intermediateType = rewriter.getIntegerType(intermediateBitWidth);
181-
Value zpAddValue = rewriter.create<arith::ConstantOp>(
182-
loc, rewriter.getIntegerAttr(intermediateType, zpAdd));
183-
184-
// The negation can be applied by doing:
185-
// outputValue = inZp + outZp - inputValue
186-
auto ext = rewriter.create<arith::ExtSIOp>(loc, intermediateType, args[0]);
187-
auto sub = rewriter.create<arith::SubIOp>(loc, zpAddValue, ext);
188-
189-
// Clamp to the negation range.
190-
Value min = rewriter.create<arith::ConstantIntOp>(
191-
loc, APInt::getSignedMinValue(inputBitWidth).getSExtValue(),
192-
intermediateType);
193-
Value max = rewriter.create<arith::ConstantIntOp>(
194-
loc, APInt::getSignedMaxValue(inputBitWidth).getSExtValue(),
195-
intermediateType);
196-
auto clamp =
197-
clampIntHelper(loc, sub, min, max, rewriter, /*isUnsigned=*/false);
198-
199-
// Truncate to the final value.
200-
return rewriter.create<arith::TruncIOp>(loc, elementTy, clamp);
181+
Type intermediateType = rewriter.getIntegerType(intermediateBitWidth);
182+
Value zpAddValue = rewriter.create<arith::ConstantOp>(
183+
loc, rewriter.getIntegerAttr(intermediateType, zpAdd));
184+
185+
// The negation can be applied by doing:
186+
// outputValue = inZp + outZp - inputValue
187+
auto ext =
188+
rewriter.create<arith::ExtSIOp>(loc, intermediateType, args[0]);
189+
auto sub = rewriter.create<arith::SubIOp>(loc, zpAddValue, ext);
190+
191+
// Clamp to the negation range.
192+
Value min = rewriter.create<arith::ConstantIntOp>(
193+
loc, APInt::getSignedMinValue(inputBitWidth).getSExtValue(),
194+
intermediateType);
195+
Value max = rewriter.create<arith::ConstantIntOp>(
196+
loc, APInt::getSignedMaxValue(inputBitWidth).getSExtValue(),
197+
intermediateType);
198+
auto clamp = clampIntHelper(loc, sub, min, max, rewriter, false);
199+
200+
// Truncate to the final value.
201+
return rewriter.create<arith::TruncIOp>(loc, elementTy, clamp);
202+
}
201203
}
202204

203205
// tosa::BitwiseAndOp

mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp

Lines changed: 13 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -590,18 +590,15 @@ class MatMulConverter : public OpConversionPattern<tosa::MatMulOp> {
590590
.create<linalg::FillOp>(loc, ValueRange{zero},
591591
ValueRange{emptyTensor})
592592
.result();
593-
if (!op.getQuantizationInfo()) {
593+
if (!op.getAZp() && !op.getBZp()) {
594594
rewriter.replaceOpWithNewOp<linalg::BatchMatmulOp>(
595595
op, TypeRange{op.getType()},
596596
ValueRange{adaptor.getA(), adaptor.getB()}, ValueRange{zeroTensor});
597597
return success();
598598
}
599599

600-
auto quantizationInfo = *op.getQuantizationInfo();
601-
auto aZp = rewriter.create<arith::ConstantOp>(
602-
loc, rewriter.getI32IntegerAttr(quantizationInfo.getAZp()));
603-
auto bZp = rewriter.create<arith::ConstantOp>(
604-
loc, rewriter.getI32IntegerAttr(quantizationInfo.getBZp()));
600+
auto aZp = rewriter.create<arith::ConstantOp>(loc, op.getAZpAttr());
601+
auto bZp = rewriter.create<arith::ConstantOp>(loc, op.getBZpAttr());
605602
rewriter.replaceOpWithNewOp<linalg::QuantizedBatchMatmulOp>(
606603
op, TypeRange{op.getType()},
607604
ValueRange{adaptor.getA(), adaptor.getB(), aZp, bZp}, zeroTensor);
@@ -661,7 +658,7 @@ class FullyConnectedConverter
661658
Value broadcastBias =
662659
linalgBroadcastAndMaybeExtSI(rewriter, loc, bias, biasEmptyTensor);
663660

664-
if (!op.getQuantizationInfo()) {
661+
if (!op.getInputZp() && !op.getWeightZp()) {
665662
Value matmul = rewriter
666663
.create<linalg::MatmulOp>(
667664
loc, TypeRange{op.getType()},
@@ -672,11 +669,9 @@ class FullyConnectedConverter
672669
return success();
673670
}
674671

675-
auto quantizationInfo = *op.getQuantizationInfo();
676-
auto inputZp = rewriter.create<arith::ConstantOp>(
677-
loc, rewriter.getI32IntegerAttr(quantizationInfo.getInputZp()));
678-
auto outputZp = rewriter.create<arith::ConstantOp>(
679-
loc, rewriter.getI32IntegerAttr(quantizationInfo.getWeightZp()));
672+
auto inputZp = rewriter.create<arith::ConstantOp>(loc, op.getInputZpAttr());
673+
auto outputZp =
674+
rewriter.create<arith::ConstantOp>(loc, op.getWeightZpAttr());
680675
Value matmul =
681676
rewriter
682677
.create<linalg::QuantizedMatmulOp>(
@@ -958,10 +953,9 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
958953

959954
// If we have quantization information we need to apply an offset
960955
// for the input zp value.
961-
if (op.getQuantizationInfo()) {
962-
auto quantizationInfo = *op.getQuantizationInfo();
963-
auto inputZp = rewriter.create<arith::ConstantOp>(
964-
loc, b.getIntegerAttr(accETy, quantizationInfo.getInputZp()));
956+
if (op.getInputZp()) {
957+
auto inputZp =
958+
rewriter.create<arith::ConstantOp>(loc, op.getInputZpAttr());
965959
Value offset =
966960
rewriter.create<arith::MulIOp>(loc, accETy, count, inputZp);
967961
poolVal =
@@ -1013,11 +1007,9 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
10131007

10141008
// If we have quantization information we need to apply output
10151009
// zeropoint.
1016-
if (op.getQuantizationInfo()) {
1017-
auto quantizationInfo = *op.getQuantizationInfo();
1018-
auto outputZp = rewriter.create<arith::ConstantOp>(
1019-
loc, b.getIntegerAttr(scaled.getType(),
1020-
quantizationInfo.getOutputZp()));
1010+
if (op.getOutputZp()) {
1011+
auto outputZp =
1012+
rewriter.create<arith::ConstantOp>(loc, op.getOutputZpAttr());
10211013
scaled = rewriter.create<arith::AddIOp>(loc, scaled, outputZp)
10221014
.getResult();
10231015
}

mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -358,10 +358,10 @@ class PadConverter : public OpConversionPattern<tosa::PadOp> {
358358
TypedAttr constantAttr;
359359
if (isa<FloatType>(elementTy)) {
360360
constantAttr = rewriter.getFloatAttr(elementTy, 0.0);
361-
} else if (isa<IntegerType>(elementTy) && !padOp.getQuantizationInfo()) {
361+
} else if (isa<IntegerType>(elementTy) && !padOp.getInputZpAttr()) {
362362
constantAttr = rewriter.getIntegerAttr(elementTy, 0);
363-
} else if (isa<IntegerType>(elementTy) && padOp.getQuantizationInfo()) {
364-
int64_t value = padOp.getQuantizationInfo()->getInputZp();
363+
} else if (isa<IntegerType>(elementTy) && padOp.getInputZpAttr()) {
364+
int64_t value = padOp.getInputZpAttr().getInt();
365365
constantAttr = rewriter.getIntegerAttr(elementTy, value);
366366
}
367367
if (constantAttr)

mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -207,10 +207,10 @@ struct MaterializePadValue : public OpRewritePattern<tosa::PadOp> {
207207
Attribute constantAttr;
208208
if (llvm::isa<FloatType>(elementTy)) {
209209
constantAttr = rewriter.getFloatAttr(elementTy, 0.0);
210-
} else if (llvm::isa<IntegerType>(elementTy) && !op.getQuantizationInfo()) {
210+
} else if (llvm::isa<IntegerType>(elementTy) && !op.getInputZpAttr()) {
211211
constantAttr = rewriter.getIntegerAttr(elementTy, 0);
212-
} else if (llvm::isa<IntegerType>(elementTy) && op.getQuantizationInfo()) {
213-
auto value = op.getQuantizationInfo()->getInputZp();
212+
} else if (llvm::isa<IntegerType>(elementTy) && op.getInputZpAttr()) {
213+
int64_t value = op.getInputZpAttr().getInt();
214214
constantAttr = rewriter.getIntegerAttr(elementTy, value);
215215
}
216216

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 43 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -271,11 +271,11 @@ static LogicalResult verifyConvOp(T op) {
271271
}
272272
}
273273

274-
bool inputIsQuant = !llvm::isa<FloatType>(inputEType);
275-
bool weightIsQuant = !llvm::isa<FloatType>(weightEType);
274+
bool inputIsFloat = llvm::isa<FloatType>(inputEType);
275+
bool weightIsFloat = llvm::isa<FloatType>(weightEType);
276276

277-
// Either both must be quantized or both unquantized.
278-
if (inputIsQuant != weightIsQuant) {
277+
// Either both must be float or both non-float.
278+
if (inputIsFloat != weightIsFloat) {
279279
op.emitOpError(
280280
"expect both input and weight to be float or not together, got ")
281281
<< inputEType << " and " << weightEType;
@@ -527,7 +527,12 @@ static void buildTransConvOpWithQuantInfo(
527527
auto quantAttr = ::buildConvOpQuantizationAttr(builder, input, weight);
528528

529529
if (quantAttr) {
530-
result.addAttribute("quantization_info", quantAttr);
530+
result.addAttribute("input_zp",
531+
builder.getI32IntegerAttr(
532+
static_cast<int32_t>(quantAttr.getInputZp())));
533+
result.addAttribute("weight_zp",
534+
builder.getI32IntegerAttr(
535+
static_cast<int32_t>(quantAttr.getWeightZp())));
531536
result.addTypes(
532537
buildConvOpResultTypeInfo(builder, outputType, input, weight));
533538
} else {
@@ -563,7 +568,10 @@ static void buildMatMulOpWithQuantInfo(OpBuilder &builder,
563568
auto quantAttr = ::buildMatMulOpQuantizationAttr(builder, a, b);
564569

565570
if (quantAttr) {
566-
result.addAttribute("quantization_info", quantAttr);
571+
result.addAttribute("a_zp", builder.getI32IntegerAttr(
572+
static_cast<int32_t>(quantAttr.getAZp())));
573+
result.addAttribute("b_zp", builder.getI32IntegerAttr(
574+
static_cast<int32_t>(quantAttr.getBZp())));
567575

568576
auto inputType = llvm::dyn_cast<ShapedType>(a.getType());
569577
assert(inputType && "Input must be a shaped tensor type!");
@@ -603,8 +611,14 @@ buildAvgPool2dOpWithQuantInfo(OpBuilder &builder, OperationState &result,
603611
result.addAttribute("pad", pad);
604612
result.addAttribute("acc_type", accType);
605613
auto quantAttr = buildUnaryOpQuantizationAttr(builder, input, outputType);
606-
if (quantAttr)
607-
result.addAttribute("quantization_info", quantAttr);
614+
if (quantAttr) {
615+
result.addAttribute("input_zp",
616+
builder.getI32IntegerAttr(
617+
static_cast<int32_t>(quantAttr.getInputZp())));
618+
result.addAttribute("output_zp",
619+
builder.getI32IntegerAttr(
620+
static_cast<int32_t>(quantAttr.getOutputZp())));
621+
}
608622
result.types.push_back(outputType);
609623
}
610624

@@ -616,8 +630,15 @@ static void buildUnaryOpWithQuantInfo(OpBuilder &builder,
616630
Value input) {
617631
result.addOperands(input);
618632
auto quantAttr = buildUnaryOpQuantizationAttr(builder, input, outputType);
619-
if (quantAttr)
620-
result.addAttribute("quantization_info", quantAttr);
633+
if (quantAttr) {
634+
// note: negateOp has attributes input1_zp and output_zp
635+
result.addAttribute("input1_zp",
636+
builder.getI32IntegerAttr(
637+
static_cast<int32_t>(quantAttr.getInputZp())));
638+
result.addAttribute("output_zp",
639+
builder.getI32IntegerAttr(
640+
static_cast<int32_t>(quantAttr.getOutputZp())));
641+
}
621642
result.types.push_back(outputType);
622643
}
623644

@@ -629,8 +650,11 @@ static void buildPadOpWithQuantInfo(OpBuilder &builder, OperationState &result,
629650
Value paddings) {
630651
result.addOperands({input, paddings});
631652
auto quantAttr = buildPadOpQuantizationAttr(builder, input);
632-
if (quantAttr)
633-
result.addAttribute("quantization_info", quantAttr);
653+
if (quantAttr) {
654+
result.addAttribute("input_zp",
655+
builder.getI32IntegerAttr(
656+
static_cast<int32_t>(quantAttr.getInputZp())));
657+
}
634658
result.types.push_back(outputType);
635659
}
636660

@@ -643,8 +667,11 @@ static void buildExplicitValuePadOpWithQuantInfo(OpBuilder &builder,
643667
Value padConst) {
644668
result.addOperands({input, paddings, padConst});
645669
auto quantAttr = buildPadOpQuantizationAttr(builder, input);
646-
if (quantAttr)
647-
result.addAttribute("quantization_info", quantAttr);
670+
if (quantAttr) {
671+
result.addAttribute("input_zp",
672+
builder.getI32IntegerAttr(
673+
static_cast<int32_t>(quantAttr.getInputZp())));
674+
}
648675
result.types.push_back(outputType);
649676
}
650677

@@ -898,9 +925,8 @@ LogicalResult FullyConnectedOp::verify() {
898925

899926
// Quantized type must have constructed the quantizationattr, and unquantized
900927
// types should not have a quantizationattr.
901-
if ((inputIsQuant && !getQuantizationInfo()) ||
902-
(!inputIsQuant && getQuantizationInfo())) {
903-
emitOpError("quantizationattr is required for quantized type, and not "
928+
if ((inputIsQuant && !getInputZp()) || (!inputIsQuant && getInputZp())) {
929+
emitOpError("input zero point is required for quantized type, and not "
904930
"allowed for float type");
905931
return failure();
906932
}

0 commit comments

Comments
 (0)