@@ -45,30 +45,30 @@ def forward(self, x):
45
45
def test_basic_generated_identifier (self ):
46
46
delegate_builder = DelegateMappingBuilder (generated_identifiers = True )
47
47
48
- expected_mapping = {0 : (0 , 1 , 2 , 3 )}
48
+ expected_mapping = {0 : (1 , 2 , 3 , 4 )}
49
49
self .assertEqual (
50
50
delegate_builder .insert_delegate_mapping_entry (nodes = self .nodes ), 0
51
51
)
52
52
self .assertEqual (delegate_builder .get_delegate_mapping (), expected_mapping )
53
53
54
- expected_mapping = {0 : (0 , 1 , 2 , 3 ), 1 : (0 ,)}
54
+ expected_mapping = {0 : (1 , 2 , 3 , 4 ), 1 : (1 ,)}
55
55
self .assertEqual (
56
56
delegate_builder .insert_delegate_mapping_entry (nodes = self .nodes [0 ]), 1
57
57
)
58
58
self .assertEqual (delegate_builder .get_delegate_mapping (), expected_mapping )
59
59
60
- expected_mapping = {0 : (0 , 1 , 2 , 3 ), 1 : (0 ,), 2 : (1 ,)}
60
+ expected_mapping = {0 : (1 , 2 , 3 , 4 ), 1 : (1 ,), 2 : (2 ,)}
61
61
self .assertEqual (
62
62
delegate_builder .insert_delegate_mapping_entry (handles = self .handles [2 ]),
63
63
2 ,
64
64
)
65
65
self .assertEqual (delegate_builder .get_delegate_mapping (), expected_mapping )
66
66
67
67
expected_mapping = {
68
- 0 : (0 , 1 , 2 , 3 ),
69
- 1 : (0 ,),
70
- 2 : (1 ,),
71
- 3 : (0 , 1 , 2 , 3 ),
68
+ 0 : (1 , 2 , 3 , 4 ),
69
+ 1 : (1 ,),
70
+ 2 : (2 ,),
71
+ 3 : (1 , 2 , 3 , 4 ),
72
72
}
73
73
self .assertEqual (
74
74
delegate_builder .insert_delegate_mapping_entry (handles = self .handles ), 3
@@ -144,7 +144,7 @@ def test_backend_with_delegate_mapping(self) -> None:
144
144
self .assertEqual (len (debug_handle_map ), 5 )
145
145
# Check to see that all the delegate debug indexes in the range [0,2] are present.
146
146
self .assertTrue (
147
- all (element in debug_handle_map .keys () for element in [0 , 1 , 2 , 3 ])
147
+ all (element in debug_handle_map .keys () for element in [1 , 2 , 3 , 4 ])
148
148
)
149
149
150
150
class CompositeModule (torch .nn .Module ):
@@ -200,7 +200,7 @@ def _test_basic_manual_identifier(self, identifiers: Iterator[Union[int, str]]):
200
200
201
201
# Entry with a list of nodes
202
202
iden_1 = next (identifiers )
203
- expected_mapping = {iden_1 : (0 , 1 , 2 , 3 )}
203
+ expected_mapping = {iden_1 : (1 , 2 , 3 , 4 )}
204
204
self .assertEqual (
205
205
delegate_builder_nodes .insert_delegate_mapping_entry (
206
206
nodes = self .nodes , identifier = iden_1
@@ -222,7 +222,7 @@ def _test_basic_manual_identifier(self, identifiers: Iterator[Union[int, str]]):
222
222
223
223
# Entry with a single node
224
224
iden_2 = next (identifiers )
225
- expected_mapping = {iden_1 : (0 , 1 , 2 , 3 ), iden_2 : (0 ,)}
225
+ expected_mapping = {iden_1 : (1 , 2 , 3 , 4 ), iden_2 : (1 ,)}
226
226
self .assertEqual (
227
227
delegate_builder_nodes .insert_delegate_mapping_entry (
228
228
nodes = self .nodes [0 ], identifier = iden_2
0 commit comments