Skip to content

Commit fe295b9

Browse files
tarun292facebook-github-bot
authored andcommitted
Add helper method to generate missing debug handles (#5902)
Summary: Pull Request resolved: #5902 Helper function to generate missing debug handles on nodes, which is usually needed when graph transforms are done and new nodes are inserted. Reviewed By: Vysarat Differential Revision: D63913905 fbshipit-source-id: 0a24afa1d5207356a2706db88ff7828a5fba7f1a
1 parent 62a13c1 commit fe295b9

File tree

3 files changed

+64
-13
lines changed

3 files changed

+64
-13
lines changed

exir/backend/test/test_delegate_map_builder.py

+10-10
Original file line numberDiff line numberDiff line change
@@ -45,30 +45,30 @@ def forward(self, x):
4545
def test_basic_generated_identifier(self):
4646
delegate_builder = DelegateMappingBuilder(generated_identifiers=True)
4747

48-
expected_mapping = {0: (0, 1, 2, 3)}
48+
expected_mapping = {0: (1, 2, 3, 4)}
4949
self.assertEqual(
5050
delegate_builder.insert_delegate_mapping_entry(nodes=self.nodes), 0
5151
)
5252
self.assertEqual(delegate_builder.get_delegate_mapping(), expected_mapping)
5353

54-
expected_mapping = {0: (0, 1, 2, 3), 1: (0,)}
54+
expected_mapping = {0: (1, 2, 3, 4), 1: (1,)}
5555
self.assertEqual(
5656
delegate_builder.insert_delegate_mapping_entry(nodes=self.nodes[0]), 1
5757
)
5858
self.assertEqual(delegate_builder.get_delegate_mapping(), expected_mapping)
5959

60-
expected_mapping = {0: (0, 1, 2, 3), 1: (0,), 2: (1,)}
60+
expected_mapping = {0: (1, 2, 3, 4), 1: (1,), 2: (2,)}
6161
self.assertEqual(
6262
delegate_builder.insert_delegate_mapping_entry(handles=self.handles[2]),
6363
2,
6464
)
6565
self.assertEqual(delegate_builder.get_delegate_mapping(), expected_mapping)
6666

6767
expected_mapping = {
68-
0: (0, 1, 2, 3),
69-
1: (0,),
70-
2: (1,),
71-
3: (0, 1, 2, 3),
68+
0: (1, 2, 3, 4),
69+
1: (1,),
70+
2: (2,),
71+
3: (1, 2, 3, 4),
7272
}
7373
self.assertEqual(
7474
delegate_builder.insert_delegate_mapping_entry(handles=self.handles), 3
@@ -144,7 +144,7 @@ def test_backend_with_delegate_mapping(self) -> None:
144144
self.assertEqual(len(debug_handle_map), 5)
145145
# Check to see that all the delegate debug indexes in the range [0,2] are present.
146146
self.assertTrue(
147-
all(element in debug_handle_map.keys() for element in [0, 1, 2, 3])
147+
all(element in debug_handle_map.keys() for element in [1, 2, 3, 4])
148148
)
149149

150150
class CompositeModule(torch.nn.Module):
@@ -200,7 +200,7 @@ def _test_basic_manual_identifier(self, identifiers: Iterator[Union[int, str]]):
200200

201201
# Entry with a list of nodes
202202
iden_1 = next(identifiers)
203-
expected_mapping = {iden_1: (0, 1, 2, 3)}
203+
expected_mapping = {iden_1: (1, 2, 3, 4)}
204204
self.assertEqual(
205205
delegate_builder_nodes.insert_delegate_mapping_entry(
206206
nodes=self.nodes, identifier=iden_1
@@ -222,7 +222,7 @@ def _test_basic_manual_identifier(self, identifiers: Iterator[Union[int, str]]):
222222

223223
# Entry with a single node
224224
iden_2 = next(identifiers)
225-
expected_mapping = {iden_1: (0, 1, 2, 3), iden_2: (0,)}
225+
expected_mapping = {iden_1: (1, 2, 3, 4), iden_2: (1,)}
226226
self.assertEqual(
227227
delegate_builder_nodes.insert_delegate_mapping_entry(
228228
nodes=self.nodes[0], identifier=iden_2

exir/passes/debug_handle_generator_pass.py

+34-1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from executorch.exir.graph_module import get_control_flow_submodules
88
from executorch.exir.pass_base import ExportPass
9+
from torch.export import ExportedProgram
910
from torch.fx import GraphModule
1011
from torch.fx.passes.infra.pass_base import PassResult
1112

@@ -17,7 +18,7 @@ def call(self, graph_module: GraphModule) -> PassResult:
1718
"""
1819

1920
queue = [graph_module]
20-
index = 0
21+
index = 1
2122
# bfs to traverse all modules including control flow submodules to attached debug handle id
2223
while queue:
2324
current_graph_module = queue.pop(0)
@@ -30,3 +31,35 @@ def call(self, graph_module: GraphModule) -> PassResult:
3031
]
3132
queue.extend(control_flow_submodules)
3233
return PassResult(graph_module, True)
34+
35+
36+
def generate_missing_debug_handles(ep: ExportedProgram):
37+
"""
38+
This pass is used to generate missing debug handles for the graph module and its submodules.
39+
"""
40+
41+
def get_control_flow_submodules_list(graph_module):
42+
return [
43+
submodule for _, submodule, _ in get_control_flow_submodules(graph_module)
44+
]
45+
46+
max_handle = 0
47+
queue = [ep.graph_module]
48+
49+
while queue:
50+
current_graph_module = queue.pop(0)
51+
for node in current_graph_module.graph.nodes:
52+
if "debug_handle" in node.meta:
53+
max_handle = max(max_handle, node.meta["debug_handle"])
54+
control_flow_submodules = get_control_flow_submodules_list(ep.graph_module)
55+
queue.extend(control_flow_submodules)
56+
57+
queue = [ep.graph_module]
58+
while queue:
59+
current_graph_module = queue.pop(0)
60+
for node in current_graph_module.graph.nodes:
61+
if node.meta.get("debug_handle", 0) in (0, None):
62+
node.meta["debug_handle"] = max_handle + 1
63+
max_handle += 1
64+
control_flow_submodules = get_control_flow_submodules_list(ep.graph_module)
65+
queue.extend(control_flow_submodules)

exir/tests/test_passes.py

+20-2
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,10 @@
3333
ToOutVarPass,
3434
)
3535
from executorch.exir.passes.constant_prop_pass import constant_prop_pass
36-
from executorch.exir.passes.debug_handle_generator_pass import DebugHandleGeneratorPass
36+
from executorch.exir.passes.debug_handle_generator_pass import (
37+
DebugHandleGeneratorPass,
38+
generate_missing_debug_handles,
39+
)
3740
from executorch.exir.passes.insert_write_back_for_buffers_pass import (
3841
insert_write_back_for_buffers_pass,
3942
)
@@ -949,13 +952,28 @@ def test_debug_handle_generator_pass(self) -> None:
949952
.exported_program()
950953
.graph_module
951954
)
952-
DebugHandleGeneratorPass()(graph_module)
953955
for node in graph_module.graph.nodes:
954956
self.assertIn("debug_handle", node.meta)
955957
ScalarToTensorPass()(graph_module)
956958
for node in graph_module.graph.nodes:
957959
self.assertIn("debug_handle", node.meta)
958960

961+
def test_generate_missing_debug_handles(self) -> None:
962+
eager_model = MLP(2, output_size=4)
963+
inputs = eager_model.get_random_inputs()
964+
965+
ep = to_edge(
966+
export(
967+
eager_model,
968+
inputs,
969+
)
970+
).exported_program()
971+
972+
list(ep.graph.nodes)[0].meta.pop("debug_handle")
973+
self.assertTrue(list(ep.graph.nodes)[0].meta.get("debug_handle") is None)
974+
generate_missing_debug_handles(ep)
975+
self.assertTrue(list(ep.graph.nodes)[0].meta.get("debug_handle") is not None)
976+
959977
def test_debug_handle_generator_pass_with_control_flow(self) -> None:
960978
def true_nested(y: torch.Tensor) -> torch.Tensor:
961979
y = y + y

0 commit comments

Comments
 (0)