Skip to content

Commit 9ba6d99

Browse files
committed
Replace str "output" by a dummy Op in the clients of the FunctionGraph
1 parent 7f623fe commit 9ba6d99

File tree

18 files changed

+172
-180
lines changed

18 files changed

+172
-180
lines changed

pytensor/compile/debugmode.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from pytensor.graph.basic import Variable, io_toposort
3131
from pytensor.graph.destroyhandler import DestroyHandler
3232
from pytensor.graph.features import AlreadyThere, BadOptimization
33+
from pytensor.graph.fg import Output
3334
from pytensor.graph.op import HasInnerGraph, Op
3435
from pytensor.graph.utils import InconsistencyError, MethodNotDefined
3536
from pytensor.link.basic import Container, LocalLinker
@@ -628,7 +629,9 @@ def _is_used_in_graph(fgraph, var):
628629
True if `var` is used by another node in the graph.
629630
630631
"""
631-
return not (fgraph.clients[var] == [("output", 1)] or fgraph.clients[var] == [])
632+
return any(
633+
client for client, _ in fgraph.clients[var] if not isinstance(client.op, Output)
634+
)
632635

633636

634637
def _check_strides_match(a, b, warn_err, op):
@@ -977,7 +980,7 @@ def _check_preallocated_output(
977980
# disable memory checks in that mode, since they were already run.
978981
try:
979982
changed_inner_mode = False
980-
if isinstance(getattr(node, "op", None), HasInnerGraph):
983+
if isinstance(node.op, HasInnerGraph):
981984
fn = node.op.fn
982985
if not (fn and hasattr(fn, "maker") and hasattr(fn.maker, "mode")):
983986
_logger.warning(f"Expected pytensor function not found in {node.op}.fn")
@@ -1132,18 +1135,14 @@ class _FunctionGraphEvent:
11321135

11331136
def __init__(self, kind, node, idx=None, reason=None):
11341137
self.kind = kind
1135-
if node == "output":
1136-
self.node = "output"
1137-
self.op = "output"
1138-
else:
1139-
self.node = node
1140-
self.op = node.op
1138+
self.node = node
1139+
self.op = node.op
11411140
self.idx = idx
11421141
self.reason = str(reason)
11431142

11441143
def __str__(self):
11451144
if self.kind == "change":
1146-
if self.op != "output":
1145+
if not isinstance(self.op, Output):
11471146
msg = str(len(self.node.inputs))
11481147
else:
11491148
msg = ""

pytensor/compile/function/types.py

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,6 @@ def view_tree_set(fgraph, v, treeset):
7878
"""
7979
treeset.add(v)
8080
for cl, v_input_pos_to_cl in fgraph.clients[v]:
81-
if cl == "output":
82-
continue
8381
vmap = cl.op.view_map
8482
dmap = cl.op.destroy_map
8583
for opos, iposlist in chain(vmap.items(), dmap.items()):
@@ -1202,8 +1200,11 @@ def insert_deepcopy(fgraph, wrapped_inputs, wrapped_outputs):
12021200
has_destroyers_attr = hasattr(fgraph, "has_destroyers")
12031201

12041202
for i in range(len(fgraph.outputs)):
1203+
original_out = fgraph.outputs[i]
1204+
output_client = fgraph.get_output_client(i)
1205+
12051206
views_of_output_i = set()
1206-
view_tree_set(fgraph, alias_root(fgraph.outputs[i]), views_of_output_i)
1207+
view_tree_set(fgraph, alias_root(original_out), views_of_output_i)
12071208
copied = False
12081209
# do not allow outputs to be aliased
12091210
for j in range(i + 1, len(fgraph.outputs)):
@@ -1212,16 +1213,16 @@ def insert_deepcopy(fgraph, wrapped_inputs, wrapped_outputs):
12121213
if fgraph.outputs[j] in views_of_output_i:
12131214
if wrapped_outputs[i].borrow and wrapped_outputs[j].borrow:
12141215
fgraph.change_node_input(
1215-
"output", i, view_op(fgraph.outputs[i]), reason=reason
1216+
*output_client, view_op(original_out), reason=reason
12161217
)
12171218
else:
12181219
fgraph.change_node_input(
1219-
"output", i, deep_copy_op(fgraph.outputs[i]), reason=reason
1220+
*output_client, deep_copy_op(original_out), reason=reason
12201221
)
12211222
copied = True
12221223
break
12231224

