Skip to content

Commit e7f33c9

Browse files
authored
Merge branch 'main' into fix_logging
2 parents 91f9d91 + bd98b5e commit e7f33c9

32 files changed

+1269
-751
lines changed

.github/release.yml

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
# .github/release.yml
2+
3+
changelog:
4+
exclude:
5+
labels:
6+
- ignore-for-release
7+
categories:
8+
- title: Breaking Changes 🛠
9+
labels:
10+
- Semver-Major
11+
- breaking-change
12+
- title: API
13+
labels:
14+
- "release notes: api"
15+
- title: ARM
16+
labels:
17+
- "release notes: arm"
18+
- title: NXP
19+
labels:
20+
- "release notes: nxp"
21+
- title: Exir
22+
labels:
23+
- "release notes: exir"
24+
- title: Misc
25+
labels:
26+
- "release notes: misc"
27+
- title: Apple
28+
labels:
29+
- "release notes: apple"
30+
- title: Build
31+
labels:
32+
- "release notes: build"
33+
- title: Vulkan
34+
labels:
35+
- "release notes: vulkan"
36+
- title: Cadence
37+
labels:
38+
- "release notes: cadence"
39+
- title: Runtime
40+
labels:
41+
- "release notes: runtime"
42+
- title: XNNPACK
43+
labels:
44+
- "release notes: xnnpack"
45+
- title: Devtools
46+
labels:
47+
- "release notes: devtools"
48+
- title: Examples
49+
labels:
50+
- "release notes: examples"
51+
- title: Mediatek
52+
labels:
53+
- "release notes: mediatek"
54+
- title: Openvino
55+
labels:
56+
- "release notes: openvino"
57+
- title: Qualcomm
58+
labels:
59+
- "release notes: qualcomm"
60+
- title: Training
61+
labels:
62+
- "release notes: training"
63+
- title: Quantization
64+
labels:
65+
- "release notes: quantization"
66+
- title: Ops & kernels
67+
labels:
68+
- "release notes: ops & kernels"
69+
- title: Other Changes
70+
labels:
71+
- "*"

backends/arm/_passes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from .decompose_batchnorm_pass import DecomposeBatchNormPass # noqa
2222
from .decompose_div_pass import DecomposeDivPass # noqa
2323
from .decompose_layernorm_pass import DecomposeLayerNormPass # noqa
24+
from .decompose_leaky_relu_pass import DecomposeLeakyReLUPass # noqa
2425
from .decompose_linear_pass import DecomposeLinearPass # noqa
2526
from .decompose_meandim_pass import DecomposeMeanDimPass # noqa
2627
from .decompose_select import DecomposeSelectPass # noqa

backends/arm/_passes/arm_pass_manager.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
DecomposeBatchNormPass,
2727
DecomposeDivPass,
2828
DecomposeLayerNormPass,
29+
DecomposeLeakyReLUPass,
2930
DecomposeLinearPass,
3031
DecomposeMeanDimPass,
3132
DecomposeSelectPass,
@@ -121,6 +122,7 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
121122
self.add_pass(FuseBatchnorm2DPass(exported_program))
122123
self.add_pass(ConvertMmToBmmPass())
123124
self.add_pass(DecomposeLinearPass())
125+
self.add_pass(DecomposeLeakyReLUPass())
124126
self.add_pass(DecomposeBatchNormPass())
125127
self.add_pass(DecomposeLayerNormPass())
126128
self.add_pass(DecomposeVarPass())
@@ -178,6 +180,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
178180
self.add_pass(DecomposeVarPass())
179181
self.add_pass(DecomposeMeanDimPass())
180182
self.add_pass(DecomposeDivPass())
183+
self.add_pass(DecomposeLeakyReLUPass())
181184

