Skip to content

Qualcomm AI Engine Direct - alias_copy op #10319

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backends/qualcomm/_passes/remove_redundancy.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def __init__(self):
exir_ops.edge.aten.clone.default: self._default_condition,
torch.ops.aten.alias.default: self._default_condition,
exir_ops.edge.aten.alias.default: self._default_condition,
exir_ops.edge.aten.alias_copy.default: self._default_condition,
exir_ops.edge.aten.lift_fresh_copy.default: self._default_condition,
# remove this target if '_skip_dim_order' is set to False
exir_ops.edge.dim_order_ops._to_dim_order_copy.default: self._dim_order_op_condition,
Expand Down
11 changes: 11 additions & 0 deletions backends/qualcomm/tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,17 @@
# module with related operator only


# Ensure alias_copy is removed in remove_redundancy pass
class Alias(torch.nn.Module):
def __init__(self):
super().__init__()
self.relu = torch.nn.ReLU()

def forward(self, x):
alias_x = torch.ops.aten.alias.default(x)
return self.relu(alias_x)


class And(torch.nn.Module):
def __init__(self, pos, neg):
super().__init__()
Expand Down
11 changes: 11 additions & 0 deletions backends/qualcomm/tests/test_qnn_delegate.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,11 @@ def test_qnn_backend_adaptive_avg_pool2d(self):
sample_input = (torch.randn(1, 512, 7, 7),)
self.lower_module_and_test_output(module, sample_input)

def test_qnn_backend_alias(self):
module = Alias() # noqa: F405
sample_input = (torch.randn(1, 10),)
self.lower_module_and_test_output(module, sample_input)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we testing alias op should still work, meaning they're not removed?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @cccclai,
Thanks for reviewing the PR.
Alias op will be removed.
This test is just to show alias can be properly removed and model is still working properly.
I have added some comments under the model.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, I assume this test will work regardless with or without alias_op, correct? I was thinking export the graph and run to_backend, and check there is no alias op after that.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, I assume this test will work regardless with or without alias_op, correct? I was thinking export the graph and run to_backend, and check there is no alias op after that.

This test it to reproduce @billmguo error.
Without adding exir_ops.edge.aten.alias_copy.default to RemoveRedundancy pass, this test will fail during qnn_partitioner where it does not have op_builder for aten.alias_copy.default.

Thanks for the suggestion. That will probably be more straight forward as the unit test is not actually running the alias_op since it is dropped during RemoveRedundancy. Maybe we can add a new class for UT, targeting whether passes are working as expected.


def test_qnn_backend_amax(self):
modules = [AMax(dim=1, keepdim=False), AMax(dim=1, keepdim=True)] # noqa: F405
sample_input = (torch.randn(4, 4),)
Expand Down Expand Up @@ -1156,6 +1161,12 @@ def test_qnn_backend_adaptive_avg_pool2d(self):
module = self.get_qdq_module(module, sample_input)
self.lower_module_and_test_output(module, sample_input)

def test_qnn_backend_alias(self):
module = Alias() # noqa: F405
sample_input = (torch.randn(1, 10),)
module = self.get_qdq_module(module, sample_input)
self.lower_module_and_test_output(module, sample_input)

def test_qnn_backend_amax(self):
modules = [AMax(dim=1, keepdim=False), AMax(dim=1, keepdim=True)] # noqa: F405
sample_input = (torch.randn(4, 4),)
Expand Down
Loading