Skip to content

Commit 4dfddf5

Browse files
authored
Add pass to remove unused parameters in to_edge
Differential Revision: D73654202 Pull Request resolved: #10484
1 parent c6c0899 commit 4dfddf5

File tree

6 files changed

+296
-0
lines changed

6 files changed

+296
-0
lines changed

exir/passes/TARGETS

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ python_library(
2121
":quant_fusion_pass",
2222
":quantize_io_pass",
2323
":remove_noop_pass",
24+
":remove_unused_parameters_pass",
2425
":replace_aten_with_edge_pass",
2526
":replace_broken_ops_with_function_ops_pass",
2627
":replace_edge_with_backend_pass",
@@ -390,3 +391,14 @@ python_library(
390391
"//executorch/exir/dialects:lib",
391392
],
392393
)
394+
395+
python_library(
396+
name = "remove_unused_parameters_pass",
397+
srcs = [
398+
"remove_unused_parameters_pass.py",
399+
],
400+
deps = [
401+
"//caffe2:torch",
402+
"//executorch/exir/dialects:lib",
403+
],
404+
)

exir/passes/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,9 @@
4545
from executorch.exir.passes.prune_empty_tensors_pass import PruneEmptyTensorsPass
4646
from executorch.exir.passes.quant_fusion_pass import QuantFusionPass
4747
from executorch.exir.passes.remove_noop_pass import RemoveNoopPass, RemoveToCopyPass
48+
from executorch.exir.passes.remove_unused_parameters_pass import (
49+
remove_unused_parameters_pass,
50+
)
4851
from executorch.exir.passes.replace_aten_with_edge_pass import OpReplacePass
4952
from executorch.exir.passes.replace_broken_ops_with_function_ops_pass import (
5053
ReplaceBrokenOpsWithFunctionalOpsPass,
@@ -71,6 +74,7 @@
7174
"MemoryPlanningPass",
7275
"HintBasedSymShapeEvalPass",
7376
"insert_write_back_for_buffers_pass",
77+
"remove_unused_parameters_pass",
7478
"weights_to_outputs_pass",
7579
]
7680

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import copy
8+
9+
import torch
10+
11+
from torch.export.exported_program import ExportedProgram, InputKind
12+
13+
14+
def remove_unused_parameters_pass(
15+
ep: ExportedProgram,
16+
) -> ExportedProgram:
17+
"""
18+
Remove unused parameters from the exported program.
19+
"""
20+
21+
placeholder_nodes = {
22+
node.target: node
23+
for node in ep.graph_module.graph.nodes
24+
if node.op == "placeholder"
25+
}
26+
27+
unused_parameters = [
28+
s
29+
for s in ep.graph_signature.input_specs
30+
if s.kind == InputKind.PARAMETER
31+
and not _is_parameter_used(ep, s.arg.name, placeholder_nodes)
32+
]
33+
34+
# Remove params from the state dict, graph, and signature.
35+
new_signature = copy.deepcopy(ep.graph_signature)
36+
for param in unused_parameters:
37+
new_signature.input_specs.remove(param)
38+
del ep._state_dict[param.target]
39+
ep.graph_module.graph.erase_node(placeholder_nodes[param.arg.name])
40+
41+
ep._graph_signature = new_signature
42+
ep.graph_module.recompile()
43+
return ep
44+
45+
46+
def _is_parameter_used(
47+
ep: ExportedProgram, parameter: str, placeholder_nodes: dict[str, torch.fx.Node]
48+
) -> bool:
49+
placeholder_node = placeholder_nodes.get(parameter)
50+
if placeholder_node is None:
51+
raise RuntimeError(
52+
f"Invalid graph. No placeholder for {parameter} found in graph."
53+
)
54+
55+
return len(placeholder_node.users) > 0

