Skip to content

Commit cc73c7e

Browse files
authored
Revert "Arm backend: Update more node visitors to support TOSA 1.0" (#10455)
Reverts #10425
1 parent 1432243 commit cc73c7e

32 files changed

+236
-1519
lines changed

backends/arm/_passes/insert_table_ops.py

-2
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,6 @@ class TableOps:
4848
exir_ops.edge.aten.reciprocal.default: torch.reciprocal,
4949
exir_ops.edge.aten.rsqrt.default: torch.rsqrt,
5050
exir_ops.edge.aten.sigmoid.default: torch.sigmoid,
51-
exir_ops.edge.aten.cos.default: torch.cos,
52-
exir_ops.edge.aten.sin.default: torch.sin,
5351
exir_ops.edge.aten.tanh.default: torch.tanh,
5452
exir_ops.edge.aten.hardsigmoid.default: torch.nn.functional.hardsigmoid,
5553
exir_ops.edge.aten.hardswish.default: torch.nn.functional.hardswish,

backends/arm/operator_support/__init__.py

-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
pool_2d_support,
1313
reduce_sum_support,
1414
right_shift_support,
15-
sin_cos_support,
1615
slice_copy_support,
1716
to_copy_support,
1817
tosa_supported_operators,

backends/arm/operator_support/sin_cos_support.py

-32
This file was deleted.

backends/arm/operator_support/tosa_supported_operators.py

+2-8
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,7 @@
2323
EthosU55NotSupported,
2424
EthosU55TransposeCheck,
2525
)
26-
from executorch.backends.arm.tosa_specification import (
27-
Tosa_0_80,
28-
Tosa_1_00,
29-
TosaSpecification,
30-
)
26+
from executorch.backends.arm.tosa_specification import Tosa_0_80, TosaSpecification
3127
from executorch.exir import ExportedProgram
3228
from executorch.exir.backend.utils import WhyNoPartitionReporter
3329
from executorch.exir.dialects._ops import ops as exir_ops
@@ -128,9 +124,7 @@ def tosa_support_factory(
128124
if not tosa_spec.support_float():
129125
negative_checks.append(NeedsDecompositionCheck(reporter))
130126
negative_checks.append(CheckProperQuantization(reporter))
131-
if (isinstance(tosa_spec, Tosa_0_80) and tosa_spec.is_U55_subset) or (
132-
isinstance(tosa_spec, Tosa_1_00) and "u55" in tosa_spec.extensions
133-
):
127+
if isinstance(tosa_spec, Tosa_0_80) and tosa_spec.is_U55_subset:
134128
negative_checks.append(EthosU55NotSupported(reporter))
135129
negative_checks.append(EthosU55DtypeSupport(reporter))
136130
negative_checks.append(EthosU55TransposeCheck(reporter))

backends/arm/operators/__init__.py

-2
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
op_clamp,
1919
op_constant_pad_nd,
2020
op_conv2d,
21-
op_cos,
2221
op_eq,
2322
op_erf,
2423
op_exp,
@@ -39,7 +38,6 @@
3938
op_rshift_tensor,
4039
op_rsqrt,
4140
op_sigmoid,
42-
op_sin,
4341
op_slice,
4442
op_sub,
4543
op_sum,

backends/arm/operators/op_any.py

+4-45
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@
44
# LICENSE file in the root directory of this source tree.
55

66
# pyre-unsafe
7-
from typing import Any, cast, List
7+
from typing import cast, List
88

9+
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
910
from executorch.backends.arm.operators.node_visitor import ( # type: ignore
1011
NodeVisitor,
1112
register_node_visitor,
@@ -15,59 +16,17 @@
1516
from torch.fx import Node
1617

1718

18-
@register_node_visitor
19-
class AnyVisitor_0_80(NodeVisitor):
20-
target = "aten.any.dim"
21-
22-
tosa_specs = NodeVisitor.tosa_specs_0_80
23-
24-
def define_node(
25-
self,
26-
node: Node,
27-
tosa_graph: Any,
28-
inputs: List[TosaArg],
29-
output: TosaArg,
30-
) -> None:
31-
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
32-
33-
if not (inputs[0].dtype == output.dtype):
34-
raise ValueError(
35-
"All inputs and outputs need same dtype."
36-
f"Got {ts.DTypeNames[inputs[0].dtype]=}, {ts.DTypeNames[output.dtype]=}."
37-
)
38-
if not (inputs[0].dtype == ts.DType.BOOL):
39-
raise ValueError("All inputs need to be BOOL." f"Got {inputs[0].dtype=}")
40-
41-
input_shape = list(inputs[0].shape)
42-
dim = cast(int, inputs[1].number) % len(
43-
input_shape
44-
) # process the negative index
45-
keep_dim = cast(bool, inputs[2].number if len(inputs) > 2 else False)
46-
if not keep_dim:
47-
raise ValueError("This case should be handled by ConvertAnyDimDimsPass")
48-
49-
attr = ts.TosaSerializerAttribute()
50-
attr.AxisAttribute(inputs[0].dim_order.index(dim))
51-
52-
tosa_graph.addOperator(
53-
ts.TosaOp.Op().REDUCE_ANY, [inputs[0].name], [output.name], attr
54-
)
55-
56-
5719
@register_node_visitor
5820
class AnyVisitor(NodeVisitor):
5921
target = "aten.any.dim"
6022

61-
tosa_specs = NodeVisitor.tosa_specs_1_00
62-
6323
def define_node(
6424
self,
6525
node: Node,
66-
tosa_graph: Any,
26+
tosa_graph: ts.TosaSerializer,
6727
inputs: List[TosaArg],
6828
output: TosaArg,
6929
) -> None:
70-
import serializer.tosa_serializer as ts
7130

7231
if not (inputs[0].dtype == output.dtype):
7332
raise ValueError(
@@ -86,7 +45,7 @@ def define_node(
8645
raise ValueError("This case should be handled by ConvertAnyDimDimsPass")
8746

8847
attr = ts.TosaSerializerAttribute()
89-
attr.ReduceAnyAttribute(inputs[0].dim_order.index(dim))
48+
attr.AxisAttribute(inputs[0].dim_order.index(dim))
9049

9150
tosa_graph.addOperator(
9251
ts.TosaOp.Op().REDUCE_ANY, [inputs[0].name], [output.name], attr

backends/arm/operators/op_avg_pool2d.py

+7-134
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,12 @@
44
# LICENSE file in the root directory of this source tree.
55

66
# pyre-unsafe
7-
from typing import Any, List
7+
from typing import List
88

99
import torch
1010

11+
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
12+
1113
from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
1214
get_input_qparams,
1315
get_output_qparams,
@@ -34,16 +36,14 @@ def __init__(self, *args):
3436
def _build_generic_avgpool2d(
3537
self,
3638
node: torch.fx.Node,
37-
tosa_graph: Any,
39+
tosa_graph: ts.TosaSerializer,
3840
inputs: List[TosaArg],
3941
output: TosaArg,
4042
input_zp: int,
4143
output_zp: int,
42-
accumulator_type: Any,
44+
accumulator_type: ts.DType,
4345
) -> None:
4446

45-
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
46-
4747
input_tensor = inputs[0]
4848
kernel_size_list = inputs[1].special
4949
stride_size_list = inputs[2].special
@@ -79,12 +79,10 @@ def _build_generic_avgpool2d(
7979
def define_node(
8080
self,
8181
node: torch.fx.Node,
82-
tosa_graph: Any,
82+
tosa_graph: ts.TosaSerializer,
8383
inputs: List[TosaArg],
8484
output: TosaArg,
8585
) -> None:
86-
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
87-
8886
input_tensor = inputs[0]
8987
assert input_tensor.dtype == ts.DType.INT8
9088

@@ -112,135 +110,10 @@ class AvgPool2dVisitor_0_80_MI(AvgPool2dVisitor_0_80_BI):
112110
def define_node(
113111
self,
114112
node: torch.fx.Node,
115-
tosa_graph: Any,
113+
tosa_graph: ts.TosaSerializer,
116114
inputs: List[TosaArg],
117115
output: TosaArg,
118116
) -> None:
119-
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
120-
121-
assert (
122-
inputs[0].dtype == ts.DType.INT8 or inputs[0].dtype == ts.DType.FP32
123-
), "Only FP32 and INT8 supported"
124-
125-
if inputs[0].dtype == ts.DType.INT8:
126-
super().define_node(node, tosa_graph, inputs, output)
127-
128-
if inputs[0].dtype == ts.DType.FP32:
129-
accumulator_type = ts.DType.FP32
130-
# Initilize zero point to zero.
131-
input_zp = 0
132-
output_zp = 0
133-
134-
self._build_generic_avgpool2d(
135-
node, tosa_graph, inputs, output, input_zp, output_zp, accumulator_type
136-
)
137-
138-
139-
@register_node_visitor
140-
class AvgPool2dVisitor(NodeVisitor):
141-
target = "aten.avg_pool2d.default"
142-
143-
tosa_specs = [
144-
TosaSpecification.create_from_string("TOSA-1.0+INT"),
145-
]
146-
147-
def __init__(self, *args):
148-
super().__init__(*args)
149-
150-
def _build_generic_avgpool2d(
151-
self,
152-
node: torch.fx.Node,
153-
tosa_graph: Any,
154-
inputs: List[TosaArg],
155-
output: TosaArg,
156-
input_zp: int,
157-
output_zp: int,
158-
accumulator_type: Any,
159-
) -> None:
160-
161-
import serializer.tosa_serializer as ts # type: ignore
162-
163-
input_tensor = inputs[0]
164-
kernel_size_list = inputs[1].special
165-
stride_size_list = inputs[2].special
166-
167-
try:
168-
pad_size_list = inputs[3].special
169-
pad_size_list = [
170-
pad_size_list[0],
171-
pad_size_list[0],
172-
pad_size_list[1],
173-
pad_size_list[1],
174-
]
175-
except IndexError:
176-
pad_size_list = [0, 0, 0, 0]
177-
178-
attr = ts.TosaSerializerAttribute()
179-
attr.AvgPool2dAttribute(
180-
kernel=kernel_size_list,
181-
stride=stride_size_list,
182-
pad=pad_size_list,
183-
acc_type=accumulator_type,
184-
)
185-
input_zp_tensor = tosa_graph.addConst(
186-
shape=[1], dtype=output.dtype, vals=[input_zp]
187-
)
188-
output_zp_tensor = tosa_graph.addConst(
189-
shape=[1], dtype=output.dtype, vals=[output_zp]
190-
)
191-
192-
tosa_graph.addOperator(
193-
ts.TosaOp.Op().AVG_POOL2D,
194-
[input_tensor.name, input_zp_tensor.name, output_zp_tensor.name],
195-
[output.name],
196-
attr,
197-
)
198-
199-
def define_node(
200-
self,
201-
node: torch.fx.Node,
202-
tosa_graph: Any,
203-
inputs: List[TosaArg],
204-
output: TosaArg,
205-
) -> None:
206-
import serializer.tosa_serializer as ts # type: ignore
207-
208-
input_tensor = inputs[0]
209-
assert input_tensor.dtype == ts.DType.INT8
210-
211-
accumulator_type = ts.DType.INT32
212-
213-
input_qargs = get_input_qparams(node)
214-
input_zp = input_qargs[0].zp
215-
216-
output_qargs = get_output_qparams(node)
217-
output_zp = output_qargs[0].zp
218-
219-
self._build_generic_avgpool2d(
220-
node, tosa_graph, inputs, output, input_zp, output_zp, accumulator_type
221-
)
222-
223-
224-
@register_node_visitor
225-
class AvgPool2dVisitor_FP(AvgPool2dVisitor):
226-
target = "aten.avg_pool2d.default"
227-
228-
tosa_specs = [
229-
TosaSpecification.create_from_string("TOSA-1.0+FP"),
230-
]
231-
232-
def __init__(self, *args):
233-
super().__init__(*args)
234-
235-
def define_node(
236-
self,
237-
node: torch.fx.Node,
238-
tosa_graph: Any,
239-
inputs: List[TosaArg],
240-
output: TosaArg,
241-
) -> None:
242-
import serializer.tosa_serializer as ts # type: ignore
243-
244117
assert (
245118
inputs[0].dtype == ts.DType.INT8 or inputs[0].dtype == ts.DType.FP32
246119
), "Only FP32 and INT8 supported"

0 commit comments

Comments
 (0)