Skip to content

Commit 072403b

Browse files
Arm backend: Add FuseEqualPlaceholdersPass (#9893)
Internal tests looking good. Thanks.
1 parent 6371758 commit 072403b

File tree

5 files changed

+184
-1
lines changed

5 files changed

+184
-1
lines changed

backends/arm/_passes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
)
4040
from .fuse_batchnorm2d_pass import FuseBatchnorm2DPass # noqa
4141
from .fuse_constant_ops_pass import ComputeConstantOpsAOT, FuseConstantArgsPass # noqa
42+
from .fuse_equal_placeholders_pass import FuseEqualPlaceholdersPass # noqa
4243
from .fuse_quantized_activation_pass import FuseQuantizedActivationPass # noqa
4344
from .insert_rescales_pass import InsertRescalePass # noqa
4445
from .insert_table_ops import InsertTableOpsPass # noqa

backends/arm/_passes/arm_pass_manager.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
FoldAndAnnotateQParamsPass,
4141
FuseBatchnorm2DPass,
4242
FuseConstantArgsPass,
43+
FuseEqualPlaceholdersPass,
4344
FuseQuantizedActivationPass,
4445
InsertRescalePass,
4546
InsertTableOpsPass,
@@ -113,6 +114,7 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
113114
self.add_pass(FuseConstantArgsPass(exported_program))
114115

115116
self.add_pass(InsertTableOpsPass(exported_program))
117+
self.add_pass(FuseEqualPlaceholdersPass(exported_program))
116118
self.add_pass(AnnotateChannelsLastDimOrder())
117119
self.add_pass(InsertRescalePass())
118120

