-
Notifications
You must be signed in to change notification settings - Fork 576
Arm Backend: Add New DecomposeSilu pass to arm_pass_manager #9448
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
zingo
merged 2 commits into
pytorch:main
from
ArmRyan:experimental/arm_silu_sdpa_passes
Apr 15, 2025
Merged
Changes from all commits
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
# Copyright 2025 Arm Limited and/or its affiliates. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
# pyre-unsafe | ||
|
||
import torch | ||
from executorch.exir.pass_base import ExportPass | ||
|
||
aten_silu_ops = (torch.ops.aten.silu.default, torch.ops.aten.silu_.default) | ||
|
||
|
||
class DecomposeSiluPass(ExportPass): | ||
""" | ||
This pass decomposes silu into a mul and a sigmoid node. | ||
|
||
Example: | ||
y = silu(a) | ||
Becomes: | ||
x = sigmoid(a) | ||
y = mul(a,x) | ||
""" | ||
|
||
def call_operator(self, op, args, kwargs, meta): | ||
if op not in (aten_silu_ops): | ||
return super().call_operator(op, args, kwargs, meta) | ||
sigmoid_op = torch.ops.aten.sigmoid.default | ||
mul_op = torch.ops.aten.mul.Tensor | ||
|
||
original = args[0] | ||
sigmoid = super().call_operator(sigmoid_op, (original,), {}, meta) | ||
|
||
return super().call_operator(mul_op, (original, sigmoid), {}, meta) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,113 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# Copyright 2025 Arm Limited and/or its affiliates. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
|
||
from typing import Optional, Tuple | ||
|
||
import torch | ||
from executorch.backends.arm.test import common | ||
from executorch.backends.arm.test.tester.test_pipeline import ( | ||
EthosU55PipelineBI, | ||
EthosU85PipelineBI, | ||
TosaPipelineBI, | ||
TosaPipelineMI, | ||
) | ||
|
||
|
||
input_t = Tuple[torch.Tensor] | ||
|
||
|
||
class Silu(torch.nn.Module): | ||
def forward( | ||
self, | ||
_input: torch.Tensor, | ||
_inplace: Optional[bool] = False, | ||
): | ||
return torch.nn.SiLU(inplace=_inplace)(_input) | ||
|
||
test_data: list[input_t] = { | ||
"op_silu_rank1_ones": (torch.ones(5),), | ||
"op_silu_rank1_negative_ones": (torch.ones(5) * (-1),), | ||
"op_silu_rank1_rand": (torch.rand(5) * 5,), | ||
"op_silu_rank4_ones": (torch.ones(1, 10, 25, 20),), | ||
"op_silu_rank4_negative_ones": ((-1) * torch.ones(1, 10, 25, 20),), | ||
"op_silu_rank4_large_rand": (200 * torch.rand(1, 10, 25, 20),), | ||
"op_silu_rank4_negative_large_rand": ((-200) * torch.rand(1, 10, 25, 20),), | ||
"op_silu_rank4_large_randn": (200 * torch.randn(1, 10, 25, 20) + 1,), | ||
} | ||
|
||
aten_op_MI = "torch.ops.aten.silu.default" | ||
aten_op_inplace_MI = "torch.ops.aten.silu_.default" | ||
aten_op_BI = ["torch.ops.aten.sigmoid.default", "torch.ops.aten.mul.Tensor"] | ||
|
||
|
||
@common.parametrize("test_data", Silu.test_data) | ||
def test_silu_tosa_MI(test_data: input_t): | ||
silu_data = (test_data[0], False) | ||
pipeline = TosaPipelineMI[input_t](Silu(), silu_data, Silu.aten_op_MI) | ||
pipeline.run() | ||
|
||
|
||
@common.parametrize("test_data", Silu.test_data) | ||
def test_silu_tosa_MI_inplace(test_data: input_t): | ||
silu_data = (test_data[0], True) | ||
pipeline = TosaPipelineMI[input_t](Silu(), silu_data, Silu.aten_op_inplace_MI) | ||
pipeline.run() | ||
|
||
|
||
@common.parametrize("test_data", Silu.test_data) | ||
def test_silu_tosa_BI(test_data: input_t): | ||
silu_data = (test_data[0], False) | ||
pipeline = TosaPipelineBI[input_t](Silu(), silu_data, Silu.aten_op_BI) | ||
pipeline.run() | ||
|
||
|
||
@common.parametrize("test_data", Silu.test_data) | ||
def test_silu_tosa_BI_inplace(test_data: input_t): | ||
silu_data = (test_data[0], True) | ||
pipeline = TosaPipelineBI[input_t](Silu(), silu_data, Silu.aten_op_BI) | ||
pipeline.run() | ||
|
||
|
||
@common.parametrize("test_data", Silu.test_data) | ||
@common.XfailIfNoCorstone300 | ||
def test_silu_u55_BI(test_data: input_t): | ||
silu_data = (test_data[0], False) | ||
pipeline = EthosU55PipelineBI[input_t]( | ||
Silu(), silu_data, Silu.aten_op_BI, run_on_fvp=True | ||
) | ||
pipeline.run() | ||
|
||
|
||
@common.parametrize("test_data", Silu.test_data) | ||
@common.XfailIfNoCorstone300 | ||
def test_silu_u55_BI_inplace(test_data: input_t): | ||
silu_data = (test_data[0], True) | ||
pipeline = EthosU55PipelineBI[input_t]( | ||
Silu(), silu_data, Silu.aten_op_BI, run_on_fvp=True | ||
) | ||
pipeline.run() | ||
|
||
|
||
@common.parametrize("test_data", Silu.test_data) | ||
@common.XfailIfNoCorstone320 | ||
def test_silu_u85_BI(test_data: input_t): | ||
silu_data = (test_data[0], False) | ||
pipeline = EthosU85PipelineBI[input_t]( | ||
Silu(), silu_data, Silu.aten_op_BI, run_on_fvp=True | ||
) | ||
pipeline.run() | ||
|
||
|
||
@common.parametrize("test_data", Silu.test_data) | ||
@common.XfailIfNoCorstone320 | ||
def test_silu_u85_BI_inplace(test_data: input_t): | ||
silu_data = (test_data[0], True) | ||
pipeline = EthosU85PipelineBI[input_t]( | ||
Silu(), silu_data, Silu.aten_op_BI, run_on_fvp=True | ||
) | ||
pipeline.run() |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no tests for SDPA?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There were no tests added because it was not a new function essentially, just a new call with the idea is that it is tested elsewhere. We have tested it functionally just not with a unit test
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
adding test_sdpa.py would ensure we can lower and run it. Where else it is tested? Couldn't find it in the test/ops dir.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is tested from the QNN delegate but I can add unit tests here also! I will push that when I can ( probably Monday )
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for the delay, had some issues and went on holidays, came back to more issues so I have removed SDPA for now