Skip to content

Commit babe874

Browse files
Tai78641FranklandJack
authored andcommitted
[mlir][tosa] Remove Quantization Attribute
Removed the TOSA quantization attribute used in various MLIR TOSA dialect operations in favour of using builtin attributes. Update any lit tests, conversions and transformations appropriately. Rename operands as follows to align with the TOSA-v1.0 specification: * `cond` -> `condition` * `then_branch` -> `then_graph` * `else_branch` -> `else_graph` * `inputs` -> `input_list` * `output` -> `output_list` * `cond` -> `cond_graph` * `body` -> `body_graph` Signed-off-by: Tai Ly <[email protected]>
1 parent 50d5d06 commit babe874

File tree

15 files changed

+184
-147
lines changed

15 files changed

+184
-147
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,8 @@ def Tosa_AvgPool2dOp : Tosa_InferShapedTypeOp<"avg_pool2d"> {
7878
Tosa_IntArrayAttr2:$stride,
7979
Tosa_IntArrayAttr4:$pad,
8080
TypeAttrOf<Tosa_AccType>:$acc_type,
81-
OptionalAttr<Tosa_UnaryOpQuantizationAttr>:$quantization_info
81+
OptionalAttr<I32Attr>:$input_zp,
82+
OptionalAttr<I32Attr>:$output_zp
8283
);
8384

