Skip to content

Commit 044910b

Browse files
Add helper explicit_graph_inputs (#712)
1 parent 27bd9aa commit 044910b

File tree

2 files changed

+64
-0
lines changed

2 files changed

+64
-0
lines changed

pytensor/graph/basic.py

+49
Original file line numberDiff line numberDiff line change
@@ -936,6 +936,55 @@ def graph_inputs(
936936
yield from (r for r in ancestors(graphs, blockers) if r.owner is None)
937937

938938

939+
def explicit_graph_inputs(
940+
graph: Variable | Iterable[Variable],
941+
) -> Generator[Variable, None, None]:
942+
"""
943+
Get the root variables needed as inputs to a function that computes `graph`
944+
945+
Parameters
946+
----------
947+
graph : TensorVariable
948+
Output `Variable` instances for which to search backward through
949+
owners.
950+
951+
Returns
952+
-------
953+
iterable
954+
Generator of root Variables (without owner) needed to compile a function that evaluates `graphs`.
955+
956+
Examples
957+
--------
958+
959+
.. code-block:: python
960+
961+
import pytensor
962+
import pytensor.tensor as pt
963+
from pytensor.graph.basic import explicit_graph_inputs
964+
965+
x = pt.vector('x')
966+
y = pt.constant(2)
967+
z = pt.mul(x*y)
968+
969+
inputs = list(explicit_graph_inputs(z))
970+
f = pytensor.function(inputs, z)
971+
eval = f([1, 2, 3])
972+
973+
print(eval)
974+
# [2. 4. 6.]
975+
"""
976+
from pytensor.compile.sharedvalue import SharedVariable
977+
978+
if isinstance(graph, Variable):
979+
graph = [graph]
980+
981+
return (
982+
v
983+
for v in graph_inputs(graph)
984+
if isinstance(v, Variable) and not isinstance(v, Constant | SharedVariable)
985+
)
986+
987+
939988
def vars_between(
940989
ins: Collection[Variable], outs: Iterable[Variable]
941990
) -> Generator[Variable, None, None]:

tests/graph/test_basic.py

+15
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
clone,
1919
clone_get_equiv,
2020
equal_computations,
21+
explicit_graph_inputs,
2122
general_toposort,
2223
get_var_by_name,
2324
graph_inputs,
@@ -522,6 +523,20 @@ def test_graph_inputs():
522523
assert res_list == [r3, r1, r2]
523524

524525

526+
def test_explicit_graph_inputs():
527+
x = pt.fscalar()
528+
y = pt.constant(2)
529+
z = shared(1)
530+
a = pt.sum(x + y + z)
531+
b = pt.true_div(x, y)
532+
533+
res = list(explicit_graph_inputs([a]))
534+
res1 = list(explicit_graph_inputs(b))
535+
536+
assert res == [x]
537+
assert res1 == [x]
538+
539+
525540
def test_variables_and_orphans():
526541
r1, r2, r3 = MyVariable(1), MyVariable(2), MyVariable(3)
527542
o1 = MyOp(r1, r2)

0 commit comments

Comments
 (0)