Skip to content

Commit 9658ffd

Browse files
Sebastian-LarssonHaiting Pu
authored and
Haiting Pu
committed
Arm backend: Create op utility function for num input verification (#10713)
Create function for validating that the number of inputs is correct in each node visitor. Signed-off-by: Sebastian Larsson <[email protected]>
1 parent 9a7f9fa commit 9658ffd

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

49 files changed

+409
-85
lines changed

backends/arm/operators/op_abs.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@
1313
NodeVisitor,
1414
register_node_visitor,
1515
)
16+
from executorch.backends.arm.operators.operator_validation_utils import (
17+
validate_num_inputs,
18+
)
1619
from executorch.backends.arm.tosa_mapping import TosaArg
1720
from executorch.backends.arm.tosa_specification import TosaSpecification
1821
from torch.fx import Node
@@ -39,6 +42,7 @@ def define_node(
3942

4043
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
4144

45+
validate_num_inputs(self.target, inputs, 1)
4246
# Specification (0.80) states that input and output types
4347
# should all be the same
4448
if not (inputs[0].dtype == output.dtype):
@@ -105,6 +109,7 @@ def define_node(
105109

106110
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
107111

112+
validate_num_inputs(self.target, inputs, 1)
108113
# Specification (0.80) states that input and output types
109114
# should all be the same
110115
if not (inputs[0].dtype == output.dtype):
@@ -157,6 +162,8 @@ def define_node(
157162

158163
import serializer.tosa_serializer as ts # type: ignore
159164

165+
validate_num_inputs(self.target, inputs, 1)
166+
160167
# Specification (1.0) states that input and output types
161168
# should all be the same
162169
if not (inputs[0].dtype == output.dtype):
@@ -224,6 +231,8 @@ def define_node(
224231

225232
import serializer.tosa_serializer as ts # type: ignore
226233

234+
validate_num_inputs(self.target, inputs, 1)
235+
227236
# Specification (1.0) states that input and output types
228237
# should all be the same
229238
if not (inputs[0].dtype == output.dtype):

backends/arm/operators/op_add.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@
1414
NodeVisitor,
1515
register_node_visitor,
1616
)
17+
from executorch.backends.arm.operators.operator_validation_utils import (
18+
validate_num_inputs,
19+
)
1720
from executorch.backends.arm.tosa_mapping import TosaArg
1821
from executorch.backends.arm.tosa_specification import TosaSpecification
1922
from torch.fx import Node
@@ -40,6 +43,7 @@ def define_node(
4043

4144
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
4245

46+
validate_num_inputs(self.target, inputs, 2)
4347
# Specification (0.80) states that input and output types
4448
# should all be the same
4549
if inputs[0].dtype != inputs[1].dtype or inputs[0].dtype != output.dtype:
@@ -118,6 +122,7 @@ def define_node(
118122

119123
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
120124

125+
validate_num_inputs(self.target, inputs, 2)
121126
# Specification (0.80) states that input and output types
122127
# should all be the same
123128
if inputs[0].dtype != inputs[1].dtype or inputs[0].dtype != output.dtype:
@@ -169,6 +174,8 @@ def define_node(
169174

170175
import serializer.tosa_serializer as ts # type: ignore
171176

177+
validate_num_inputs(self.target, inputs, 2)
178+
172179
# Specification (1.0) states that input and output types
173180
# should all be the same
174181
if inputs[0].dtype != inputs[1].dtype or inputs[0].dtype != output.dtype:
@@ -237,6 +244,8 @@ def define_node(
237244

238245
import serializer.tosa_serializer as ts # type: ignore
239246

247+
validate_num_inputs(self.target, inputs, 2)
248+
240249
# Specification (1.0) states that input and output types
241250
# should all be the same
242251
if inputs[0].dtype != inputs[1].dtype or inputs[0].dtype != output.dtype:

backends/arm/operators/op_amax.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@
99
NodeVisitor,
1010
register_node_visitor,
1111
)
12+
from executorch.backends.arm.operators.operator_validation_utils import (
13+
validate_num_inputs,
14+
)
1215
from executorch.backends.arm.tosa_mapping import TosaArg
1316
from torch.fx import Node
1417

@@ -31,6 +34,8 @@ def define_node(
3134
) -> None:
3235
import tosa_tools.v0_80.serializer.tosa_serializer as ts
3336

37+
validate_num_inputs(self.target, inputs, 3)
38+
3439
input = inputs[0]
3540
dim = inputs[1].number
3641

@@ -71,6 +76,8 @@ def define_node(
7176
) -> None:
7277
import serializer.tosa_serializer as ts
7378

79+
validate_num_inputs(self.target, inputs, 3)
80+
7481
input = inputs[0]
7582
dim = inputs[1].number
7683

backends/arm/operators/op_amin.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@
99
NodeVisitor,
1010
register_node_visitor,
1111
)
12+
from executorch.backends.arm.operators.operator_validation_utils import (
13+
validate_num_inputs,
14+
)
1215
from executorch.backends.arm.tosa_mapping import TosaArg
1316
from torch.fx import Node
1417

@@ -31,6 +34,8 @@ def define_node(
3134
) -> None:
3235
import tosa_tools.v0_80.serializer.tosa_serializer as ts
3336

37+
validate_num_inputs(self.target, inputs, 3)
38+
3439
input = inputs[0]
3540
dim = inputs[1].number
3641

@@ -71,6 +76,8 @@ def define_node(
7176
) -> None:
7277
import serializer.tosa_serializer as ts
7378

79+
validate_num_inputs(self.target, inputs, 3)
80+
7481
input = inputs[0]
7582
dim = inputs[1].number
7683

backends/arm/operators/op_any.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@
1010
NodeVisitor,
1111
register_node_visitor,
1212
)
13+
from executorch.backends.arm.operators.operator_validation_utils import (
14+
validate_num_inputs,
15+
)
1316

1417
from executorch.backends.arm.tosa_mapping import TosaArg # type: ignore
1518
from torch.fx import Node
@@ -30,6 +33,8 @@ def define_node(
3033
) -> None:
3134
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
3235

36+
validate_num_inputs(self.target, inputs, 3)
37+
3338
if not (inputs[0].dtype == output.dtype):
3439
raise ValueError(
3540
"All inputs and outputs need same dtype."
@@ -69,6 +74,8 @@ def define_node(
6974
) -> None:
7075
import serializer.tosa_serializer as ts
7176

77+
validate_num_inputs(self.target, inputs, 3)
78+
7279
if not (inputs[0].dtype == output.dtype):
7380
raise ValueError(
7481
"All inputs and outputs need same dtype."

backends/arm/operators/op_avg_pool2d.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@
1616
NodeVisitor,
1717
register_node_visitor,
1818
)
19+
from executorch.backends.arm.operators.operator_validation_utils import (
20+
validate_num_inputs,
21+
)
1922
from executorch.backends.arm.tosa_mapping import TosaArg
2023
from executorch.backends.arm.tosa_specification import TosaSpecification
2124

@@ -85,6 +88,8 @@ def define_node(
8588
) -> None:
8689
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
8790

91+
validate_num_inputs(self.target, inputs, [3, 4, 6])
92+
8893
supported_dtypes = [ts.DType.INT8]
8994
if inputs[0].dtype not in supported_dtypes:
9095
raise TypeError(
@@ -122,6 +127,8 @@ def define_node(
122127
) -> None:
123128
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
124129

130+
validate_num_inputs(self.target, inputs, [3, 4, 6])
131+
125132
supported_dtypes = [ts.DType.INT8, ts.DType.FP32]
126133
if inputs[0].dtype not in supported_dtypes:
127134
raise TypeError(
@@ -212,6 +219,8 @@ def define_node(
212219
) -> None:
213220
import serializer.tosa_serializer as ts # type: ignore
214221

222+
validate_num_inputs(self.target, inputs, [3, 4, 6])
223+
215224
supported_dtypes = [ts.DType.INT8]
216225
if inputs[0].dtype not in supported_dtypes:
217226
raise TypeError(
@@ -252,6 +261,8 @@ def define_node(
252261
) -> None:
253262
import serializer.tosa_serializer as ts # type: ignore
254263

264+
validate_num_inputs(self.target, inputs, [3, 4, 6])
265+
255266
supported_dtypes = [ts.DType.INT8, ts.DType.FP32]
256267
if inputs[0].dtype not in supported_dtypes:
257268
raise TypeError(

backends/arm/operators/op_bmm.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@
1717
NodeVisitor,
1818
register_node_visitor,
1919
)
20-
20+
from executorch.backends.arm.operators.operator_validation_utils import (
21+
validate_num_inputs,
22+
)
2123
from executorch.backends.arm.tosa_mapping import TosaArg
2224
from executorch.backends.arm.tosa_quant_utils import build_rescale, build_rescale_v0_80
2325
from executorch.backends.arm.tosa_specification import TosaSpecification
@@ -46,6 +48,7 @@ def define_node(
4648

4749
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
4850

51+
validate_num_inputs(self.target, inputs, 2)
4952
if inputs[0].dtype != inputs[1].dtype or inputs[0].dtype != output.dtype:
5053
raise TypeError(
5154
f"All IO needs to have the same data type, got: "
@@ -128,6 +131,8 @@ def define_node(
128131

129132
import serializer.tosa_serializer as ts # type: ignore
130133

134+
validate_num_inputs(self.target, inputs, 2)
135+
131136
if inputs[0].dtype != inputs[1].dtype or inputs[0].dtype != output.dtype:
132137
raise TypeError(
133138
f"All IO needs to have the same data type, got: "

backends/arm/operators/op_cat.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@
1111
NodeVisitor,
1212
register_node_visitor,
1313
)
14+
from executorch.backends.arm.operators.operator_validation_utils import (
15+
validate_num_inputs,
16+
)
1417
from executorch.backends.arm.tosa_mapping import TosaArg
1518
from torch.fx import Node
1619

@@ -33,6 +36,8 @@ def define_node(
3336
) -> None:
3437
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
3538

39+
validate_num_inputs(self.target, inputs, [1, 2])
40+
3641
tensors = inputs[0].special
3742
dim = 0 if len(inputs) < 2 else inputs[1].number
3843
rank = len(output.shape)
@@ -68,6 +73,8 @@ def define_node(
6873
) -> None:
6974
import serializer.tosa_serializer as ts
7075

76+
validate_num_inputs(self.target, inputs, [1, 2])
77+
7178
tensors = inputs[0].special
7279
dim = 0 if len(inputs) < 2 else inputs[1].number
7380
rank = len(output.shape)

backends/arm/operators/op_clamp.py

Lines changed: 7 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@
1515
NodeVisitor,
1616
register_node_visitor,
1717
)
18+
from executorch.backends.arm.operators.operator_validation_utils import (
19+
validate_num_inputs,
20+
)
1821

1922
from executorch.backends.arm.tosa_mapping import TosaArg
2023
from executorch.backends.arm.tosa_specification import TosaSpecification
@@ -65,9 +68,6 @@ def cast_type(value: Any) -> int | float:
6568
# Attempt to cast to float
6669
return float(value)
6770

68-
if len(node.args) != 2 and len(node.args) != 3:
69-
raise ValueError(f"Expected len(node.args) to be 2 or 3, got {node.args}")
70-
7171
min_arg = dtype_min
7272
max_arg = dtype_max
7373

@@ -87,10 +87,7 @@ def define_node(
8787
inputs: List[TosaArg],
8888
output: TosaArg,
8989
) -> None:
90-
if len(node.all_input_nodes) != 1:
91-
raise ValueError(
92-
f"Expected 1 input for {self.target}, got {len(node.all_input_nodes)}"
93-
)
90+
validate_num_inputs(self.target, inputs, [2, 3])
9491

9592
min_int8, max_int8 = self._get_min_max_arguments(
9693
node,
@@ -130,10 +127,7 @@ def define_node(
130127
) -> None:
131128
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
132129

133-
if len(node.all_input_nodes) != 1:
134-
raise ValueError(
135-
f"Expected 1 input for {self.target}, got {len(node.all_input_nodes)}"
136-
)
130+
validate_num_inputs(self.target, inputs, [2, 3])
137131

138132
if inputs[0].dtype == ts.DType.INT8:
139133
# Call the inherited define_node for handling integers
@@ -178,9 +172,6 @@ def cast_type(value: Any) -> int | float:
178172
# Attempt to cast to float
179173
return float(value)
180174

181-
if len(node.args) != 2 and len(node.args) != 3:
182-
raise ValueError(f"Expected len(node.args) to be 2 or 3, got {node.args}")
183-
184175
min_arg = dtype_min
185176
max_arg = dtype_max
186177

@@ -202,10 +193,7 @@ def define_node(
202193
) -> None:
203194
import serializer.tosa_serializer as ts # type: ignore
204195

205-
if len(node.all_input_nodes) != 1:
206-
raise ValueError(
207-
f"Expected 1 input for {self.target}, got {len(node.all_input_nodes)}"
208-
)
196+
validate_num_inputs(self.target, inputs, [2, 3])
209197

210198
# NOTE: Quantization of the min/max arguments is handled by QuantizeOperatorArguments
211199
min_int8, max_int8 = self._get_min_max_arguments(
@@ -247,10 +235,7 @@ def define_node(
247235
) -> None:
248236
import serializer.tosa_serializer as ts # type: ignore
249237

250-
if len(node.all_input_nodes) != 1:
251-
raise ValueError(
252-
f"Expected 1 input for {self.target}, got {len(node.all_input_nodes)}"
253-
)
238+
validate_num_inputs(self.target, inputs, [2, 3])
254239

255240
min_fp32, max_fp32 = self._get_min_max_arguments(
256241
node,

backends/arm/operators/op_constant_pad_nd.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@
1616
NodeVisitor,
1717
register_node_visitor,
1818
)
19+
from executorch.backends.arm.operators.operator_validation_utils import (
20+
validate_num_inputs,
21+
)
1922
from executorch.backends.arm.tosa_mapping import TosaArg
2023
from executorch.backends.arm.tosa_specification import TosaSpecification
2124

@@ -39,6 +42,8 @@ def define_node(
3942
) -> None:
4043
import tosa_tools.v0_80.serializer.tosa_serializer as ts
4144

45+
validate_num_inputs(self.target, inputs, 3)
46+
4247
if inputs[0].dtype == ts.DType.INT8:
4348
input_qparams = get_input_qparams(node)
4449
qargs = input_qparams[0]
@@ -98,9 +103,10 @@ def define_node(
98103
inputs: List[TosaArg],
99104
output: TosaArg,
100105
) -> None:
101-
102106
import serializer.tosa_serializer as ts # type: ignore
103107

108+
validate_num_inputs(self.target, inputs, 3)
109+
104110
if inputs[0].dtype == ts.DType.INT8:
105111
input_qparams = get_input_qparams(node)
106112
qargs = input_qparams[0]

0 commit comments

Comments
 (0)