@@ -164,6 +166,7 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
164166
self.add_pass(FuseViewCopyTransform())
165167
self.add_pass(FuseConstantArgsPass(exported_program))
166168
self.add_pass(InsertTableOpsPass(exported_program))
169+
self.add_pass(FuseEqualPlaceholdersPass(exported_program))
167170
self.add_pass(AnnotateChannelsLastDimOrder())
168171
self.add_pass(InsertRescalePass())
169172

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import torch
7+
from executorch.backends.arm._passes.arm_pass_utils import (
8+
get_constant_placeholder_kind,
9+
get_param_tensor,
10+
is_param_node,
11+
)
12+
from executorch.backends.transforms.utils import (
13+
create_constant_placeholder,
14+
delete_constant_placeholder,
15+
)
16+
from executorch.exir import ExportedProgram
17+
from executorch.exir.pass_base import ExportPass, PassResult
18+
19+
20+
class FuseEqualPlaceholdersPass(ExportPass):
21+
"""
22+
This pass optimizes memory usage by finding constant placeholders
23+
pointing to identical tensors and fusing them to one single placeholder
24+
with multiple users.
25+
"""
26+
27+
def __init__(self, exported_program: ExportedProgram):
28+
self.exported_program = exported_program
29+
super().__init__()
30+
31+
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
32+
modified = False
33+
const_placeholder_nodes = []
34+
for node in graph_module.graph.nodes:
35+
if is_param_node(self.exported_program, node):
36+
const_placeholder_nodes.append(node)
37+
38+
while const_placeholder_nodes:
39+
40+
# Find equal tensors
41+
node1 = const_placeholder_nodes.pop()
42+
eq_nodes = [node1]
43+
tensor1 = get_param_tensor(self.exported_program, node1)
44+
if tensor1 is None:
45+
continue
46+
47+
for node2 in const_placeholder_nodes:
48+
tensor2 = get_param_tensor(self.exported_program, node2)
49+
if tensor2 is None:
50+
continue
51+
52+
if torch.equal(tensor1, tensor2):
53+
eq_nodes.append(node2)
54+
55+
if len(eq_nodes) > 1:
56+
common_name = node1.name + "_common"
57+
common_kind = get_constant_placeholder_kind(
58+
self.exported_program, node1
59+
)
60+
common_persisten_buffer = True
61+
62+
with graph_module.graph.inserting_before(node1):
63+
common_node = create_constant_placeholder(
64+
self.exported_program,
65+
graph_module.graph,
66+
common_name,
67+
common_kind,
68+
tensor1,
69+
common_persisten_buffer,
70+
)
71+
72+
for eq_node in eq_nodes:
73+
eq_node.replace_all_uses_with(common_node)
74+
delete_constant_placeholder(self.exported_program, eq_node)
75+
if eq_node != node1:
76+
const_placeholder_nodes.remove(eq_node)
77+
78+
modified = True
79+
80+
if modified:
81+
graph_module.recompile()
82+
graph_module = super().call(graph_module).graph_module
83+
return PassResult(graph_module=graph_module, modified=modified)
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from copy import deepcopy
7+
from typing import Tuple
8+
9+
import torch
10+
from executorch.backends.arm._passes.fuse_equal_placeholders_pass import (
11+
FuseEqualPlaceholdersPass,
12+
)
13+
from executorch.backends.arm.test.tester.test_pipeline import PassPipeline
14+
15+
input_t = Tuple[torch.Tensor] # Input x
16+
17+
18+
class FuseWeightsConstants(torch.nn.Module):
19+
ops_before_pass = {}
20+
ops_after_pass = {}
21+
ops_not_after_pass = []
22+
23+
def __init__(
24+
self,
25+
):
26+
super().__init__()
27+
self.weights1 = torch.rand(1, 2, 1)
28+
self.weights2 = deepcopy(self.weights1)
29+
self.bias1 = torch.rand(1)
30+
self.bias2 = deepcopy(self.bias1)
31+
self.bias3 = deepcopy(self.bias1)
32+
33+
def forward(self, x):
34+
return (
35+
torch.conv1d(x, self.weights1, self.bias1)
36+
+ torch.conv1d(x, self.weights2, self.bias2)
37+
+ self.bias3
38+
)
39+
40+
41+
class FuseWeightsStateDict(torch.nn.Module):
42+
ops_before_pass = {}
43+
ops_after_pass = {}
44+
ops_not_after_pass = []
45+
46+
def __init__(
47+
self,
48+
):
49+
super().__init__()
50+
self.fc1 = torch.nn.Linear(in_features=8, out_features=2, bias=True)
51+
self.fc2 = deepcopy(self.fc1)
52+
53+
def forward(self, x):
54+
return self.fc1(x) + self.fc2(x)
55+
56+
57+
def test_fuse_equal_placeholders_constants_tosa_MI():
58+
module = FuseWeightsConstants()
59+
data = (torch.rand(1, 2, 8),)
60+
pipeline = PassPipeline[input_t](
61+
module,
62+
data,
63+
tosa_version="TOSA-0.80+MI",
64+
ops_before_pass=module.ops_before_pass,
65+
ops_after_pass=module.ops_after_pass,
66+
passes_with_exported_program=[FuseEqualPlaceholdersPass],
67+
)
68+
pipeline.run()
69+
70+
# Check that weights and bias has been merged.
71+
exp_program = pipeline.tester.get_artifact().exported_program()
72+
constant_keys = list(exp_program.constants.keys())
73+
assert len(constant_keys) == 2, "FuseEqualPlaceholders constants failed"
74+
assert "_common" in constant_keys[0], "FuseEqualPlaceholders constants failed"
75+
assert "_common" in constant_keys[1], "FuseEqualPlaceholders constants failed"
76+
77+
78+
def test_fuse_equal_placeholders_state_dict_tosa_MI():
79+
module = FuseWeightsStateDict()
80+
data = (torch.rand(1, 2, 8),)
81+
pipeline = PassPipeline[input_t](
82+
module,
83+
data,
84+
tosa_version="TOSA-0.80+MI",
85+
ops_before_pass=module.ops_before_pass,
86+
ops_after_pass=module.ops_after_pass,
87+
passes_with_exported_program=[FuseEqualPlaceholdersPass],
88+
)
89+
pipeline.run()
90+
91+
# Check that weights and bias has been merged.
92+
exp_program = pipeline.tester.get_artifact().exported_program()
93+
state_dict_keys = list(exp_program.state_dict.keys())
94+
assert len(state_dict_keys) == 2, "FuseEqualPlaceholders state_dict failed"
95+
assert "_common" in state_dict_keys[0], "FuseEqualPlaceholders state_dict failed"
96+
assert "_common" in state_dict_keys[1], "FuseEqualPlaceholders state_dict failed"

examples/arm/setup.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ fi
6060

6161
# vela
6262
vela_repo_url="https://gitlab.arm.com/artificial-intelligence/ethos-u/ethos-u-vela"
63-
vela_rev="425541302c7e4b6fbeca7c0061286b131ee507c3"
63+
vela_rev="859cc066178a87ff28230c1ce9bd370f1e98aa5a"
6464

6565
########
6666
### Functions

0 commit comments

Comments
 (0)