Skip to content

Commit fa2e1f2

Browse files
Arm backend: Rename const-tensors for TOSA 1.0
Rename const tensors such input_zp to use node.name instead of input_node.name. This avoids clashes as conv and build_rescale used different names. Signed-off-by: Oscar Andersson <[email protected]> Change-Id: If3b636e97bf623cc9dbbe681632a6bc9f4265083
1 parent 4909db1 commit fa2e1f2

File tree

2 files changed

+23
-21
lines changed

2 files changed

+23
-21
lines changed

backends/arm/operators/op_conv2d.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -277,17 +277,29 @@ def define_node(
277277
input_qparams = get_input_qparams(node)
278278
input_zp = input_qparams[0].zp
279279

280-
tosa_graph.addConst([1], output.dtype, [input_zp], name=f"{node.name}_input_zp")
281-
tosa_graph.addConst([1], output.dtype, [0], name=f"{node.name}_weight_zp")
280+
# The output type is int32 when input type is int8.
281+
conv2d_output_name = output.name
282+
if output.dtype == ts.DType.INT8:
283+
conv2d_res = tosa_graph.addIntermediate(
284+
tosa_shape(output.shape, output.dim_order), ts.DType.INT32
285+
)
286+
conv2d_output_name = conv2d_res.name
282287
acc_type = (
283288
inputs[0].dtype if inputs[0].dtype == ts.DType.FP32 else ts.DType.INT32
284289
)
285290

291+
tosa_graph.addConst(
292+
[1], output.dtype, [input_zp], name=f"{conv2d_output_name}_input_zp"
293+
)
294+
tosa_graph.addConst(
295+
[1], output.dtype, [0], name=f"{conv2d_output_name}_weight_zp"
296+
)
297+
286298
# Non-bias case.
287299
if len(node.all_input_nodes) == 2:
288300
# Create a zero bias tensor if not presented
289301
out_channels = weight.shape[0]
290-
bias_name = "bias" + node.name.split("default", 1)[1]
302+
bias_name = f"{conv2d_output_name}_bias"
291303
bias_type = output.dtype
292304
if output.dtype == ts.DType.INT8:
293305
# Conv is quantized to int8, but the TOSA operator has
@@ -301,14 +313,6 @@ def define_node(
301313
name=bias_name,
302314
)
303315

304-
# The output type is int32 when input type is int8.
305-
conv2d_output_name = output.name
306-
if output.dtype == ts.DType.INT8:
307-
conv2d_res = tosa_graph.addIntermediate(
308-
tosa_shape(output.shape, output.dim_order), ts.DType.INT32
309-
)
310-
conv2d_output_name = conv2d_res.name
311-
312316
# Given input.shape is (N, Ci, H, W), and weight.shape is (Co, Ci/G, H, W)
313317
in_channels = input.shape[1]
314318
out_channels = weight.shape[0]
@@ -373,8 +377,8 @@ def define_node(
373377
input.name,
374378
weight_name,
375379
bias.name,
376-
f"{node.name}_input_zp",
377-
f"{node.name}_weight_zp",
380+
f"{conv2d_output_name}_input_zp",
381+
f"{conv2d_output_name}_weight_zp",
378382
],
379383
[conv2d_output_name],
380384
attr,

backends/arm/tosa_quant_utils.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,7 @@ def create_const_ops_for_rescale(
239239
tosa_fb,
240240
scale_32,
241241
input_dtype,
242-
input_name,
242+
node_name,
243243
multipliers,
244244
shifts,
245245
input_zp,
@@ -252,16 +252,16 @@ def create_const_ops_for_rescale(
252252
(len(multipliers),),
253253
ts.DType.INT32 if scale_32 else ts.DType.INT16,
254254
multipliers,
255-
name=input_name + "_multipliers",
255+
name=node_name + "_multipliers",
256256
)
257257
shifts = tosa_fb.addConst(
258-
(len(shifts),), ts.DType.INT8, shifts, name=input_name + "_shifts"
258+
(len(shifts),), ts.DType.INT8, shifts, name=node_name + "_shifts"
259259
)
260260
input_zp = tosa_fb.addConst(
261-
[1], input_dtype, [input_zp], name=input_name + "_input_zp"
261+
[1], input_dtype, [input_zp], name=node_name + "_input_zp"
262262
)
263263
output_zp = tosa_fb.addConst(
264-
[1], output_dtype, [output_zp], name=input_name + "_output_zp"
264+
[1], output_dtype, [output_zp], name=node_name + "_output_zp"
265265
)
266266

267267
return [multipliers.name, shifts.name, input_zp.name, output_zp.name]
@@ -281,16 +281,14 @@ def build_rescale(
281281
import serializer.tosa_serializer as ts # type: ignore
282282
import tosa.Op as TosaOp # type: ignore
283283

284-
input_name = input_node.name
285-
286284
scaleWidth = 32
287285
is_scale32 = True
288286
multipliers, shifts = compute_multiplier_and_shift(scale, scaleWidth)
289287
rescale_inputs = create_const_ops_for_rescale(
290288
tosa_fb,
291289
is_scale32,
292290
input_node.dtype,
293-
input_name,
291+
output_name,
294292
multipliers,
295293
shifts,
296294
input_zp,

0 commit comments

Comments
 (0)