@@ -277,17 +277,29 @@ def define_node(
277
277
input_qparams = get_input_qparams (node )
278
278
input_zp = input_qparams [0 ].zp
279
279
280
- tosa_graph .addConst ([1 ], output .dtype , [input_zp ], name = f"{ node .name } _input_zp" )
281
- tosa_graph .addConst ([1 ], output .dtype , [0 ], name = f"{ node .name } _weight_zp" )
280
+ # The output type is int32 when input type is int8.
281
+ conv2d_output_name = output .name
282
+ if output .dtype == ts .DType .INT8 :
283
+ conv2d_res = tosa_graph .addIntermediate (
284
+ tosa_shape (output .shape , output .dim_order ), ts .DType .INT32
285
+ )
286
+ conv2d_output_name = conv2d_res .name
282
287
acc_type = (
283
288
inputs [0 ].dtype if inputs [0 ].dtype == ts .DType .FP32 else ts .DType .INT32
284
289
)
285
290
291
+ tosa_graph .addConst (
292
+ [1 ], output .dtype , [input_zp ], name = f"{ conv2d_output_name } _input_zp"
293
+ )
294
+ tosa_graph .addConst (
295
+ [1 ], output .dtype , [0 ], name = f"{ conv2d_output_name } _weight_zp"
296
+ )
297
+
286
298
# Non-bias case.
287
299
if len (node .all_input_nodes ) == 2 :
288
300
# Create a zero bias tensor if not presented
289
301
out_channels = weight .shape [0 ]
290
- bias_name = "bias" + node . name . split ( "default" , 1 )[ 1 ]
302
+ bias_name = f" { conv2d_output_name } _bias"
291
303
bias_type = output .dtype
292
304
if output .dtype == ts .DType .INT8 :
293
305
# Conv is quantized to int8, but the TOSA operator has
@@ -301,14 +313,6 @@ def define_node(
301
313
name = bias_name ,
302
314
)
303
315
304
- # The output type is int32 when input type is int8.
305
- conv2d_output_name = output .name
306
- if output .dtype == ts .DType .INT8 :
307
- conv2d_res = tosa_graph .addIntermediate (
308
- tosa_shape (output .shape , output .dim_order ), ts .DType .INT32
309
- )
310
- conv2d_output_name = conv2d_res .name
311
-
312
316
# Given input.shape is (N, Ci, H, W), and weight.shape is (Co, Ci/G, H, W)
313
317
in_channels = input .shape [1 ]
314
318
out_channels = weight .shape [0 ]
@@ -373,8 +377,8 @@ def define_node(
373
377
input .name ,
374
378
weight_name ,
375
379
bias .name ,
376
- f"{ node . name } _input_zp" ,
377
- f"{ node . name } _weight_zp" ,
380
+ f"{ conv2d_output_name } _input_zp" ,
381
+ f"{ conv2d_output_name } _weight_zp" ,
378
382
],
379
383
[conv2d_output_name ],
380
384
attr ,
0 commit comments