Skip to content

Commit dbd40f4

Browse files
authored
Use single rounding as default for TOSA lowering
Differential Revision: D61240443 Pull Request resolved: #4591
1 parent ba3448c commit dbd40f4

File tree

2 files changed

+14
-33
lines changed

2 files changed

+14
-33
lines changed

backends/arm/operators/op_addmm.py

+11-26
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,7 @@
1212
register_node_visitor,
1313
)
1414
from executorch.backends.arm.tosa_mapping import TosaArg
15-
from executorch.backends.arm.tosa_quant_utils import (
16-
compute_multiplier_and_shift,
17-
get_quant_node_args,
18-
)
15+
from executorch.backends.arm.tosa_quant_utils import build_rescale, get_quant_node_args
1916

2017
from executorch.backends.arm.tosa_utils import build_reshape
2118
from executorch.exir.dialects._ops import ops as exir_ops
@@ -128,32 +125,20 @@ def define_node(
128125
weight_scale = get_quant_node_args(weight_node_q_node).scale
129126

130127
output_rescale_scale = (input_scale * weight_scale) / consumer_node_scale
131-
(
132-
multiplier_output,
133-
shift_output,
134-
) = compute_multiplier_and_shift(output_rescale_scale)
135-
136-
attr_rescale_output = ts.TosaSerializerAttribute()
137-
attr_rescale_output.RescaleAttribute(
138-
input_zp=0,
139-
output_zp=consumer_node_node_zp,
140-
multiplier=[multiplier_output],
141-
shift=[shift_output],
142-
scale32=True,
143-
double_round=True,
144-
per_channel=False,
145-
input_unsigned=False,
146-
output_unsigned=False,
147-
)
148128

149129
reshaped_res = tosa_graph.addIntermediate(result_shape, ts.DType.INT32)
150130
build_reshape(tosa_graph, conv2d_res.name, result_shape, reshaped_res.name)
151131

152-
tosa_graph.addOperator(
153-
TosaOp.Op().RESCALE,
154-
[reshaped_res.name],
155-
[output.name],
156-
attr_rescale_output,
132+
build_rescale(
133+
tosa_fb=tosa_graph,
134+
scale=output_rescale_scale,
135+
input_node=reshaped_res,
136+
output_name=output.name,
137+
output_type=ts.DType.INT8,
138+
output_shape=reshaped_res.shape,
139+
input_zp=0,
140+
output_zp=consumer_node_node_zp,
141+
is_double_round=False,
157142
)
158143

159144
else:

backends/arm/tosa_quant_utils.py

+3-7
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ def build_rescale(
171171
output_shape,
172172
input_zp,
173173
output_zp,
174-
is_double_round,
174+
is_double_round=False,
175175
):
176176
scale_width = 32 if is_scale32(output_type) else 16
177177
multiplier, shift = compute_multiplier_and_shift(scale, scale_width)
@@ -197,7 +197,7 @@ def build_rescale(
197197

198198

199199
def build_rescale_to_int32(
200-
tosa_fb, input, input_zp, rescale_scale, is_scale32=True, is_double_round=True
200+
tosa_fb, input, input_zp, rescale_scale, is_scale32=True, is_double_round=False
201201
) -> TosaSerializerTensor:
202202
multiplier, shift = compute_multiplier_and_shift(rescale_scale)
203203
attr_rescale = ts.TosaSerializerAttribute()
@@ -230,7 +230,7 @@ def build_rescale_from_int32(
230230
output_zp,
231231
rescale_scale,
232232
is_scale32=True,
233-
is_double_round=True,
233+
is_double_round=False,
234234
) -> TosaSerializerTensor:
235235
multiplier, shift = compute_multiplier_and_shift(rescale_scale)
236236
attr_rescale_output = ts.TosaSerializerAttribute()
@@ -329,9 +329,6 @@ def build_rescale_conv_output(
329329
output_scale,
330330
output_zp,
331331
):
332-
# Only use double round if we are doing 32 bit scaling
333-
double_round = is_scale32(output_type)
334-
335332
# TODO add check to verify if this is a Per-channel quantization.
336333
post_conv2d_scale = (input_scale.number * weight_scale.number) / output_scale.number
337334

@@ -345,6 +342,5 @@ def build_rescale_conv_output(
345342
op.shape,
346343
0,
347344
output_zp.number,
348-
double_round,
349345
)
350346
return

0 commit comments

Comments
 (0)