Skip to content

Commit fc01661

Browse files
authored
Arm Backend: Add New DecomposeSilu pass to arm_pass_manager (#9448)
* Adds DecomposeSilu pass * Adds Tests for DecomposeSilu Signed-off-by: Ryan O'Shea <[email protected]>
1 parent 0ffab50 commit fc01661

File tree

5 files changed

+152
-2
lines changed

5 files changed

+152
-2
lines changed

backends/arm/_passes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from .decompose_linear_pass import DecomposeLinearPass # noqa
2727
from .decompose_meandim_pass import DecomposeMeanDimPass # noqa
2828
from .decompose_select import DecomposeSelectPass # noqa
29+
from .decompose_silu_pass import DecomposeSiluPass # noqa
2930
from .decompose_softmax_pass import DecomposeSoftmaxPass # noqa
3031
from .decompose_softmax_unstable_pass import DecomposeSoftmaxUnstablePass # noqa
3132
from .decompose_sqrt_pass import DecomposeSqrtPass # noqa

backends/arm/_passes/arm_pass_manager.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
DecomposeLinearPass,
3232
DecomposeMeanDimPass,
3333
DecomposeSelectPass,
34+
DecomposeSiluPass,
3435
DecomposeSoftmaxPass,
3536
DecomposeSoftmaxUnstablePass,
3637
DecomposeSqrtPass,
@@ -196,6 +197,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
196197
self.add_pass(DecomposeDivPass())
197198
self.add_pass(DecomposeLeakyReLUPass())
198199
self.add_pass(DecomposeSqrtPass())
200+
self.add_pass(DecomposeSiluPass())
199201

200202
if isinstance(self.tosa_spec, Tosa_0_80) and self.tosa_spec.is_U55_subset:
201203
# Numerically stable softmax uses amax which is not supported on Ethos-U55
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# pyre-unsafe
7+
8+
import torch
9+
from executorch.exir.pass_base import ExportPass
10+
11+
aten_silu_ops = (torch.ops.aten.silu.default, torch.ops.aten.silu_.default)
12+
13+
14+
class DecomposeSiluPass(ExportPass):
15+
"""
16+
This pass decomposes silu into a mul and a sigmoid node.
17+
18+
Example:
19+
y = silu(a)
20+
Becomes:
21+
x = sigmoid(a)
22+
y = mul(a,x)
23+
"""
24+
25+
def call_operator(self, op, args, kwargs, meta):
26+
if op not in (aten_silu_ops):
27+
return super().call_operator(op, args, kwargs, meta)
28+
sigmoid_op = torch.ops.aten.sigmoid.default
29+
mul_op = torch.ops.aten.mul.Tensor
30+
31+
original = args[0]
32+
sigmoid = super().call_operator(sigmoid_op, (original,), {}, meta)
33+
34+
return super().call_operator(mul_op, (original, sigmoid), {}, meta)

backends/arm/quantizer/arm_quantizer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -286,10 +286,10 @@ def _annotate_all_static_patterns(
286286
quantization_config: Optional[QuantizationConfig],
287287
filter_fn: Optional[Callable[[Node], bool]] = None,
288288
) -> GraphModule:
289-
"""Loops over all STATIC_OPS and runs the corresponding registred annotator.
289+
"""Loops over all STATIC_OPS and runs the corresponding registered annotator.
290290
Args:
291291
model: The model to annotate statically.
292-
quantization_config: Specifices the QuantizationSpecs for the model's
292+
quantization_config: Specifies the QuantizationSpecs for the model's
293293
input activations, output activations, weights and biases.
294294
filter_fn: An optional filter function that takes a node and returns whether the node should be annotated.
295295
Returns:

backends/arm/test/ops/test_silu.py

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
# Copyright 2025 Arm Limited and/or its affiliates.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
9+
from typing import Optional, Tuple
10+
11+
import torch
12+
from executorch.backends.arm.test import common
13+
from executorch.backends.arm.test.tester.test_pipeline import (
14+
EthosU55PipelineBI,
15+
EthosU85PipelineBI,
16+
TosaPipelineBI,
17+
TosaPipelineMI,
18+
)
19+
20+
21+
input_t = Tuple[torch.Tensor]
22+
23+
24+
class Silu(torch.nn.Module):
25+
def forward(
26+
self,
27+
_input: torch.Tensor,
28+
_inplace: Optional[bool] = False,
29+
):
30+
return torch.nn.SiLU(inplace=_inplace)(_input)
31+
32+
test_data: list[input_t] = {
33+
"op_silu_rank1_ones": (torch.ones(5),),
34+
"op_silu_rank1_negative_ones": (torch.ones(5) * (-1),),
35+
"op_silu_rank1_rand": (torch.rand(5) * 5,),
36+
"op_silu_rank4_ones": (torch.ones(1, 10, 25, 20),),
37+
"op_silu_rank4_negative_ones": ((-1) * torch.ones(1, 10, 25, 20),),
38+
"op_silu_rank4_large_rand": (200 * torch.rand(1, 10, 25, 20),),
39+
"op_silu_rank4_negative_large_rand": ((-200) * torch.rand(1, 10, 25, 20),),
40+
"op_silu_rank4_large_randn": (200 * torch.randn(1, 10, 25, 20) + 1,),
41+
}
42+
43+
aten_op_MI = "torch.ops.aten.silu.default"
44+
aten_op_inplace_MI = "torch.ops.aten.silu_.default"
45+
aten_op_BI = ["torch.ops.aten.sigmoid.default", "torch.ops.aten.mul.Tensor"]
46+
47+
48+
@common.parametrize("test_data", Silu.test_data)
49+
def test_silu_tosa_MI(test_data: input_t):
50+
silu_data = (test_data[0], False)
51+
pipeline = TosaPipelineMI[input_t](Silu(), silu_data, Silu.aten_op_MI)
52+
pipeline.run()
53+
54+
55+
@common.parametrize("test_data", Silu.test_data)
56+
def test_silu_tosa_MI_inplace(test_data: input_t):
57+
silu_data = (test_data[0], True)
58+
pipeline = TosaPipelineMI[input_t](Silu(), silu_data, Silu.aten_op_inplace_MI)
59+
pipeline.run()
60+
61+
62+
@common.parametrize("test_data", Silu.test_data)
63+
def test_silu_tosa_BI(test_data: input_t):
64+
silu_data = (test_data[0], False)
65+
pipeline = TosaPipelineBI[input_t](Silu(), silu_data, Silu.aten_op_BI)
66+
pipeline.run()
67+
68+
69+
@common.parametrize("test_data", Silu.test_data)
70+
def test_silu_tosa_BI_inplace(test_data: input_t):
71+
silu_data = (test_data[0], True)
72+
pipeline = TosaPipelineBI[input_t](Silu(), silu_data, Silu.aten_op_BI)
73+
pipeline.run()
74+
75+
76+
@common.parametrize("test_data", Silu.test_data)
77+
@common.XfailIfNoCorstone300
78+
def test_silu_u55_BI(test_data: input_t):
79+
silu_data = (test_data[0], False)
80+
pipeline = EthosU55PipelineBI[input_t](
81+
Silu(), silu_data, Silu.aten_op_BI, run_on_fvp=True
82+
)
83+
pipeline.run()
84+
85+
86+
@common.parametrize("test_data", Silu.test_data)
87+
@common.XfailIfNoCorstone300
88+
def test_silu_u55_BI_inplace(test_data: input_t):
89+
silu_data = (test_data[0], True)
90+
pipeline = EthosU55PipelineBI[input_t](
91+
Silu(), silu_data, Silu.aten_op_BI, run_on_fvp=True
92+
)
93+
pipeline.run()
94+
95+
96+
@common.parametrize("test_data", Silu.test_data)
97+
@common.XfailIfNoCorstone320
98+
def test_silu_u85_BI(test_data: input_t):
99+
silu_data = (test_data[0], False)
100+
pipeline = EthosU85PipelineBI[input_t](
101+
Silu(), silu_data, Silu.aten_op_BI, run_on_fvp=True
102+
)
103+
pipeline.run()
104+
105+
106+
@common.parametrize("test_data", Silu.test_data)
107+
@common.XfailIfNoCorstone320
108+
def test_silu_u85_BI_inplace(test_data: input_t):
109+
silu_data = (test_data[0], True)
110+
pipeline = EthosU85PipelineBI[input_t](
111+
Silu(), silu_data, Silu.aten_op_BI, run_on_fvp=True
112+
)
113+
pipeline.run()

0 commit comments

Comments
 (0)