-
Notifications
You must be signed in to change notification settings - Fork 587
Arm backend: Update fuse_batchnorm_pass to create new placeholders #8411
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
15 commits
Select commit
Hold shift + click to select a range
54acf1b
[ARM backend] Update fuse_batchnorm_pass to create new placeholders
AdrianLundell 71e791e
Move create/delete_constant_node utils to shared folder
AdrianLundell dc59086
Merge branch 'main' of https://github.com/pytorch/executorch into cha…
AdrianLundell 18a1f7e
Merge branch 'main' of https://github.com/pytorch/executorch into cha…
AdrianLundell 67bc6cd
Add buck dependency
AdrianLundell cbcd78c
Merge branch 'main' into change-987432
digantdesai 6f71967
Fix bazel build
AdrianLundell 23b6a58
Merge branch 'main' of https://github.com/pytorch/executorch into cha…
AdrianLundell 1f9a61c
Merge branch 'change-987432' of github.com:AdrianLundell/executorch i…
AdrianLundell 73c995b
Merge branch 'main' into change-987432
AdrianLundell 8a45bfa
Merge branch 'main' into change-987432
digantdesai d091301
Merge branch 'main' into change-987432
digantdesai 1e629f7
Merge branch 'main' into change-987432
AdrianLundell a5e51e7
Merge branch 'main' into change-987432
AdrianLundell 9113109
Merge branch 'main' into change-987432
digantdesai File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
123 changes: 123 additions & 0 deletions
123
backends/transforms/test/test_create_delete_constant_placeholder.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
# Copyright 2025 Arm Limited and/or its affiliates. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import torch | ||
from executorch.backends.transforms.utils import ( | ||
create_constant_placeholder, | ||
delete_constant_placeholder, | ||
) | ||
from executorch.exir import to_edge | ||
from executorch.exir.dialects._ops import ops as exir_ops | ||
from torch.export import export | ||
from torch.export.graph_signature import InputKind | ||
|
||
|
||
class EmptyNetwork(torch.nn.Module): | ||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
return x | ||
|
||
test_data: torch.Tensor = (torch.zeros(1),) | ||
|
||
|
||
def _test_create_delete(kind: InputKind, persistent_buffer: bool = None): | ||
""" | ||
Tests the utility functions create_constant_placeholder and delete_constant_placeholder | ||
""" | ||
|
||
# Toy network with two nodes, input and output | ||
# The result should be 0 = 0 | ||
module = EmptyNetwork() | ||
exported_program = export(module, args=module.test_data) | ||
exported_program = to_edge(exported_program).exported_program() | ||
graph = exported_program.graph_module.graph | ||
assert len(graph.nodes) == 2 | ||
assert exported_program.module()(torch.zeros(1)) == 0 | ||
assert len(exported_program.graph_signature.input_specs) == 1 | ||
assert len(exported_program.state_dict) == 0 | ||
assert len(exported_program.constants) == 0 | ||
|
||
const_name = "test_node" | ||
|
||
# Create one const node with value 1 and add it to the input | ||
input_node = list(graph.nodes)[0] | ||
with graph.inserting_before(input_node): | ||
const_node = create_constant_placeholder( | ||
exp_program=exported_program, | ||
graph=graph, | ||
kind=kind, | ||
name=const_name, | ||
data=torch.ones(1), | ||
persistent_buffer=persistent_buffer, | ||
) | ||
assert "val" in const_node.meta | ||
|
||
with graph.inserting_after(input_node): | ||
add_node = graph.create_node( | ||
"call_function", | ||
exir_ops.edge.aten.add.Tensor, | ||
args=(input_node, const_node), | ||
kwargs={}, | ||
) | ||
|
||
output_node = list(graph.nodes)[-1] | ||
output_node.replace_input_with(input_node, add_node) | ||
|
||
# We should now have four nodes: test_node, input, add, output | ||
# The result should be 0 + 1 = 1 | ||
assert exported_program.module()(torch.zeros(1)) == 1 | ||
assert len(graph.nodes) == 4 | ||
|
||
if kind == InputKind.PARAMETER: | ||
assert const_name in exported_program.graph_signature.inputs_to_parameters | ||
assert const_name in exported_program.state_dict | ||
assert len(exported_program.constants) == 0 | ||
elif kind == InputKind.BUFFER and persistent_buffer: | ||
assert const_name in exported_program.graph_signature.inputs_to_buffers | ||
assert const_name in exported_program.state_dict | ||
assert len(exported_program.constants) == 0 | ||
elif kind == InputKind.BUFFER and not persistent_buffer: | ||
assert const_name in exported_program.graph_signature.inputs_to_buffers | ||
assert len(exported_program.state_dict) == 0 | ||
assert const_name in exported_program.constants | ||
elif kind == InputKind.CONSTANT_TENSOR: | ||
assert ( | ||
const_name | ||
in exported_program.graph_signature.inputs_to_lifted_tensor_constants | ||
) | ||
assert len(exported_program.state_dict) == 0 | ||
assert const_name in exported_program.constants | ||
else: | ||
raise RuntimeError("Wrong input kind") | ||
|
||
# Replacing the add op and using eliminate_dead_code() deletes the add op but not the input op | ||
output_node.replace_input_with(add_node, input_node) | ||
graph.eliminate_dead_code() | ||
assert len(graph.nodes) == 3 | ||
|
||
# Delete the input op manually | ||
# The result should again be 0 = 0 | ||
delete_constant_placeholder(exported_program, const_node) | ||
assert exported_program.module()(torch.zeros(1)) == 0 | ||
assert len(graph.nodes) == 2 | ||
assert len(exported_program.graph_signature.input_specs) == 1 | ||
assert len(exported_program.state_dict) == 0 | ||
assert len(exported_program.constants) == 0 | ||
|
||
|
||
def test_create_delete_parameter(): | ||
_test_create_delete(InputKind.PARAMETER) | ||
|
||
|
||
def test_create_delete_persistent_buffer(): | ||
_test_create_delete(InputKind.BUFFER, True) | ||
|
||
|
||
def test_create_delete_non_persistent_buffer(): | ||
_test_create_delete(InputKind.BUFFER, False) | ||
|
||
|
||
def test_create_delete_constant_tensor(): | ||
_test_create_delete(InputKind.CONSTANT_TENSOR) |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.