Skip to content

Commit 8ca9ed5

Browse files
committed
Allow defining an OpFromGraph from constant and shared inputs.
Also adds a strict flag
1 parent 339aab4 commit 8ca9ed5

File tree

2 files changed

+57
-37
lines changed

2 files changed

+57
-37
lines changed

pytensor/compile/builders.py

+31-30
Original file line numberDiff line numberDiff line change
@@ -92,38 +92,29 @@ def construct_nominal_fgraph(
9292
dict[Variable, Variable],
9393
]:
9494
"""Construct an inner-`FunctionGraph` with ordered nominal inputs."""
95-
dummy_inputs = []
96-
for n, inp in enumerate(inputs):
97-
if (
98-
not isinstance(inp, Variable)
99-
or isinstance(inp, Constant)
100-
or isinstance(inp, SharedVariable)
101-
):
102-
raise TypeError(
103-
f"Inputs and outputs must be non-Constant/shared Variable instances; got {inp}"
104-
)
105-
106-
dummy_inputs.append(inp.type())
95+
implicit_shared_inputs = []
10796

108-
dummy_shared_inputs = []
109-
shared_inputs = []
97+
dummy_inputs = [inp.type() for inp in inputs]
98+
dummy_implicit_shared_inputs = []
11099
for var in graph_inputs(outputs, inputs):
100+
if var in inputs:
101+
continue
111102
if isinstance(var, SharedVariable):
112-
# To correctly support shared variables the inner-graph should
113-
# not see them; otherwise, there will be problems with
114-
# gradients.
115-
# That's why we collect the shared variables and replace them
116-
# with dummies.
117-
shared_inputs.append(var)
118-
dummy_shared_inputs.append(var.type())
119-
elif var not in inputs and not isinstance(var, Constant):
120-
raise MissingInputError(f"OpFromGraph is missing an input: {var}")
121-
122-
replacements = dict(zip(inputs + shared_inputs, dummy_inputs + dummy_shared_inputs))
103+
# We allow shared inputs to be added automatically to the graph
104+
implicit_shared_inputs.append(var)
105+
dummy_implicit_shared_inputs.append(var.type())
106+
elif not isinstance(var, Constant):
107+
raise MissingInputError(f"NominalGraph is missing an input: {var}")
108+
109+
replacements = dict(
110+
zip(
111+
inputs + implicit_shared_inputs, dummy_inputs + dummy_implicit_shared_inputs
112+
)
113+
)
123114

124115
new = rebuild_collect_shared(
125116
cast(Sequence[Variable], outputs),
126-
inputs=inputs + shared_inputs,
117+
inputs=inputs + implicit_shared_inputs,
127118
replace=replacements,
128119
copy_inputs_over=False,
129120
)
@@ -133,7 +124,7 @@ def construct_nominal_fgraph(
133124
(clone_d, update_d, update_expr, new_shared_inputs),
134125
) = new
135126