exir/program/_program.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
EdgeToBackendOpsPass,
4343
MemoryFormatOpsPass,
4444
OpReplacePass,
45+
remove_unused_parameters_pass,
4546
)
4647
from executorch.exir.passes.external_constants_pass import (
4748
external_constants_pass,
@@ -801,6 +802,9 @@ def _generate_edge_program(
801802
assert gm_res is not None
802803
gm = gm_res.graph_module
803804

805+
# Remove unused parameters
806+
program = remove_unused_parameters_pass(program)
807+
804808
if config._check_ir_validity:
805809
try:
806810
EXIRATenDialectVerifier(

exir/tests/TARGETS

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -432,6 +432,22 @@ python_unittest(
432432
],
433433
)
434434

435+
python_unittest(
436+
name = "test_remove_unused_parameters_pass",
437+
srcs = [
438+
"test_remove_unused_parameters_pass.py",
439+
],
440+
deps = [
441+
"//caffe2:torch",
442+
"//executorch/backends/xnnpack:xnnpack_delegate",
443+
"//executorch/exir:lib",
444+
"//executorch/exir:memory",
445+
"//executorch/exir/capture:config",
446+
"//executorch/exir/passes:lib",
447+
"//executorch/runtime:runtime",
448+
],
449+
)
450+
435451
python_unittest(
436452
name = "test_remove_view_copy",
437453
srcs = [
Lines changed: 205 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,205 @@
1+
import unittest
2+
from typing import Sequence
3+
4+
import torch
5+
6+
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
7+
from executorch.exir import to_edge, to_edge_transform_and_lower
8+
from executorch.exir.passes import remove_unused_parameters_pass
9+
from executorch.runtime import Runtime
10+
from torch.export import ExportedProgram
11+
12+
13+
class TestRemoveUnusedParametersPass(unittest.TestCase):
14+
class SimpleModelWithUnusedParameters(torch.nn.Module):
15+
def __init__(self):
16+
super().__init__()
17+
self.linear1 = torch.nn.Linear(16, 16)
18+
self.unused_linear = torch.nn.Linear(1024, 1024)
19+
20+
def forward(self, x):
21+
return self.linear1(x)
22+
23+
class NestedModel(torch.nn.Module):
24+
def __init__(self):
25+
super().__init__()
26+
self.mod1 = TestRemoveUnusedParametersPass.SimpleModelWithUnusedParameters()
27+
self.mod2 = TestRemoveUnusedParametersPass.SimpleModelWithUnusedParameters()
28+
29+
def forward(self, x):
30+
y = self.mod1(x) + self.mod2(x)
31+
y += self.mod1.unused_linear(x.repeat([1, 64]))[:, :16]
32+
return y
33+
34+
def test_remove_unused_parameters_simple(self):
35+
model = self.SimpleModelWithUnusedParameters()
36+
model.eval()
37+
example_inputs = (torch.randn(1, 16),)
38+
eager_outputs = model(*example_inputs)
39+
ep = torch.export.export(model, example_inputs, strict=False)
40+
41+
unused_param_names_and_args = {
42+
"unused_linear.weight": "p_unused_linear_weight",
43+
"unused_linear.bias": "p_unused_linear_bias",
44+
}
45+
46+
self._test_pass(ep, unused_param_names_and_args, example_inputs, eager_outputs)
47+
48+
def test_remove_unused_parameters_nested(self):
49+
model = self.NestedModel()
50+
model.eval()
51+
example_inputs = (torch.randn(1, 16),)
52+
eager_outputs = model(*example_inputs)
53+
ep = torch.export.export(model, example_inputs, strict=False)
54+
55+
unused_param_names_and_args = {
56+
"mod2.unused_linear.weight": "p_mod2_unused_linear_weight",
57+
"mod2.unused_linear.bias": "p_mod2_unused_linear_bias",
58+
}
59+
60+
self._test_pass(ep, unused_param_names_and_args, example_inputs, eager_outputs)
61+
62+
def test_remove_unused_parameters_simple_e2e_to_edge(self):
63+
model = self.SimpleModelWithUnusedParameters().eval()
64+
example_inputs = (torch.randn(1, 16),)
65+
66+
# There are approximately 1M unused fp32 parameters - ~4Mb.
67+
# Without the unused params, the expected size is ~2.5Kb.
68+
size_bound = 10000
69+
70+
for strict in [False, True]:
71+
for delegate in [False, True]:
72+
self._test_pass_e2e(
73+
model,
74+
example_inputs,
75+
strict=strict,
76+
use_to_edge=True,
77+
delegate=delegate,
78+
size_bound=size_bound,
79+
)
80+
81+
def test_remove_unused_parameters_simple_e2e_to_edge_transform_and_lower(self):
82+
model = self.SimpleModelWithUnusedParameters().eval()
83+
example_inputs = (torch.randn(1, 16),)
84+
85+
# There are approximately 1M unused fp32 parameters - ~4Mb.
86+
# Without the unused params, the expected size is ~2.5Kb.
87+
size_bound = 10000
88+
89+
for strict in [False, True]:
90+
for delegate in [False, True]:
91+
self._test_pass_e2e(
92+
model,
93+
example_inputs,
94+
strict=strict,
95+
use_to_edge=False,
96+
delegate=delegate,
97+
size_bound=size_bound,
98+
)
99+
100+
def test_remove_unused_parameters_nested_e2e_to_edge(self):
101+
model = self.NestedModel().eval()
102+
example_inputs = (torch.randn(1, 16),)
103+
104+
size_bound = 20000 + 1024 * 1024 * 4
105+
106+
for strict in [False, True]:
107+
for delegate in [False, True]:
108+
self._test_pass_e2e(
109+
model,
110+
example_inputs,
111+
strict=strict,
112+
use_to_edge=True,
113+
delegate=delegate,
114+
size_bound=size_bound,
115+
)
116+
117+
def test_remove_unused_parameters_nested_e2e_to_edge_transform_and_lower(self):
118+
model = self.SimpleModelWithUnusedParameters().eval()
119+
example_inputs = (torch.randn(1, 16),)
120+
121+
size_bound = 20000 + 1024 * 1024 * 4
122+
123+
for strict in [False, True]:
124+
for delegate in [False, True]:
125+
self._test_pass_e2e(
126+
model,
127+
example_inputs,
128+
strict=strict,
129+
use_to_edge=False,
130+
delegate=delegate,
131+
size_bound=size_bound,
132+
)
133+
134+
def _test_pass(
135+
self,
136+
ep: ExportedProgram,
137+
unused_param_names_and_args: dict[str, str],
138+
example_inputs: Sequence[torch.Tensor],
139+
expected_outputs: torch.Tensor,
140+
):
141+
# Verify EP state before running the pass.
142+
placeholders = {
143+
n.target for n in ep.graph_module.graph.nodes if n.op == "placeholder"
144+
}
145+
for param_name, param_arg in unused_param_names_and_args.items():
146+
self.assertIn(param_name, ep.state_dict.keys())
147+
self.assertIn(param_name, ep.graph_signature.parameters)
148+
self.assertIn(param_arg, placeholders)
149+
150+
new_ep = remove_unused_parameters_pass(ep)
151+
152+
# Verify that the unused params are not in the state dict,
153+
# graph signature, or graph.
154+
new_placeholders = {
155+
n.target for n in new_ep.graph_module.graph.nodes if n.op == "placeholder"
156+
}
157+
for param_name, param_arg in unused_param_names_and_args.items():
158+
self.assertNotIn(param_name, new_ep.state_dict.keys())
159+
self.assertNotIn(param_name, new_ep.graph_signature.parameters)
160+
self.assertNotIn(param_arg, new_placeholders)
161+
162+
# Verify that the outputs are unchanged.
163+
new_outputs = new_ep.module()(*example_inputs)
164+
self.assertTrue(torch.allclose(new_outputs, expected_outputs))
165+
166+
def _test_pass_e2e(
167+
self,
168+
model: torch.nn.Module,
169+
example_inputs: Sequence[torch.Tensor],
170+
strict: bool,
171+
use_to_edge: bool,
172+
delegate: bool,
173+
size_bound: int,
174+
):
175+
eager_outputs = model(*example_inputs)
176+
ep = torch.export.export(model, example_inputs, strict=strict)
177+
178+
if use_to_edge:
179+
lowered = to_edge(ep)
180+
if delegate:
181+
lowered = lowered.to_backend(XnnpackPartitioner())
182+
else: # use to_edge_transform_and_lower
183+
lowered = to_edge_transform_and_lower(
184+
ep,
185+
partitioner=[XnnpackPartitioner()] if delegate else [],
186+
)
187+
188+
lowered = lowered.to_executorch()
189+
self.assertLess(len(lowered.buffer), size_bound)
190+
191+
# Make sure we can load and run the serialized .pte.
192+
runtime = Runtime.get()
193+
program = runtime.load_program(lowered.buffer)
194+
method = program.load_method("forward")
195+
runtime_outputs = method.execute([*example_inputs])
196+
197+
self.assertEqual(1, len(runtime_outputs))
198+
self.assertTrue(
199+
torch.allclose(runtime_outputs[0], eager_outputs, atol=2e-6),
200+
"Values out of tolerance.\n"
201+
+ f" Strict: {strict}, ToEdge: {use_to_edge}, Delegate: {delegate}.\n"
202+
+ f" Eager: {eager_outputs}.\n"
203+
+ f" Pybind: {runtime_outputs[0]}.\n"
204+
+ f" Error: {eager_outputs - runtime_outputs[0]}",
205+
)

0 commit comments

Comments
 (0)