Skip to content

Arm Backend: Add New DecomposeSilu pass to arm_pass_manager #9448

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backends/arm/_passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from .decompose_linear_pass import DecomposeLinearPass # noqa
from .decompose_meandim_pass import DecomposeMeanDimPass # noqa
from .decompose_select import DecomposeSelectPass # noqa
from .decompose_silu_pass import DecomposeSiluPass # noqa
from .decompose_softmax_pass import DecomposeSoftmaxPass # noqa
from .decompose_softmax_unstable_pass import DecomposeSoftmaxUnstablePass # noqa
from .decompose_sqrt_pass import DecomposeSqrtPass # noqa
Expand Down
2 changes: 2 additions & 0 deletions backends/arm/_passes/arm_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
DecomposeLinearPass,
DecomposeMeanDimPass,
DecomposeSelectPass,
DecomposeSiluPass,
DecomposeSoftmaxPass,
DecomposeSoftmaxUnstablePass,
DecomposeSqrtPass,
Expand Down Expand Up @@ -196,6 +197,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
self.add_pass(DecomposeDivPass())
self.add_pass(DecomposeLeakyReLUPass())
self.add_pass(DecomposeSqrtPass())
self.add_pass(DecomposeSiluPass())

if isinstance(self.tosa_spec, Tosa_0_80) and self.tosa_spec.is_U55_subset:
# Numerically stable softmax uses amax which is not supported on Ethos-U55
Expand Down
34 changes: 34 additions & 0 deletions backends/arm/_passes/decompose_silu_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe

import torch
from executorch.exir.pass_base import ExportPass

aten_silu_ops = (torch.ops.aten.silu.default, torch.ops.aten.silu_.default)


class DecomposeSiluPass(ExportPass):
"""
This pass decomposes silu into a mul and a sigmoid node.

Example:
y = silu(a)
Becomes:
x = sigmoid(a)
y = mul(a,x)
"""

def call_operator(self, op, args, kwargs, meta):
if op not in (aten_silu_ops):
return super().call_operator(op, args, kwargs, meta)
sigmoid_op = torch.ops.aten.sigmoid.default
mul_op = torch.ops.aten.mul.Tensor

original = args[0]
sigmoid = super().call_operator(sigmoid_op, (original,), {}, meta)

return super().call_operator(mul_op, (original, sigmoid), {}, meta)
4 changes: 2 additions & 2 deletions backends/arm/quantizer/arm_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,10 +286,10 @@ def _annotate_all_static_patterns(
quantization_config: Optional[QuantizationConfig],
filter_fn: Optional[Callable[[Node], bool]] = None,
) -> GraphModule:
"""Loops over all STATIC_OPS and runs the corresponding registred annotator.
"""Loops over all STATIC_OPS and runs the corresponding registered annotator.
Args:
model: The model to annotate statically.
quantization_config: Specifices the QuantizationSpecs for the model's
quantization_config: Specifies the QuantizationSpecs for the model's
input activations, output activations, weights and biases.
filter_fn: An optional filter function that takes a node and returns whether the node should be annotated.
Returns:
Expand Down
113 changes: 113 additions & 0 deletions backends/arm/test/ops/test_silu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no tests for SDPA?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There were no tests added because it was not a new function essentially, just a new call with the idea is that it is tested elsewhere. We have tested it functionally just not with a unit test

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

adding test_sdpa.py would ensure we can lower and run it. Where else it is tested? Couldn't find it in the test/ops dir.

Copy link
Collaborator Author

@ArmRyan ArmRyan Mar 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is tested from the QNN delegate but I can add unit tests here also! I will push that when I can ( probably Monday )

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the delay, had some issues and went on holidays, came back to more issues so I have removed SDPA for now

# All rights reserved.
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.


from typing import Optional, Tuple