136-
assert len(local_inputs) == len(inputs) + len(shared_inputs)
127+
assert len(local_inputs) == len(inputs) + len(implicit_shared_inputs)
137128
assert len(local_outputs) == len(outputs)
138129
assert not update_d
139130
assert not update_expr
@@ -155,7 +146,7 @@ def construct_nominal_fgraph(
155146
fgraph.clients.pop(inp, None)
156147
fgraph.add_input(nom_inp)
157148

158-
return fgraph, shared_inputs, update_d, update_expr
149+
return fgraph, implicit_shared_inputs, update_d, update_expr
159150

160151

161152
class OpFromGraph(Op, HasInnerGraph):
@@ -177,8 +168,6 @@ class OpFromGraph(Op, HasInnerGraph):
177168
- grad() make it support DisconnectedType and the new interface
178169
- add support for NullType and DisconnectedType when R_op supports them
179170
- check how it works with updates.
180-
- add test with constant as input or inside the inner graph.
181-
- Add support for the GPU? Probably just need an opt to remove transfer
182171
- Add support to pickle this Op.
183172
- Add support/test with random generator
184173
- Add optimization to removing unused inputs/outputs
@@ -310,11 +299,13 @@ def __init__(
310299
self,
311300
inputs: list[Variable],
312301
outputs: list[Variable],
302+
*,
313303
inline: bool = False,
314304
lop_overrides: str = "default",
315305
grad_overrides: str = "default",
316306
rop_overrides: str = "default",
317307
connection_pattern: Optional[list[list[bool]]] = None,
308+
strict: bool = False,
318309
name: Optional[str] = None,
319310
**kwargs,
320311
):
@@ -399,6 +390,10 @@ def __init__(
399390
must be equal to number of outputs. connection_pattern If not
400391
``None``, this will be used as the connection_pattern for this
401392
:class:`Op`.
393+
strict: bool, default False
394+
If true, it raises when any variables needed to compute the inner graph
395+
are not provided as explici inputs. This can only happen for graphs with
396+
shared variables.
402397
name
403398
A name for debugging purposes.
404399
kwargs
@@ -424,6 +419,12 @@ def __init__(
424419
inputs, outputs
425420
)
426421

422+
if strict and self.shared_inputs:
423+
raise ValueError(
424+
"All variables needed to compute inner-graph must be provided as inputs under strict=True. "
425+
f"The inner-graph implicitly depends on the following shared variables {self.shared_inputs}"
426+
)
427+
427428
self.kwargs = kwargs
428429
self.input_types = [inp.type for inp in inputs]
429430
self.output_types = [out.type for out in outputs]

tests/compile/test_builders.py

+26-7
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from pytensor.graph.rewriting.utils import rewrite_graph
1616
from pytensor.graph.utils import MissingInputError
1717
from pytensor.printing import debugprint
18-
from pytensor.tensor.basic import as_tensor
18+
from pytensor.tensor.basic import constant
1919
from pytensor.tensor.math import dot, exp, sigmoid
2020
from pytensor.tensor.math import round as pt_round
2121
from pytensor.tensor.math import sum as pt_sum
@@ -43,12 +43,6 @@ def test_valid_input(self):
4343
with pytest.raises(TypeError):
4444
OpFromGraph([1], [1])
4545

46-
with pytest.raises(TypeError):
47-
OpFromGraph([x, as_tensor(1)], [x])
48-
49-
with pytest.raises(TypeError):
50-
OpFromGraph([shared(1)], [1])
51-
5246
with pytest.raises(NotImplementedError):
5347
OpFromGraph([x], [x], updates={})
5448

@@ -559,6 +553,31 @@ def test_outputs_consistency(self):
559553
# The original `op.fgraph` outputs should stay the same, though
560554
assert equal_computations(op.inner_outputs, [x**2 / x], op.inner_inputs, [x])
561555

556+
def test_explicit_input_from_constant(self):
557+
x = pt.dscalar("x")
558+
y = constant(1.0, name="y")
559+
test_ofg = OpFromGraph([x, y], [x + y])
560+
561+
out = test_ofg(x, y)
562+
assert out.eval({x: 5}) == 6
563+
564+
def test_explicit_input_from_shared(self):
565+
x = pt.dscalar("x")
566+
y = shared(1.0, name="y")
567+
568+
with pytest.raises(
569+
ValueError,
570+
match=r"The inner-graph implicitly depends on the following shared variables \[y\]",
571+
):
572+
OpFromGraph([x], [x + y], strict=True)
573+
574+
test_ofg = OpFromGraph([x, y], [x + y], strict=True)
575+
576+
out = test_ofg(x, y)
577+
assert out.eval({x: 5}) == 6
578+
y.set_value(2.0)
579+
assert out.eval({x: 6})
580+
562581

563582
@config.change_flags(floatX="float64")
564583
def test_debugprint():

0 commit comments

Comments
 (0)