182185
if isinstance(self.tosa_spec, Tosa_0_80) and self.tosa_spec.is_U55_subset:
183186
# Numerically stable softmax uses amax which is not supported on Ethos-U55
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-unsafe
8+
9+
import torch
10+
from executorch.backends.arm._passes import ArmPass
11+
from executorch.exir.dialects._ops import ops as exir_ops
12+
13+
edge_ops = (exir_ops.edge.aten.leaky_relu.default,)
14+
torch_ops = (torch.ops.aten.leaky_relu.default,)
15+
16+
17+
def _get_leaky_relu_ops(op) -> tuple:
18+
if op in edge_ops:
19+
return (
20+
exir_ops.edge.aten.clamp.default,
21+
exir_ops.edge.aten.full.default,
22+
exir_ops.edge.aten.mul.Tensor,
23+
exir_ops.edge.aten.add.Tensor,
24+
)
25+
elif op in torch_ops:
26+
return (
27+
torch.ops.aten.clamp.default,
28+
torch.ops.aten.full.default,
29+
torch.ops.aten.mul.Tensor,
30+
torch.ops.aten.add.Tensor,
31+
)
32+
else:
33+
raise RuntimeError(f"Can't get decomposition ops for op {op}")
34+
35+
36+
class DecomposeLeakyReLUPass(ArmPass):
37+
"""
38+
This pass decomposes Leaky ReLU into primitive operations.
39+
LeakyReLU(x,slope) = max(0,x) + slope * min(0,x)
40+
41+
Example:
42+
%op1 = clamp(x,0,None) (equivalent to max(0,x))
43+
%op2 = clamp(x,None,0) (equivalent to min(0,x))
44+
%op3 = full(x.shape,slope)
45+
%op4 = mul(%op3,%op2)
46+
%op5 = add(%op1,%op4)
47+
"""
48+
49+
def call_operator(self, op, args, kwargs, meta):
50+
if op not in (edge_ops + torch_ops):
51+
return super().call_operator(op, args, kwargs, meta)
52+
53+
x = args[0]
54+
slope = args[1] if len(args) > 1 else 0.01
55+
dtype = x.node.meta["val"].dtype
56+
clamp, full, mul, add = _get_leaky_relu_ops(op)
57+
op1 = super().call_operator(
58+
op=clamp, args=(x, 0, None), kwargs=kwargs, meta=meta
59+
)
60+
op2 = super().call_operator(
61+
op=clamp, args=(x, None, 0), kwargs=kwargs, meta=meta
62+
)
63+
op3 = super().call_operator(
64+
op=full,
65+
args=(x.node.meta["val"].shape, slope),
66+
kwargs={"dtype": dtype},
67+
meta=meta,
68+
)
69+
op4 = super().call_operator(op=mul, args=(op3, op2), kwargs=kwargs, meta=meta)
70+
op5 = super().call_operator(op=add, args=(op1, op4), kwargs=kwargs, meta=meta)
71+
return op5

backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -136,8 +136,14 @@ def call(self, graph_module: GraphModule) -> PassResult:
136136
continue
137137

138138
# Make sure we haven't already set qparams meta information on the node
139-
assert "input_qparams" not in n.meta.keys()
140-
assert "output_qparams" not in n.meta.keys()
139+
assert "input_qparams" not in n.meta, (
140+
f'Unexpected key "input_qparams" found in meta for node {n}. '
141+
"input_qparams should not have been set at this point"
142+
)
143+
assert "output_qparams" not in n.meta, (
144+
f'Unexpected key "output_qparams" found in meta for node {n}. '
145+
"output_qparams should not have been set at this point"
146+
)
141147

142148
# for the inputs and outputs search the graph for quantization info and
143149
# store the information in a dict with order of the _tensor_ inputs as key,

