Skip to content

Commit 23c11c7

Browse files
committed
Arm Backend: Add New Silu and SDPA Decomp passes to arm_pass_manager
* Adds DecomposeSilu pass * Adds DecomposeScaledDotProductAttention pass * Adds Tests for DecomposeSilu Signed-off-by: Ryan O'Shea <[email protected]> Change-Id: Ib9f15d04c4c06d92d38cc9e6297145980052e673
1 parent 059e4b0 commit 23c11c7

File tree

4 files changed

+191
-2
lines changed

4 files changed

+191
-2
lines changed

backends/arm/_passes/arm_pass_manager.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
from executorch.backends.arm._passes.decompose_select import ( # type: ignore[import-not-found]
4545
DecomposeSelectPass,
4646
)
47+
from executorch.backends.arm._passes.decompose_silu_pass import DecomposeSiluPass
4748
from executorch.backends.arm._passes.decompose_softmax_pass import DecomposeSoftmaxPass
4849
from executorch.backends.arm._passes.decompose_softmax_unstable_pass import (
4950
DecomposeSoftmaxUnstablePass,
@@ -83,6 +84,9 @@
8384
UnsqueezeScalarPlaceholdersPass,
8485
)
8586
from executorch.backends.arm.tosa_specification import Tosa_0_80, TosaSpecification
87+
from executorch.backends.transforms.decompose_sdpa import (
88+
DecomposeScaledDotProductAttention,
89+
)
8690
from executorch.backends.transforms.fuse_view_copy import FuseViewCopyTransform
8791