import torch
from executorch.backends.arm.test import common
from executorch.backends.arm.test.tester.test_pipeline import (
EthosU55PipelineBI,
EthosU85PipelineBI,
TosaPipelineBI,
TosaPipelineMI,
)


input_t = Tuple[torch.Tensor]


class Silu(torch.nn.Module):
def forward(
self,
_input: torch.Tensor,
_inplace: Optional[bool] = False,
):
return torch.nn.SiLU(inplace=_inplace)(_input)

test_data: list[input_t] = {
"op_silu_rank1_ones": (torch.ones(5),),
"op_silu_rank1_negative_ones": (torch.ones(5) * (-1),),
"op_silu_rank1_rand": (torch.rand(5) * 5,),
"op_silu_rank4_ones": (torch.ones(1, 10, 25, 20),),
"op_silu_rank4_negative_ones": ((-1) * torch.ones(1, 10, 25, 20),),
"op_silu_rank4_large_rand": (200 * torch.rand(1, 10, 25, 20),),
"op_silu_rank4_negative_large_rand": ((-200) * torch.rand(1, 10, 25, 20),),
"op_silu_rank4_large_randn": (200 * torch.randn(1, 10, 25, 20) + 1,),
}

aten_op_MI = "torch.ops.aten.silu.default"
aten_op_inplace_MI = "torch.ops.aten.silu_.default"
aten_op_BI = ["torch.ops.aten.sigmoid.default", "torch.ops.aten.mul.Tensor"]


@common.parametrize("test_data", Silu.test_data)
def test_silu_tosa_MI(test_data: input_t):
silu_data = (test_data[0], False)
pipeline = TosaPipelineMI[input_t](Silu(), silu_data, Silu.aten_op_MI)
pipeline.run()


@common.parametrize("test_data", Silu.test_data)
def test_silu_tosa_MI_inplace(test_data: input_t):
silu_data = (test_data[0], True)
pipeline = TosaPipelineMI[input_t](Silu(), silu_data, Silu.aten_op_inplace_MI)
pipeline.run()


@common.parametrize("test_data", Silu.test_data)
def test_silu_tosa_BI(test_data: input_t):
silu_data = (test_data[0], False)
pipeline = TosaPipelineBI[input_t](Silu(), silu_data, Silu.aten_op_BI)
pipeline.run()


@common.parametrize("test_data", Silu.test_data)
def test_silu_tosa_BI_inplace(test_data: input_t):
silu_data = (test_data[0], True)
pipeline = TosaPipelineBI[input_t](Silu(), silu_data, Silu.aten_op_BI)
pipeline.run()


@common.parametrize("test_data", Silu.test_data)
@common.XfailIfNoCorstone300
def test_silu_u55_BI(test_data: input_t):
silu_data = (test_data[0], False)
pipeline = EthosU55PipelineBI[input_t](
Silu(), silu_data, Silu.aten_op_BI, run_on_fvp=True
)
pipeline.run()


@common.parametrize("test_data", Silu.test_data)
@common.XfailIfNoCorstone300
def test_silu_u55_BI_inplace(test_data: input_t):
silu_data = (test_data[0], True)
pipeline = EthosU55PipelineBI[input_t](
Silu(), silu_data, Silu.aten_op_BI, run_on_fvp=True
)
pipeline.run()


@common.parametrize("test_data", Silu.test_data)
@common.XfailIfNoCorstone320
def test_silu_u85_BI(test_data: input_t):
silu_data = (test_data[0], False)
pipeline = EthosU85PipelineBI[input_t](
Silu(), silu_data, Silu.aten_op_BI, run_on_fvp=True
)
pipeline.run()


@common.parametrize("test_data", Silu.test_data)
@common.XfailIfNoCorstone320
def test_silu_u85_BI_inplace(test_data: input_t):
silu_data = (test_data[0], True)
pipeline = EthosU85PipelineBI[input_t](
Silu(), silu_data, Silu.aten_op_BI, run_on_fvp=True
)
pipeline.run()
Loading