8485
let results = (outs
@@ -237,7 +238,8 @@ def Tosa_FullyConnectedOp : Tosa_InferShapedTypeOp<"fully_connected"> {
237238
Tosa_Tensor2D:$input,
238239
TosaTensorRankOf<[Tosa_Weight], [2]>:$weight,
239240
Tosa_Tensor1D:$bias,
240-
OptionalAttr<Tosa_ConvOpQuantizationAttr>:$quantization_info
241+
OptionalAttr<I32Attr>:$input_zp,
242+
OptionalAttr<I32Attr>:$weight_zp
241243
);
242244

243245
let results = (outs
@@ -263,7 +265,8 @@ def Tosa_MatMulOp : Tosa_InferShapedTypeOp<"matmul"> {
263265
let arguments = (ins
264266
Tosa_Tensor3D:$a,
265267
Tosa_Tensor3D:$b,
266-
OptionalAttr<Tosa_MatMulOpQuantizationAttr>:$quantization_info
268+
OptionalAttr<I32Attr>:$a_zp,
269+
OptionalAttr<I32Attr>:$b_zp
267270
);
268271

269272
let results = (outs
@@ -1114,7 +1117,8 @@ def Tosa_NegateOp : Tosa_ElementwiseUnaryOp<"negate"> {
11141117

11151118
let arguments = (ins
11161119
Tosa_Tensor:$input1,
1117-
OptionalAttr<Tosa_UnaryOpQuantizationAttr>:$quantization_info
1120+
OptionalAttr<I32Attr>:$input1_zp,
1121+
OptionalAttr<I32Attr>:$output_zp
11181122
);
11191123

11201124
let results = (outs
@@ -1589,7 +1593,7 @@ def Tosa_PadOp : Tosa_InferShapedTypeOp<"pad"> {
15891593
Tosa_RankedTensor:$input1,
15901594
Tosa_Shape:$padding,
15911595
Optional<Tosa_ScalarTensor>:$pad_const,
1592-
OptionalAttr<Tosa_PadOpQuantizationAttr>:$quantization_info
1596+
OptionalAttr<I32Attr>:$input_zp
15931597
);
15941598

15951599
let results = (outs
@@ -2071,17 +2075,17 @@ def Tosa_IfOp : Tosa_Op<"cond_if",
20712075
}];
20722076

20732077
let arguments = (ins
2074-
Tosa_I1Tensor:$cond,
2078+
Tosa_I1Tensor:$condition,
20752079
Variadic<Tosa_Tensor>:$inputs
20762080
);
20772081

20782082
let results = (outs
2079-
Variadic<Tosa_Tensor>:$output
2083+
Variadic<Tosa_Tensor>:$output_list
20802084
);
20812085

20822086
let regions = (region
2083-
SizedRegion<1>:$then_branch,
2084-
SizedRegion<1>:$else_branch
2087+
SizedRegion<1>:$then_graph,
2088+
SizedRegion<1>:$else_graph
20852089
);
20862090

20872091
let hasCustomAssemblyFormat = 1;
@@ -2108,16 +2112,16 @@ def Tosa_WhileOp : Tosa_Op<"while_loop", [
21082112
}];
21092113

21102114
let arguments = (ins
2111-
Variadic<Tosa_Tensor>:$inputs
2115+
Variadic<Tosa_Tensor>:$input_list
21122116
);
21132117

21142118
let results = (outs
2115-
Variadic<Tosa_Tensor>:$output
2119+
Variadic<Tosa_Tensor>:$output_list
21162120
);
21172121

21182122
let regions = (region
2119-
SizedRegion<1>:$cond,
2120-
SizedRegion<1>:$body
2123+
SizedRegion<1>:$cond_graph,
2124+
SizedRegion<1>:$body_graph
21212125
);
21222126

21232127
let hasCustomAssemblyFormat = 1;

mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp

Lines changed: 52 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -141,63 +141,67 @@ static Value createLinalgBodyCalculationForElementwiseOp(
141141
}
142142

143143
// tosa::NegateOp
144-
if (isa<tosa::NegateOp>(op) && isa<FloatType>(elementTy))
145-
return rewriter.create<arith::NegFOp>(loc, resultTypes, args);
144+
if (isa<tosa::NegateOp>(op)) {
145+
if (isa<FloatType>(elementTy))
146+
return rewriter.create<arith::NegFOp>(loc, resultTypes, args);
146147

147-
if (isa<tosa::NegateOp>(op) && isa<IntegerType>(elementTy)) {
148-
int64_t inZp = 0, outZp = 0;
148+
auto inputZpAttr = cast<tosa::NegateOp>(op).getInput1ZpAttr();
149+
auto outputZpAttr = cast<tosa::NegateOp>(op).getOutputZpAttr();
150+
int32_t inputZpVal = inputZpAttr ? inputZpAttr.getInt() : 0;
151+
int32_t outputZpVal = outputZpAttr ? outputZpAttr.getInt() : 0;
149152

150-
if (cast<tosa::NegateOp>(op).getQuantizationInfo()) {
151-
auto quantizationInfo = cast<tosa::NegateOp>(op).getQuantizationInfo();
152-
inZp = quantizationInfo.value().getInputZp();
153-
outZp = quantizationInfo.value().getOutputZp();
154-
}
155-
156-
int32_t inputBitWidth = elementTy.getIntOrFloatBitWidth();
157-
if (!inZp && !outZp) {
153+
if (isa<IntegerType>(elementTy) && inputZpVal == 0 && outputZpVal == 0) {
158154
auto constant = rewriter.create<arith::ConstantOp>(
159155
loc, IntegerAttr::get(elementTy, 0));
160156
return rewriter.create<arith::SubIOp>(loc, resultTypes, constant,
161157
args[0]);
162158
}
163159

164-
// Compute the maximum value that can occur in the intermediate buffer.
165-
int64_t zpAdd = inZp + outZp;
166-
int64_t maxValue = APInt::getSignedMaxValue(inputBitWidth).getSExtValue() +
167-
std::abs(zpAdd) + 1;
168-
169-
// Convert that maximum value into the maximum bitwidth needed to represent
170-
// it. We assume 48-bit numbers may be supported further in the pipeline.
171-
int intermediateBitWidth = 64;
172-
if (maxValue <= APInt::getSignedMaxValue(16).getSExtValue()) {
173-
intermediateBitWidth = 16;
174-
} else if (maxValue <= APInt::getSignedMaxValue(32).getSExtValue()) {
175-
intermediateBitWidth = 32;
176-
} else if (maxValue <= APInt::getSignedMaxValue(48).getSExtValue()) {
177-
intermediateBitWidth = 48;
178-
}
160+
if (isa<IntegerType>(elementTy) && (inputZpVal != 0 || outputZpVal != 0)) {
161+
int32_t inputBitWidth = elementTy.getIntOrFloatBitWidth();
162+
int64_t inZp = inputZpVal;
163+
int64_t outZp = outputZpVal;
164+
165+
// Compute the maximum value that can occur in the intermediate buffer.
166+
int64_t zpAdd = inZp + outZp;
167+
int64_t maxValue =
168+
APInt::getSignedMaxValue(inputBitWidth).getSExtValue() +
169+
std::abs(zpAdd) + 1;
170+
171+
// Convert that maximum value into the maximum bitwidth needed to
172+
// represent it. We assume 48-bit numbers may be supported further in the
173+
// pipeline.
174+
int intermediateBitWidth = 64;
175+
if (maxValue <= APInt::getSignedMaxValue(16).getSExtValue()) {
176+
intermediateBitWidth = 16;
177+
} else if (maxValue <= APInt::getSignedMaxValue(32).getSExtValue()) {
178+
intermediateBitWidth = 32;
179+
} else if (maxValue <= APInt::getSignedMaxValue(48).getSExtValue()) {
180+
intermediateBitWidth = 48;
181+
}
179182

180-
Type intermediateType = rewriter.getIntegerType(intermediateBitWidth);
181-
Value zpAddValue = rewriter.create<arith::ConstantOp>(
182-
loc, rewriter.getIntegerAttr(intermediateType, zpAdd));
183-
184-
// The negation can be applied by doing:
185-
// outputValue = inZp + outZp - inputValue
186-
auto ext = rewriter.create<arith::ExtSIOp>(loc, intermediateType, args[0]);
187-
auto sub = rewriter.create<arith::SubIOp>(loc, zpAddValue, ext);
188-
189-
// Clamp to the negation range.
190-
Value min = rewriter.create<arith::ConstantIntOp>(
191-
loc, APInt::getSignedMinValue(inputBitWidth).getSExtValue(),
192-
intermediateType);
193-
Value max = rewriter.create<arith::ConstantIntOp>(
194-
loc, APInt::getSignedMaxValue(inputBitWidth).getSExtValue(),
195-
intermediateType);
196-
auto clamp =
197-
clampIntHelper(loc, sub, min, max, rewriter, /*isUnsigned=*/false);
198-
199-
// Truncate to the final value.
200-
return rewriter.create<arith::TruncIOp>(loc, elementTy, clamp);
183+
Type intermediateType = rewriter.getIntegerType(intermediateBitWidth);
184+
Value zpAddValue = rewriter.create<arith::ConstantOp>(
185+
loc, rewriter.getIntegerAttr(intermediateType, zpAdd));
186+
187+
// The negation can be applied by doing:
188+
// outputValue = inZp + outZp - inputValue
189+
auto ext =
190+
rewriter.create<arith::ExtSIOp>(loc, intermediateType, args[0]);
191+
auto sub = rewriter.create<arith::SubIOp>(loc, zpAddValue, ext);
192+
193+
// Clamp to the negation range.
194+
Value min = rewriter.create<arith::ConstantIntOp>(
195+
loc, APInt::getSignedMinValue(inputBitWidth).getSExtValue(),
196+
intermediateType);
197+
Value max = rewriter.create<arith::ConstantIntOp>(
198+
loc, APInt::getSignedMaxValue(inputBitWidth).getSExtValue(),
199+
intermediateType);
200+
auto clamp = clampIntHelper(loc, sub, min, max, rewriter, false);
201+
202+
// Truncate to the final value.
203+
return rewriter.create<arith::TruncIOp>(loc, elementTy, clamp);
204+
}
201205
}
202206

203207
// tosa::BitwiseAndOp

mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp

Lines changed: 11 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -590,18 +590,15 @@ class MatMulConverter : public OpConversionPattern<tosa::MatMulOp> {
590590
.create<linalg::FillOp>(loc, ValueRange{zero},
591591
ValueRange{emptyTensor})
592592
.result();
593-
if (!op.getQuantizationInfo()) {
593+
if (!op.getAZp() && !op.getBZp()) {
594594
rewriter.replaceOpWithNewOp<linalg::BatchMatmulOp>(
595595
op, TypeRange{op.getType()},
596596
ValueRange{adaptor.getA(), adaptor.getB()}, ValueRange{zeroTensor});
597597
return success();
598598
}
599599

600-
auto quantizationInfo = *op.getQuantizationInfo();
601-
auto aZp = rewriter.create<arith::ConstantOp>(
602-
loc, rewriter.getI32IntegerAttr(quantizationInfo.getAZp()));
603-
auto bZp = rewriter.create<arith::ConstantOp>(
604-
loc, rewriter.getI32IntegerAttr(quantizationInfo.getBZp()));
600+
auto aZp = rewriter.create<arith::ConstantOp>(loc, op.getAZpAttr());
601+
auto bZp = rewriter.create<arith::ConstantOp>(loc, op.getBZpAttr());
605602
rewriter.replaceOpWithNewOp<linalg::QuantizedBatchMatmulOp>(
606603
op, TypeRange{op.getType()},
607604
ValueRange{adaptor.getA(), adaptor.getB(), aZp, bZp}, zeroTensor);
@@ -661,7 +658,7 @@ class FullyConnectedConverter
661658
Value broadcastBias =
662659
linalgBroadcastAndMaybeExtSI(rewriter, loc, bias, biasEmptyTensor);
663660

664-
if (!op.getQuantizationInfo()) {
661+
if (!op.getInputZp() && !op.getWeightZp()) {
665662
Value matmul = rewriter
666663
.create<linalg::MatmulOp>(
667664
loc, TypeRange{op.getType()},
@@ -672,11 +669,8 @@ class FullyConnectedConverter
672669
return success();
673670
}
674671

675-
auto quantizationInfo = *op.getQuantizationInfo();
676-
auto inputZp = rewriter.create<arith::ConstantOp>(
677-
loc, rewriter.getI32IntegerAttr(quantizationInfo.getInputZp()));
678-
auto outputZp = rewriter.create<arith::ConstantOp>(
679-
loc, rewriter.getI32IntegerAttr(quantizationInfo.getWeightZp()));
672+
auto inputZp = rewriter.create<arith::ConstantOp>(loc, op.getInputZpAttr());
673+
auto outputZp = rewriter.create<arith::ConstantOp>(loc, op.getWeightZpAttr());
680674
Value matmul =
681675
rewriter
682676
.create<linalg::QuantizedMatmulOp>(
@@ -958,10 +952,9 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
958952

959953
// If we have quantization information we need to apply an offset
960954
// for the input zp value.
961-
if (op.getQuantizationInfo()) {
962-
auto quantizationInfo = *op.getQuantizationInfo();
955+
if (op.getInputZp()) {
963956
auto inputZp = rewriter.create<arith::ConstantOp>(
964-
loc, b.getIntegerAttr(accETy, quantizationInfo.getInputZp()));
957+
loc, op.getInputZpAttr());
965958
Value offset =
966959
rewriter.create<arith::MulIOp>(loc, accETy, count, inputZp);
967960
poolVal =
@@ -1013,11 +1006,9 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
10131006

10141007
// If we have quantization information we need to apply output
10151008
// zeropoint.
1016-
if (op.getQuantizationInfo()) {
1017-
auto quantizationInfo = *op.getQuantizationInfo();
1018-
auto outputZp = rewriter.create<arith::ConstantOp>(
1019-
loc, b.getIntegerAttr(scaled.getType(),
1020-
quantizationInfo.getOutputZp()));
1009+
if (op.getOutputZp()) {
1010+
auto outputZp =
1011+
rewriter.create<arith::ConstantOp>(loc, op.getOutputZpAttr());
10211012
scaled = rewriter.create<arith::AddIOp>(loc, scaled, outputZp)
10221013
.getResult();
10231014
}

mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -68,13 +68,13 @@ class IfOpConverter : public OpRewritePattern<tosa::IfOp> {
6868
LogicalResult matchAndRewrite(tosa::IfOp op,
6969
PatternRewriter &rewriter) const final {
7070
auto condition =
71-
rewriter.create<tensor::ExtractOp>(op.getLoc(), op.getCond());
71+
rewriter.create<tensor::ExtractOp>(op.getLoc(), op.getCondition());
7272
auto newIf = rewriter.create<scf::IfOp>(op.getLoc(), op.getResultTypes(),
7373
condition, true);
7474

75-
inlineIfCase(op.getThenBranch(), newIf.getThenRegion(), op.getInputs(),
75+
inlineIfCase(op.getThenGraph(), newIf.getThenRegion(), op.getInputs(),
7676
rewriter);
77-
inlineIfCase(op.getElseBranch(), newIf.getElseRegion(), op.getInputs(),
77+
inlineIfCase(op.getElseGraph(), newIf.getElseRegion(), op.getInputs(),
7878
rewriter);
7979

8080
rewriter.replaceOp(op, newIf.getResults());
@@ -158,12 +158,12 @@ class WhileOpConverter : public OpRewritePattern<tosa::WhileOp> {
158158
LogicalResult matchAndRewrite(tosa::WhileOp op,
159159
PatternRewriter &rewriter) const final {
160160
auto newWhile = rewriter.create<scf::WhileOp>(
161-
op.getLoc(), op.getResultTypes(), op.getInputs());
161+
op.getLoc(), op.getResultTypes(), op.getInputList());
162162
rewriter.createBlock(&newWhile.getBefore());
163163
rewriter.createBlock(&newWhile.getAfter());
164164

165-
inlineWhileCase(op.getCond(), newWhile.getBefore(), rewriter, true);
166-
inlineWhileCase(op.getBody(), newWhile.getAfter(), rewriter, false);
165+
inlineWhileCase(op.getCondGraph(), newWhile.getBefore(), rewriter, true);
166+
inlineWhileCase(op.getBodyGraph(), newWhile.getAfter(), rewriter, false);
167167

168168
rewriter.replaceOp(op, newWhile.getResults());
169169

mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -358,10 +358,10 @@ class PadConverter : public OpConversionPattern<tosa::PadOp> {
358358
TypedAttr constantAttr;
359359
if (isa<FloatType>(elementTy)) {
360360
constantAttr = rewriter.getFloatAttr(elementTy, 0.0);
361-
} else if (isa<IntegerType>(elementTy) && !padOp.getQuantizationInfo()) {
361+
} else if (isa<IntegerType>(elementTy) && !padOp.getInputZpAttr()) {
362362
constantAttr = rewriter.getIntegerAttr(elementTy, 0);
363-
} else if (isa<IntegerType>(elementTy) && padOp.getQuantizationInfo()) {
364-
int64_t value = padOp.getQuantizationInfo()->getInputZp();
363+
} else if (isa<IntegerType>(elementTy) && padOp.getInputZpAttr()) {
364+
int64_t value = padOp.getInputZpAttr().getInt();
365365
constantAttr = rewriter.getIntegerAttr(elementTy, value);
366366
}
367367
if (constantAttr)

mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -207,10 +207,10 @@ struct MaterializePadValue : public OpRewritePattern<tosa::PadOp> {
207207
Attribute constantAttr;
208208
if (llvm::isa<FloatType>(elementTy)) {
209209
constantAttr = rewriter.getFloatAttr(elementTy, 0.0);
210-
} else if (llvm::isa<IntegerType>(elementTy) && !op.getQuantizationInfo()) {
210+
} else if (llvm::isa<IntegerType>(elementTy) && !op.getInputZpAttr()) {
211211
constantAttr = rewriter.getIntegerAttr(elementTy, 0);
212-
} else if (llvm::isa<IntegerType>(elementTy) && op.getQuantizationInfo()) {
213-
auto value = op.getQuantizationInfo()->getInputZp();
212+
} else if (llvm::isa<IntegerType>(elementTy) && op.getInputZpAttr()) {
213+
int64_t value = op.getInputZpAttr().getInt();
214214
constantAttr = rewriter.getIntegerAttr(elementTy, value);
215215
}
216216

0 commit comments

Comments
 (0)