|
| 1 | +import unittest |
| 2 | +from typing import Sequence |
| 3 | + |
| 4 | +import torch |
| 5 | + |
| 6 | +from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner |
| 7 | +from executorch.exir import to_edge, to_edge_transform_and_lower |
| 8 | +from executorch.exir.passes import remove_unused_parameters_pass |
| 9 | +from executorch.runtime import Runtime |
| 10 | +from torch.export import ExportedProgram |
| 11 | + |
| 12 | + |
| 13 | +class TestRemoveUnusedParametersPass(unittest.TestCase): |
| 14 | + class SimpleModelWithUnusedParameters(torch.nn.Module): |
| 15 | + def __init__(self): |
| 16 | + super().__init__() |
| 17 | + self.linear1 = torch.nn.Linear(16, 16) |
| 18 | + self.unused_linear = torch.nn.Linear(1024, 1024) |
| 19 | + |
| 20 | + def forward(self, x): |
| 21 | + return self.linear1(x) |
| 22 | + |
| 23 | + class NestedModel(torch.nn.Module): |
| 24 | + def __init__(self): |
| 25 | + super().__init__() |
| 26 | + self.mod1 = TestRemoveUnusedParametersPass.SimpleModelWithUnusedParameters() |
| 27 | + self.mod2 = TestRemoveUnusedParametersPass.SimpleModelWithUnusedParameters() |
| 28 | + |
| 29 | + def forward(self, x): |
| 30 | + y = self.mod1(x) + self.mod2(x) |
| 31 | + y += self.mod1.unused_linear(x.repeat([1, 64]))[:, :16] |
| 32 | + return y |
| 33 | + |
| 34 | + def test_remove_unused_parameters_simple(self): |
| 35 | + model = self.SimpleModelWithUnusedParameters() |
| 36 | + model.eval() |
| 37 | + example_inputs = (torch.randn(1, 16),) |
| 38 | + eager_outputs = model(*example_inputs) |
| 39 | + ep = torch.export.export(model, example_inputs, strict=False) |
| 40 | + |
| 41 | + unused_param_names_and_args = { |
| 42 | + "unused_linear.weight": "p_unused_linear_weight", |
| 43 | + "unused_linear.bias": "p_unused_linear_bias", |
| 44 | + } |
| 45 | + |
| 46 | + self._test_pass(ep, unused_param_names_and_args, example_inputs, eager_outputs) |
| 47 | + |
| 48 | + def test_remove_unused_parameters_nested(self): |
| 49 | + model = self.NestedModel() |
| 50 | + model.eval() |
| 51 | + example_inputs = (torch.randn(1, 16),) |
| 52 | + eager_outputs = model(*example_inputs) |
| 53 | + ep = torch.export.export(model, example_inputs, strict=False) |
| 54 | + |
| 55 | + unused_param_names_and_args = { |
| 56 | + "mod2.unused_linear.weight": "p_mod2_unused_linear_weight", |
| 57 | + "mod2.unused_linear.bias": "p_mod2_unused_linear_bias", |
| 58 | + } |
| 59 | + |
| 60 | + self._test_pass(ep, unused_param_names_and_args, example_inputs, eager_outputs) |
| 61 | + |
| 62 | + def test_remove_unused_parameters_simple_e2e_to_edge(self): |
| 63 | + model = self.SimpleModelWithUnusedParameters().eval() |
| 64 | + example_inputs = (torch.randn(1, 16),) |
| 65 | + |
| 66 | + # There are approximately 1M unused fp32 parameters - ~4Mb. |
| 67 | + # Without the unused params, the expected size is ~2.5Kb. |
| 68 | + size_bound = 10000 |
| 69 | + |
| 70 | + for strict in [False, True]: |
| 71 | + for delegate in [False, True]: |
| 72 | + self._test_pass_e2e( |
| 73 | + model, |
| 74 | + example_inputs, |
| 75 | + strict=strict, |
| 76 | + use_to_edge=True, |
| 77 | + delegate=delegate, |
| 78 | + size_bound=size_bound, |
| 79 | + ) |
| 80 | + |
| 81 | + def test_remove_unused_parameters_simple_e2e_to_edge_transform_and_lower(self): |
| 82 | + model = self.SimpleModelWithUnusedParameters().eval() |
| 83 | + example_inputs = (torch.randn(1, 16),) |
| 84 | + |
| 85 | + # There are approximately 1M unused fp32 parameters - ~4Mb. |
| 86 | + # Without the unused params, the expected size is ~2.5Kb. |
| 87 | + size_bound = 10000 |
| 88 | + |
| 89 | + for strict in [False, True]: |
| 90 | + for delegate in [False, True]: |
| 91 | + self._test_pass_e2e( |
| 92 | + model, |
| 93 | + example_inputs, |
| 94 | + strict=strict, |
| 95 | + use_to_edge=False, |
| 96 | + delegate=delegate, |
| 97 | + size_bound=size_bound, |
| 98 | + ) |
| 99 | + |
| 100 | + def test_remove_unused_parameters_nested_e2e_to_edge(self): |
| 101 | + model = self.NestedModel().eval() |
| 102 | + example_inputs = (torch.randn(1, 16),) |
| 103 | + |
| 104 | + size_bound = 20000 + 1024 * 1024 * 4 |
| 105 | + |
| 106 | + for strict in [False, True]: |
| 107 | + for delegate in [False, True]: |
| 108 | + self._test_pass_e2e( |
| 109 | + model, |
| 110 | + example_inputs, |
| 111 | + strict=strict, |
| 112 | + use_to_edge=True, |
| 113 | + delegate=delegate, |
| 114 | + size_bound=size_bound, |
| 115 | + ) |
| 116 | + |
| 117 | + def test_remove_unused_parameters_nested_e2e_to_edge_transform_and_lower(self): |
| 118 | + model = self.SimpleModelWithUnusedParameters().eval() |
| 119 | + example_inputs = (torch.randn(1, 16),) |
| 120 | + |
| 121 | + size_bound = 20000 + 1024 * 1024 * 4 |
| 122 | + |
| 123 | + for strict in [False, True]: |
| 124 | + for delegate in [False, True]: |
| 125 | + self._test_pass_e2e( |
| 126 | + model, |
| 127 | + example_inputs, |
| 128 | + strict=strict, |
| 129 | + use_to_edge=False, |
| 130 | + delegate=delegate, |
| 131 | + size_bound=size_bound, |
| 132 | + ) |
| 133 | + |
| 134 | + def _test_pass( |
| 135 | + self, |
| 136 | + ep: ExportedProgram, |
| 137 | + unused_param_names_and_args: dict[str, str], |
| 138 | + example_inputs: Sequence[torch.Tensor], |
| 139 | + expected_outputs: torch.Tensor, |
| 140 | + ): |
| 141 | + # Verify EP state before running the pass. |
| 142 | + placeholders = { |
| 143 | + n.target for n in ep.graph_module.graph.nodes if n.op == "placeholder" |
| 144 | + } |
| 145 | + for param_name, param_arg in unused_param_names_and_args.items(): |
| 146 | + self.assertIn(param_name, ep.state_dict.keys()) |
| 147 | + self.assertIn(param_name, ep.graph_signature.parameters) |
| 148 | + self.assertIn(param_arg, placeholders) |
| 149 | + |
| 150 | + new_ep = remove_unused_parameters_pass(ep) |
| 151 | + |
| 152 | + # Verify that the unused params are not in the state dict, |
| 153 | + # graph signature, or graph. |
| 154 | + new_placeholders = { |
| 155 | + n.target for n in new_ep.graph_module.graph.nodes if n.op == "placeholder" |
| 156 | + } |
| 157 | + for param_name, param_arg in unused_param_names_and_args.items(): |
| 158 | + self.assertNotIn(param_name, new_ep.state_dict.keys()) |
| 159 | + self.assertNotIn(param_name, new_ep.graph_signature.parameters) |
| 160 | + self.assertNotIn(param_arg, new_placeholders) |
| 161 | + |
| 162 | + # Verify that the outputs are unchanged. |
| 163 | + new_outputs = new_ep.module()(*example_inputs) |
| 164 | + self.assertTrue(torch.allclose(new_outputs, expected_outputs)) |
| 165 | + |
| 166 | + def _test_pass_e2e( |
| 167 | + self, |
| 168 | + model: torch.nn.Module, |
| 169 | + example_inputs: Sequence[torch.Tensor], |
| 170 | + strict: bool, |
| 171 | + use_to_edge: bool, |
| 172 | + delegate: bool, |
| 173 | + size_bound: int, |
| 174 | + ): |
| 175 | + eager_outputs = model(*example_inputs) |
| 176 | + ep = torch.export.export(model, example_inputs, strict=strict) |
| 177 | + |
| 178 | + if use_to_edge: |
| 179 | + lowered = to_edge(ep) |
| 180 | + if delegate: |
| 181 | + lowered = lowered.to_backend(XnnpackPartitioner()) |
| 182 | + else: # use to_edge_transform_and_lower |
| 183 | + lowered = to_edge_transform_and_lower( |
| 184 | + ep, |
| 185 | + partitioner=[XnnpackPartitioner()] if delegate else [], |
| 186 | + ) |
| 187 | + |
| 188 | + lowered = lowered.to_executorch() |
| 189 | + self.assertLess(len(lowered.buffer), size_bound) |
| 190 | + |
| 191 | + # Make sure we can load and run the serialized .pte. |
| 192 | + runtime = Runtime.get() |
| 193 | + program = runtime.load_program(lowered.buffer) |
| 194 | + method = program.load_method("forward") |
| 195 | + runtime_outputs = method.execute([*example_inputs]) |
| 196 | + |
| 197 | + self.assertEqual(1, len(runtime_outputs)) |
| 198 | + self.assertTrue( |
| 199 | + torch.allclose(runtime_outputs[0], eager_outputs, atol=2e-6), |
| 200 | + "Values out of tolerance.\n" |
| 201 | + + f" Strict: {strict}, ToEdge: {use_to_edge}, Delegate: {delegate}.\n" |
| 202 | + + f" Eager: {eager_outputs}.\n" |
| 203 | + + f" Pybind: {runtime_outputs[0]}.\n" |
| 204 | + + f" Error: {eager_outputs - runtime_outputs[0]}", |
| 205 | + ) |
0 commit comments