backends/arm/_passes/insert_table_ops.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ class TableOps:
4141
# Targets that follow a straigtforward one-to-one mapping to their table op
4242
unary_table_ops: Dict[EdgeOpOverload, Callable[[torch.Tensor], torch.Tensor]] = {
4343
exir_ops.edge.aten.ceil.default: torch.ceil,
44+
exir_ops.edge.aten.erf.default: torch.erf,
4445
exir_ops.edge.aten.exp.default: torch.exp,
4546
exir_ops.edge.aten.floor.default: torch.floor,
4647
exir_ops.edge.aten.log.default: torch.log,

backends/arm/operator_support/tosa_supported_operators.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,7 @@ def is_node_supported(
166166
exir_ops.edge.aten.div.Tensor,
167167
exir_ops.edge.aten.eq.Tensor,
168168
exir_ops.edge.aten.eq.Scalar,
169+
exir_ops.edge.aten.erf.default,
169170
exir_ops.edge.aten.exp.default,
170171
exir_ops.edge.aten.log.default,
171172
exir_ops.edge.aten.linear.default,
@@ -192,6 +193,7 @@ def is_node_supported(
192193
exir_ops.edge.aten.repeat.default,
193194
exir_ops.edge.aten.reciprocal.default,
194195
exir_ops.edge.aten.relu.default,
196+
exir_ops.edge.aten.leaky_relu.default,
195197
exir_ops.edge.aten.rsqrt.default,
196198
exir_ops.edge.aten._softmax.default,
197199
exir_ops.edge.aten.select_copy.int,
@@ -257,6 +259,7 @@ def is_node_supported(
257259
exir_ops.edge.aten.sub.Scalar,
258260
exir_ops.edge.aten.mul.Scalar,
259261
exir_ops.edge.aten.div.Scalar,
262+
exir_ops.edge.aten.leaky_relu.default,
260263
]
261264
if needs_decomp:
262265
self.reporter.report_reject(node, "Needs to be decomposed.")

backends/arm/operators/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
op_constant_pad_nd,
2020
op_conv2d,
2121
op_eq,
22+
op_erf,
2223
op_exp,
2324
op_full,
2425
op_ge,

backends/arm/operators/op_add.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,18 @@ def define_node(
4141
) -> None:
4242
# Specification (0.80) states that input and output types
4343
# should all be the same
44-
assert inputs[0].dtype == inputs[1].dtype == output.dtype
44+
if inputs[0].dtype != inputs[1].dtype or inputs[0].dtype != output.dtype:
45+
raise TypeError(
46+
f"All IO needs to have the same data type, got input 1: "
47+
f"{inputs[0].dtype}, input 2: {inputs[1].dtype} and output: "
48+
f"{output.dtype}"
49+
)
4550
# Handle int8 (quantized) and int32
46-
assert inputs[0].dtype in [ts.DType.INT8, ts.DType.INT32]
51+
supported_dtypes = [ts.DType.INT8, ts.DType.INT32]
52+
if inputs[0].dtype not in supported_dtypes:
53+
raise TypeError(
54+
f'IO data type needs to be {supported_dtypes}, got "{inputs[0].dtype}"'
55+
)
4756

4857
dim_order = (
4958
inputs[0].dim_order
@@ -105,15 +114,22 @@ def define_node(
105114
) -> None:
106115
# Specification (0.80) states that input and output types
107116
# should all be the same
108-
assert inputs[0].dtype == inputs[1].dtype == output.dtype
117+
if inputs[0].dtype != inputs[1].dtype or inputs[0].dtype != output.dtype:
118+
raise TypeError(
119+
f"All IO needs to have the same data type, got input 1: "
120+
f"{inputs[0].dtype}, input 2: {inputs[1].dtype} and output: "
121+
f"{output.dtype}"
122+
)
109123

110124
if inputs[0].dtype in [ts.DType.INT8, ts.DType.INT32]:
111125
# Call the inherited define_node for handling integers
112126
super().define_node(node, tosa_graph, inputs, output)
113127
else:
114128
# FP32 Add lowering
115-
assert inputs[0].dtype == ts.DType.FP32
116-
assert output.dtype == ts.DType.FP32
129+
if inputs[0].dtype != ts.DType.FP32:
130+
raise TypeError(
131+
f"Expected IO data type to be FP32, got {inputs[0].dtype}"
132+
)
117133

118134
input1, input2 = tutils.reshape_for_broadcast(tosa_graph, inputs)
119135

backends/arm/operators/op_erf.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
# pyre-unsafe
6+
from typing import List
7+
8+
import serializer.tosa_serializer as ts # type: ignore
9+
import torch.fx
10+
from executorch.backends.arm.operators.node_visitor import (
11+
NodeVisitor,
12+
register_node_visitor,
13+
)
14+
from executorch.backends.arm.tosa_mapping import TosaArg
15+
from executorch.backends.arm.tosa_specification import TosaSpecification
16+
from serializer.tosa_serializer import TosaOp
17+
18+
19+
@register_node_visitor
20+
class ERFVisitor_080_MI(NodeVisitor):
21+
target = "aten.erf.default"
22+
23+
# BI case handled by op_table
24+
tosa_specs = [TosaSpecification.create_from_string("TOSA-0.80+MI")]
25+
26+
def __init__(self, *args):
27+
super().__init__(*args)
28+
29+
def define_node(
30+
self,
31+
node: torch.fx.Node,
32+
tosa_graph: ts.TosaSerializer,
33+
inputs: List[TosaArg],
34+
output: TosaArg,
35+
) -> None:
36+
if not (inputs[0].dtype == output.dtype):
37+
raise ValueError(
38+
"All inputs and output need same dtype."
39+
f"Got {inputs[0].dtype=}, {output.dtype=}"
40+
)
41+
if not (inputs[0].dtype == ts.DType.FP32):
42+
raise ValueError("All inputs need to be FP32." f"Got {inputs[0].dtype=}")
43+
# MI lowering
44+
tosa_graph.addOperator(TosaOp.Op().ERF, [inputs[0].name], [output.name])

backends/arm/operators/op_sigmoid.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,14 @@ def define_node(
3636
output: TosaArg,
3737
) -> None:
3838

39-
assert len(node.all_input_nodes) == 1
40-
assert inputs[0].dtype == output.dtype == ts.DType.FP32
39+
if len(node.all_input_nodes) != 1:
40+
raise ValueError(
41+
f"Expected 1 input for {self.target}, got {len(node.all_input_nodes)}"
42+
)
43+
if inputs[0].dtype != ts.DType.FP32 or output.dtype != ts.DType.FP32:
44+
raise ValueError(
45+
f"Input and output for {self.target} need to be FP32, got input_dtype: "
46+
f"{inputs[0].dtype} and output_dtype: {output.dtype}"
47+
)
4148

4249
tosa_graph.addOperator(TosaOp.Op().SIGMOID, [inputs[0].name], [output.name])

backends/arm/operators/op_tanh.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,5 +34,14 @@ def define_node(
3434
inputs: List[TosaArg],
3535
output: TosaArg,
3636
) -> None:
37-
assert inputs[0].dtype == output.dtype == ts.DType.FP32
37+
if len(node.all_input_nodes) != 1:
38+
raise ValueError(
39+
f"Expected 1 input for {self.target}, got {len(node.all_input_nodes)}"
40+
)
41+
if inputs[0].dtype != ts.DType.FP32 or output.dtype != ts.DType.FP32:
42+
raise ValueError(
43+
f"Input and output for {self.target} need to be FP32, got input_dtype: "
44+
f"{inputs[0].dtype} and output_dtype: {output.dtype}"
45+
)
46+
3847
tosa_graph.addOperator(TosaOp.Op().TANH, [inputs[0].name], [output.name])

backends/arm/quantizer/quantization_annotator.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,7 @@ def _match_pattern(
164164
_one_to_one = [
165165
torch.ops.aten.abs.default,
166166
torch.ops.aten.ceil.default,
167+
torch.ops.aten.erf.default,
167168
torch.ops.aten.exp.default,
168169
torch.ops.aten.floor.default,
169170
torch.ops.aten.log.default,
@@ -217,6 +218,8 @@ def _match_pattern(
217218
torch.ops.aten.pad.default,
218219
torch.ops.aten.amax.default,
219220
torch.ops.aten.amin.default,
221+
torch.ops.aten.clamp.default,
222+
torch.ops.aten.clamp.Tensor,
220223
]
221224

222225
# Operators that can inherit the quantization specs from its parent node
@@ -236,8 +239,6 @@ def _match_pattern(
236239
torch.ops.aten.flatten.using_ints,
237240
torch.ops.aten.dropout.default,
238241
torch.ops.aten.dropout_.default,
239-
torch.ops.aten.clamp.default,
240-
torch.ops.aten.clamp.Tensor,
241242
torch.ops.aten.where,
242243
operator.getitem,
243244
]

0 commit comments

Comments
 (0)