Skip to content

Commit be0e13a

Browse files
committed
Reduce number of access to self.clients in FunctionGraph
To speedup hot rewrite loops
1 parent 36c55f5 commit be0e13a

File tree

1 file changed

+13
-14
lines changed

1 file changed

+13
-14
lines changed

pytensor/graph/fg.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -232,13 +232,13 @@ def remove_client(
232232
entry for `var` in `self.clients`.
233233
234234
"""
235-
235+
clients = self.clients
236236
removal_stack = [(var, client_to_remove)]
237237
while removal_stack:
238238
var, client_to_remove = removal_stack.pop()
239239

240240
try:
241-
var_clients = self.clients[var]
241+
var_clients = clients[var]
242242
var_clients.remove(client_to_remove)
243243
except ValueError:
244244
# In this case, the original `var` could've been removed from
@@ -256,9 +256,7 @@ def remove_client(
256256
self.variables.remove(var)
257257
else:
258258
apply_node = var.owner
259-
if not any(
260-
output for output in apply_node.outputs if self.clients[output]
261-
):
259+
if not any(clients[output] for output in apply_node.outputs):
262260
# The `Apply` node is not used and is not an output, so we
263261
# remove it and its outputs
264262
if not hasattr(apply_node.tag, "removed_by"):
@@ -276,7 +274,7 @@ def remove_client(
276274
removal_stack.append((in_var, (apply_node, i)))
277275

278276
if remove_if_empty:
279-
del self.clients[var]
277+
del clients[var]
280278

281279
def import_var(
282280
self, var: Variable, reason: str | None = None, import_missing: bool = False
@@ -563,10 +561,11 @@ def remove_node(self, node: Apply, reason: str | None = None):
563561
node.tag.removed_by.append(str(reason))
564562

565563
# Remove the outputs of the node (i.e. everything "below" it)
564+
clients = self.clients
566565
for out in node.outputs:
567566
self.variables.remove(out)
568567

569-
out_clients = self.clients.get(out, ())
568+
out_clients = clients.get(out, ())
570569
while out_clients:
571570
out_client, out_idx = out_clients.pop()
572571

@@ -590,13 +589,12 @@ def remove_node(self, node: Apply, reason: str | None = None):
590589
assert isinstance(out_client, Apply)
591590
self.remove_node(out_client, reason=reason)
592591

593-
if out in self.clients:
594-
del self.clients[out]
592+
clients.pop(out, None)
595593

596594
# Remove all the arrows pointing to this `node`, and any orphaned
597595
# variables created by removing those arrows
598596
for inp_idx, inp in enumerate(node.inputs):
599-
inp_clients: list[ClientType] = self.clients.get(inp, [])
597+
inp_clients: list[ClientType] = clients.get(inp, [])
600598

601599
arrow = (node, inp_idx)
602600

@@ -810,12 +808,13 @@ def check_integrity(self) -> None:
810808
raise Exception(
811809
f"The following nodes are inappropriately cached:\nmissing: {nodes_missing}\nin excess: {nodes_excess}"
812810
)
811+
clients = self.clients
813812
for node in nodes:
814813
for i, variable in enumerate(node.inputs):
815-
clients = self.clients[variable]
816-
if (node, i) not in clients:
814+
var_clients = clients[variable]
815+
if (node, i) not in var_clients:
817816
raise Exception(
818-
f"Inconsistent clients list {(node, i)} in {clients}"
817+
f"Inconsistent clients list {(node, i)} in {var_clients}"
819818
)
820819
variables = set(vars_between(self.inputs, self.outputs))
821820
if set(self.variables) != variables:
@@ -831,7 +830,7 @@ def check_integrity(self) -> None:
831830
and not isinstance(variable, AtomicVariable)
832831
):
833832
raise Exception(f"Undeclared input: {variable}")
834-
for cl_node, i in self.clients[variable]:
833+
for cl_node, i in clients[variable]:
835834
if cl_node == "output":
836835
if self.outputs[i] is not variable:
837836
raise Exception(

0 commit comments

Comments
 (0)