8892
from executorch.backends.transforms.replace_scalar_with_tensor import (
@@ -205,6 +209,8 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
205209
self.add_pass(DecomposeVarPass())
206210
self.add_pass(DecomposeMeanDimPass())
207211
self.add_pass(DecomposeDivPass())
212+
self.add_pass(DecomposeSiluPass())
213+
self.add_pass(DecomposeScaledDotProductAttention())
208214

209215
if isinstance(self.tosa_spec, Tosa_0_80) and self.tosa_spec.is_U55_subset:
210216
# Numerically stable softmax uses amax which is not supported on Ethos-U55
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# pyre-unsafe
7+
8+
import torch
9+
from executorch.exir.pass_base import ExportPass
10+
11+
aten_silu_ops = (torch.ops.aten.silu.default, torch.ops.aten.silu_.default)
12+
13+
14+
class DecomposeSiluPass(ExportPass):
15+
"""
16+
This pass decomposes silu into a mul and a sigmoid node.
17+
18+
Example:
19+
y = silu(a)
20+
Becomes:
21+
x = sigmoid(a)
22+
y = mul(a,x)
23+
"""
24+
25+
def call_operator(self, op, args, kwargs, meta):
26+
if op not in (aten_silu_ops):
27+
return super().call_operator(op, args, kwargs, meta)
28+
sigmoid_op = torch.ops.aten.sigmoid.default
29+
mul_op = torch.ops.aten.mul.Tensor
30+
31+
original = args[0]
32+
sigmoid = super().call_operator(sigmoid_op, (original,), {}, meta)
33+
34+
return super().call_operator(mul_op, (original, sigmoid), {}, meta)

backends/arm/quantizer/arm_quantizer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -284,10 +284,10 @@ def _annotate_all_static_patterns(
284284
quantization_config: Optional[QuantizationConfig],
285285
filter_fn: Optional[Callable[[Node], bool]] = None,
286286
) -> GraphModule:
287-
"""Loops over all STATIC_OPS and runs the corresponding registred annotator.
287+
"""Loops over all STATIC_OPS and runs the corresponding registered annotator.
288288
Args:
289289
model: The model to annotate statically.
290-
quantization_config: Specifices the QuantizationSpecs for the model's
290+
quantization_config: Specifies the QuantizationSpecs for the model's
291291
input activations, output activations, weights and biases.
292292
filter_fn: An optional filter function that takes a node and returns whether the node should be annotated.
293293
Returns:

backends/arm/test/ops/test_silu.py

Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
# Copyright 2025 Arm Limited and/or its affiliates.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
9+
from typing import Optional, Tuple
10+
11+
import torch
12+
from executorch.backends.arm.test import common
13+
from executorch.backends.arm.test.tester.test_pipeline import (
14+
EthosU55PipelineBI,
15+
EthosU85PipelineBI,
16+
TosaPipelineBI,
17+
TosaPipelineMI,
18+
)
19+
20+
21+
input_t = Tuple[torch.Tensor]
22+
23+
aten_op_MI = "torch.ops.aten.silu.default"
24+
aten_op_inplace_MI = "torch.ops.aten.silu_.default"
25+
aten_op_BI = ["torch.ops.aten.sigmoid.default", "torch.ops.aten.mul.Tensor"]
26+
27+
28+
class Silu(torch.nn.Module):
29+
def forward(
30+
self,
31+
_input: torch.Tensor,
32+
_inplace: Optional[bool] = False,
33+
):
34+
return torch.nn.SiLU(inplace=_inplace)(_input)
35+
36+
test_data: list[input_t] = {
37+
"op_silu_rank1_ones": (torch.ones(5),),
38+
"op_silu_rank1_negative_ones": (torch.ones(5) * (-1),),
39+
"op_silu_rank1_rand": (torch.rand(5) * 5,),
40+
"op_silu_rank4_ones": (torch.ones(1, 10, 25, 20),),
41+
"op_silu_rank4_negative_ones": ((-1) * torch.ones(1, 10, 25, 20),),
42+
"op_silu_rank4_large_rand": (200 * torch.rand(1, 10, 25, 20),),
43+
"op_silu_rank4_negative_large_rand": ((-200) * torch.rand(1, 10, 25, 20),),
44+
"op_silu_rank4_large_randn": (200 * torch.randn(1, 10, 25, 20) + 1,),
45+
}
46+
47+
48+
@common.parametrize("test_data", Silu.test_data)
49+
def test_silu_tosa_MI(test_data: input_t):
50+
silu_data = (test_data[0], False)
51+
pipeline = TosaPipelineMI[input_t](Silu(), silu_data, aten_op_MI)
52+
pipeline.run()
53+
54+
55+
@common.parametrize("test_data", Silu.test_data)
56+
def test_silu_inplace_tosa_MI(test_data: input_t):
57+
silu_data = (test_data[0], True)
58+
pipeline = TosaPipelineMI[input_t](Silu(), silu_data, aten_op_inplace_MI)
59+
pipeline.run()
60+
61+
62+
@common.parametrize("test_data", Silu.test_data)
63+
def test_silu_tosa_BI(test_data: input_t):
64+
silu_data = (test_data[0], False)
65+
pipeline = TosaPipelineBI[input_t](Silu(), silu_data, aten_op_BI)
66+
pipeline.run()
67+
68+
69+
@common.parametrize("test_data", Silu.test_data)
70+
def test_silu_inplace_tosa_BI(test_data: input_t):
71+
silu_data = (test_data[0], True)
72+
pipeline = TosaPipelineBI[input_t](Silu(), silu_data, aten_op_BI)
73+
pipeline.run()
74+
75+
76+
@common.parametrize("test_data", Silu.test_data)
77+
def test_silu_u55_BI(test_data: input_t):
78+
silu_data = (test_data[0], False)
79+
pipeline = EthosU55PipelineBI[input_t](
80+
Silu(), silu_data, aten_op_BI, run_on_fvp=False
81+
)
82+
pipeline.run()
83+
84+
85+
@common.parametrize("test_data", Silu.test_data)
86+
def test_silu_inplace_u55_BI(test_data: input_t):
87+
silu_data = (test_data[0], True)
88+
pipeline = EthosU55PipelineBI[input_t](
89+
Silu(), silu_data, aten_op_BI, run_on_fvp=False
90+
)
91+
pipeline.run()
92+
93+
94+
@common.parametrize("test_data", Silu.test_data)
95+
def test_silu_u85_BI(test_data: input_t):
96+
silu_data = (test_data[0], False)
97+
pipeline = EthosU85PipelineBI[input_t](
98+
Silu(), silu_data, aten_op_BI, run_on_fvp=False
99+
)
100+
pipeline.run()
101+
102+
103+
@common.parametrize("test_data", Silu.test_data)
104+
def test_silu_inplace_u85_BI(test_data: input_t):
105+
silu_data = (test_data[0], True)
106+
pipeline = EthosU85PipelineBI[input_t](
107+
Silu(), silu_data, aten_op_BI, run_on_fvp=False
108+
)
109+
pipeline.run()
110+
111+
112+
@common.parametrize("test_data", Silu.test_data)
113+
@common.SkipIfNoCorstone300
114+
def test_silu_u55_BI_on_fvp(test_data: input_t):
115+
silu_data = (test_data[0], False)
116+
pipeline = EthosU55PipelineBI[input_t](
117+
Silu(), silu_data, aten_op_BI, run_on_fvp=True
118+
)
119+
pipeline.run()
120+
121+
122+
@common.parametrize("test_data", Silu.test_data)
123+
@common.SkipIfNoCorstone300
124+
def test_silu_inplace_u55_BI_on_fvp(test_data: input_t):
125+
silu_data = (test_data[0], True)
126+
pipeline = EthosU55PipelineBI[input_t](
127+
Silu(), silu_data, aten_op_BI, run_on_fvp=True
128+
)
129+
pipeline.run()
130+
131+
132+
@common.parametrize("test_data", Silu.test_data)
133+
@common.SkipIfNoCorstone320
134+
def test_silu_u85_BI_on_fvp(test_data: input_t):
135+
silu_data = (test_data[0], False)
136+
pipeline = EthosU85PipelineBI[input_t](
137+
Silu(), silu_data, aten_op_BI, run_on_fvp=True
138+
)
139+
pipeline.run()
140+
141+
142+
@common.parametrize("test_data", Silu.test_data)
143+
@common.SkipIfNoCorstone320
144+
def test_silu_inplace_u85_BI_on_fvp(test_data: input_t):
145+
silu_data = (test_data[0], True)
146+
pipeline = EthosU85PipelineBI[input_t](
147+
Silu(), silu_data, aten_op_BI, run_on_fvp=True
148+
)
149+
pipeline.run()

0 commit comments

Comments
 (0)