1224-
if not copied:
1225+
if not copied: # no-break
12251226
for input_j in all_graph_inputs:
12261227
# do not allow outputs to be aliased to an inputs (j), unless
12271228
# a) that j'th input has been 'destroyed' by
@@ -1239,33 +1240,29 @@ def insert_deepcopy(fgraph, wrapped_inputs, wrapped_outputs):
12391240
j = fgraph.inputs.index(input_j)
12401241
if wrapped_outputs[i].borrow and wrapped_inputs[j].borrow:
12411242
fgraph.change_node_input(
1242-
"output",
1243-
i,
1244-
view_op(fgraph.outputs[i]),
1243+
*output_client,
1244+
view_op(original_out),
12451245
reason=reason,
12461246
)
12471247
break
12481248
else:
12491249
fgraph.change_node_input(
1250-
"output",
1251-
i,
1252-
deep_copy_op(fgraph.outputs[i]),
1250+
*output_client,
1251+
deep_copy_op(original_out),
12531252
reason=reason,
12541253
)
12551254
break
12561255
elif wrapped_outputs[i].borrow:
12571256
fgraph.change_node_input(
1258-
"output",
1259-
i,
1260-
view_op(fgraph.outputs[i]),
1257+
*output_client,
1258+
view_op(original_out),
12611259
reason=reason,
12621260
)
12631261
break
12641262
else:
12651263
fgraph.change_node_input(
1266-
"output",
1267-
i,
1268-
deep_copy_op(fgraph.outputs[i]),
1264+
*output_client,
1265+
deep_copy_op(original_out),
12691266
reason=reason,
12701267
)
12711268
break

pytensor/compile/profiling.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,20 +16,17 @@
1616
import time
1717
from collections import Counter, defaultdict
1818
from contextlib import contextmanager
19-
from typing import TYPE_CHECKING, Any
19+
from typing import Any
2020

2121
import numpy as np
2222

2323
import pytensor
2424
from pytensor.configdefaults import config
2525
from pytensor.graph.basic import Apply, Constant, Variable
26+
from pytensor.graph.fg import FunctionGraph, Output
2627
from pytensor.link.utils import get_destroy_dependencies
2728

2829

29-
if TYPE_CHECKING:
30-
from pytensor.graph.fg import FunctionGraph
31-
32-
3330
@contextmanager
3431
def extended_open(filename, mode="r"):
3532
if filename == "<stdout>":
@@ -1038,7 +1035,7 @@ def count_minimum_peak(node_list, fgraph, nodes_mem):
10381035
executable_nodes = set()
10391036
for var in fgraph.inputs:
10401037
for c, _ in fgraph.clients[var]:
1041-
if c != "output":
1038+
if not isinstance(c.op, Output):
10421039
deps = c.inputs + destroy_dependencies[c]
10431040
if all(compute_map[v][0] for v in deps):
10441041
executable_nodes.add(c)
@@ -1166,7 +1163,7 @@ def min_memory_generator(executable_nodes, viewed_by, view_of):
11661163

11671164
for var in node.outputs:
11681165
for c, _ in fgraph.clients[var]:
1169-
if c != "output":
1166+
if not isinstance(c.op, Output):
11701167
deps = c.inputs + destroy_dependencies[c]
11711168
if all(compute_map[v][0] for v in deps):
11721169
new_exec_nodes.add(c)

pytensor/graph/destroyhandler.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from pytensor.configdefaults import config
1212
from pytensor.graph.basic import Constant
1313
from pytensor.graph.features import AlreadyThere, Bookkeeper
14+
from pytensor.graph.fg import Output
1415
from pytensor.graph.utils import InconsistencyError
1516
from pytensor.misc.ordered_set import OrderedSet
1617

@@ -401,8 +402,6 @@ def has_destroyers(protected_list):
401402
def recursive_destroys_finder(protected_var):
402403
# protected_var is the idx'th input of app.
403404
for app, idx in fgraph.clients[protected_var]:
404-
if app == "output":
405-
continue
406405
destroy_maps = app.op.destroy_map.values()
407406
# If True means that the apply node, destroys the protected_var.
408407
if idx in [dmap for sublist in destroy_maps for dmap in sublist]:
@@ -578,7 +577,7 @@ def on_change_input(self, fgraph, app, i, old_r, new_r, reason):
578577
app.inputs[i] changed from old_r to new_r.
579578
580579
"""
581-
if app == "output":
580+
if isinstance(app.op, Output):
582581
# app == 'output' is special key that means FunctionGraph is redefining which nodes are being
583582
# considered 'outputs' of the graph.
584583
pass

0 commit comments

Comments
 (0)