@@ -232,13 +232,13 @@ def remove_client(
232
232
entry for `var` in `self.clients`.
233
233
234
234
"""
235
-
235
+ clients = self . clients
236
236
removal_stack = [(var , client_to_remove )]
237
237
while removal_stack :
238
238
var , client_to_remove = removal_stack .pop ()
239
239
240
240
try :
241
- var_clients = self . clients [var ]
241
+ var_clients = clients [var ]
242
242
var_clients .remove (client_to_remove )
243
243
except ValueError :
244
244
# In this case, the original `var` could've been removed from
@@ -256,9 +256,7 @@ def remove_client(
256
256
self .variables .remove (var )
257
257
else :
258
258
apply_node = var .owner
259
- if not any (
260
- output for output in apply_node .outputs if self .clients [output ]
261
- ):
259
+ if not any (clients [output ] for output in apply_node .outputs ):
262
260
# The `Apply` node is not used and is not an output, so we
263
261
# remove it and its outputs
264
262
if not hasattr (apply_node .tag , "removed_by" ):
@@ -276,7 +274,7 @@ def remove_client(
276
274
removal_stack .append ((in_var , (apply_node , i )))
277
275
278
276
if remove_if_empty :
279
- del self . clients [var ]
277
+ del clients [var ]
280
278
281
279
def import_var (
282
280
self , var : Variable , reason : str | None = None , import_missing : bool = False
@@ -563,10 +561,11 @@ def remove_node(self, node: Apply, reason: str | None = None):
563
561
node .tag .removed_by .append (str (reason ))
564
562
565
563
# Remove the outputs of the node (i.e. everything "below" it)
564
+ clients = self .clients
566
565
for out in node .outputs :
567
566
self .variables .remove (out )
568
567
569
- out_clients = self . clients .get (out , ())
568
+ out_clients = clients .get (out , ())
570
569
while out_clients :
571
570
out_client , out_idx = out_clients .pop ()
572
571
@@ -590,13 +589,12 @@ def remove_node(self, node: Apply, reason: str | None = None):
590
589
assert isinstance (out_client , Apply )
591
590
self .remove_node (out_client , reason = reason )
592
591
593
- if out in self .clients :
594
- del self .clients [out ]
592
+ clients .pop (out , None )
595
593
596
594
# Remove all the arrows pointing to this `node`, and any orphaned
597
595
# variables created by removing those arrows
598
596
for inp_idx , inp in enumerate (node .inputs ):
599
- inp_clients : list [ClientType ] = self . clients .get (inp , [])
597
+ inp_clients : list [ClientType ] = clients .get (inp , [])
600
598
601
599
arrow = (node , inp_idx )
602
600
@@ -810,12 +808,13 @@ def check_integrity(self) -> None:
810
808
raise Exception (
811
809
f"The following nodes are inappropriately cached:\n missing: { nodes_missing } \n in excess: { nodes_excess } "
812
810
)
811
+ clients = self .clients
813
812
for node in nodes :
814
813
for i , variable in enumerate (node .inputs ):
815
- clients = self . clients [variable ]
816
- if (node , i ) not in clients :
814
+ var_clients = clients [variable ]
815
+ if (node , i ) not in var_clients :
817
816
raise Exception (
818
- f"Inconsistent clients list { (node , i )} in { clients } "
817
+ f"Inconsistent clients list { (node , i )} in { var_clients } "
819
818
)
820
819
variables = set (vars_between (self .inputs , self .outputs ))
821
820
if set (self .variables ) != variables :
@@ -831,7 +830,7 @@ def check_integrity(self) -> None:
831
830
and not isinstance (variable , AtomicVariable )
832
831
):
833
832
raise Exception (f"Undeclared input: { variable } " )
834
- for cl_node , i in self . clients [variable ]:
833
+ for cl_node , i in clients [variable ]:
835
834
if cl_node == "output" :
836
835
if self .outputs [i ] is not variable :
837
836
raise Exception (
0 commit comments