4
4
# LICENSE file in the root directory of this source tree.
5
5
6
6
# pyre-unsafe
7
- from typing import Any , List
7
+ from typing import List
8
8
9
9
import torch
10
10
11
+ import tosa_tools .v0_80 .serializer .tosa_serializer as ts # type: ignore
12
+
11
13
from executorch .backends .arm ._passes .fold_qdq_with_annotated_qparams_pass import (
12
14
get_input_qparams ,
13
15
get_output_qparams ,
@@ -34,16 +36,14 @@ def __init__(self, *args):
34
36
def _build_generic_avgpool2d (
35
37
self ,
36
38
node : torch .fx .Node ,
37
- tosa_graph : Any ,
39
+ tosa_graph : ts . TosaSerializer ,
38
40
inputs : List [TosaArg ],
39
41
output : TosaArg ,
40
42
input_zp : int ,
41
43
output_zp : int ,
42
- accumulator_type : Any ,
44
+ accumulator_type : ts . DType ,
43
45
) -> None :
44
46
45
- import tosa_tools .v0_80 .serializer .tosa_serializer as ts # type: ignore
46
-
47
47
input_tensor = inputs [0 ]
48
48
kernel_size_list = inputs [1 ].special
49
49
stride_size_list = inputs [2 ].special
@@ -79,12 +79,10 @@ def _build_generic_avgpool2d(
79
79
def define_node (
80
80
self ,
81
81
node : torch .fx .Node ,
82
- tosa_graph : Any ,
82
+ tosa_graph : ts . TosaSerializer ,
83
83
inputs : List [TosaArg ],
84
84
output : TosaArg ,
85
85
) -> None :
86
- import tosa_tools .v0_80 .serializer .tosa_serializer as ts # type: ignore
87
-
88
86
input_tensor = inputs [0 ]
89
87
assert input_tensor .dtype == ts .DType .INT8
90
88
@@ -112,135 +110,10 @@ class AvgPool2dVisitor_0_80_MI(AvgPool2dVisitor_0_80_BI):
112
110
def define_node (
113
111
self ,
114
112
node : torch .fx .Node ,
115
- tosa_graph : Any ,
113
+ tosa_graph : ts . TosaSerializer ,
116
114
inputs : List [TosaArg ],
117
115
output : TosaArg ,
118
116
) -> None :
119
- import tosa_tools .v0_80 .serializer .tosa_serializer as ts # type: ignore
120
-
121
- assert (
122
- inputs [0 ].dtype == ts .DType .INT8 or inputs [0 ].dtype == ts .DType .FP32
123
- ), "Only FP32 and INT8 supported"
124
-
125
- if inputs [0 ].dtype == ts .DType .INT8 :
126
- super ().define_node (node , tosa_graph , inputs , output )
127
-
128
- if inputs [0 ].dtype == ts .DType .FP32 :
129
- accumulator_type = ts .DType .FP32
130
- # Initilize zero point to zero.
131
- input_zp = 0
132
- output_zp = 0
133
-
134
- self ._build_generic_avgpool2d (
135
- node , tosa_graph , inputs , output , input_zp , output_zp , accumulator_type
136
- )
137
-
138
-
139
- @register_node_visitor
140
- class AvgPool2dVisitor (NodeVisitor ):
141
- target = "aten.avg_pool2d.default"
142
-
143
- tosa_specs = [
144
- TosaSpecification .create_from_string ("TOSA-1.0+INT" ),
145
- ]
146
-
147
- def __init__ (self , * args ):
148
- super ().__init__ (* args )
149
-
150
- def _build_generic_avgpool2d (
151
- self ,
152
- node : torch .fx .Node ,
153
- tosa_graph : Any ,
154
- inputs : List [TosaArg ],
155
- output : TosaArg ,
156
- input_zp : int ,
157
- output_zp : int ,
158
- accumulator_type : Any ,
159
- ) -> None :
160
-
161
- import serializer .tosa_serializer as ts # type: ignore
162
-
163
- input_tensor = inputs [0 ]
164
- kernel_size_list = inputs [1 ].special
165
- stride_size_list = inputs [2 ].special
166
-
167
- try :
168
- pad_size_list = inputs [3 ].special
169
- pad_size_list = [
170
- pad_size_list [0 ],
171
- pad_size_list [0 ],
172
- pad_size_list [1 ],
173
- pad_size_list [1 ],
174
- ]
175
- except IndexError :
176
- pad_size_list = [0 , 0 , 0 , 0 ]
177
-
178
- attr = ts .TosaSerializerAttribute ()
179
- attr .AvgPool2dAttribute (
180
- kernel = kernel_size_list ,
181
- stride = stride_size_list ,
182
- pad = pad_size_list ,
183
- acc_type = accumulator_type ,
184
- )
185
- input_zp_tensor = tosa_graph .addConst (
186
- shape = [1 ], dtype = output .dtype , vals = [input_zp ]
187
- )
188
- output_zp_tensor = tosa_graph .addConst (
189
- shape = [1 ], dtype = output .dtype , vals = [output_zp ]
190
- )
191
-
192
- tosa_graph .addOperator (
193
- ts .TosaOp .Op ().AVG_POOL2D ,
194
- [input_tensor .name , input_zp_tensor .name , output_zp_tensor .name ],
195
- [output .name ],
196
- attr ,
197
- )
198
-
199
- def define_node (
200
- self ,
201
- node : torch .fx .Node ,
202
- tosa_graph : Any ,
203
- inputs : List [TosaArg ],
204
- output : TosaArg ,
205
- ) -> None :
206
- import serializer .tosa_serializer as ts # type: ignore
207
-
208
- input_tensor = inputs [0 ]
209
- assert input_tensor .dtype == ts .DType .INT8
210
-
211
- accumulator_type = ts .DType .INT32
212
-
213
- input_qargs = get_input_qparams (node )
214
- input_zp = input_qargs [0 ].zp
215
-
216
- output_qargs = get_output_qparams (node )
217
- output_zp = output_qargs [0 ].zp
218
-
219
- self ._build_generic_avgpool2d (
220
- node , tosa_graph , inputs , output , input_zp , output_zp , accumulator_type
221
- )
222
-
223
-
224
- @register_node_visitor
225
- class AvgPool2dVisitor_FP (AvgPool2dVisitor ):
226
- target = "aten.avg_pool2d.default"
227
-
228
- tosa_specs = [
229
- TosaSpecification .create_from_string ("TOSA-1.0+FP" ),
230
- ]
231
-
232
- def __init__ (self , * args ):
233
- super ().__init__ (* args )
234
-
235
- def define_node (
236
- self ,
237
- node : torch .fx .Node ,
238
- tosa_graph : Any ,
239
- inputs : List [TosaArg ],
240
- output : TosaArg ,
241
- ) -> None :
242
- import serializer .tosa_serializer as ts # type: ignore
243
-
244
117
assert (
245
118
inputs [0 ].dtype == ts .DType .INT8 or inputs [0 ].dtype == ts .DType .FP32
246
119
), "Only FP32 and INT8 supported"
0 commit comments