Skip to content

Allow defining an OpFromGraph from constant and shared inputs #676

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 31 additions & 30 deletions pytensor/compile/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,38 +92,29 @@ def construct_nominal_fgraph(
dict[Variable, Variable],
]:
"""Construct an inner-`FunctionGraph` with ordered nominal inputs."""
dummy_inputs = []
for n, inp in enumerate(inputs):
if (
not isinstance(inp, Variable)
or isinstance(inp, Constant)
or isinstance(inp, SharedVariable)
):
raise TypeError(
f"Inputs and outputs must be non-Constant/shared Variable instances; got {inp}"
)

dummy_inputs.append(inp.type())
implicit_shared_inputs = []

dummy_shared_inputs = []
shared_inputs = []
dummy_inputs = [inp.type() for inp in inputs]
dummy_implicit_shared_inputs = []
for var in graph_inputs(outputs, inputs):
if var in inputs:
continue
if isinstance(var, SharedVariable):
# To correctly support shared variables the inner-graph should
# not see them; otherwise, there will be problems with
# gradients.
# That's why we collect the shared variables and replace them
# with dummies.
shared_inputs.append(var)
dummy_shared_inputs.append(var.type())
elif var not in inputs and not isinstance(var, Constant):
raise MissingInputError(f"OpFromGraph is missing an input: {var}")

replacements = dict(zip(inputs + shared_inputs, dummy_inputs + dummy_shared_inputs))
# We allow shared inputs to be added automatically to the graph
implicit_shared_inputs.append(var)
dummy_implicit_shared_inputs.append(var.type())
elif not isinstance(var, Constant):
raise MissingInputError(f"NominalGraph is missing an input: {var}")

replacements = dict(
zip(
inputs + implicit_shared_inputs, dummy_inputs + dummy_implicit_shared_inputs
)
)

new = rebuild_collect_shared(
cast(Sequence[Variable], outputs),
inputs=inputs + shared_inputs,
inputs=inputs + implicit_shared_inputs,
replace=replacements,
copy_inputs_over=False,
)
Expand All @@ -133,7 +124,7 @@ def construct_nominal_fgraph(
(clone_d, update_d, update_expr, new_shared_inputs),
) = new

assert len(local_inputs) == len(inputs) + len(shared_inputs)
assert len(local_inputs) == len(inputs) + len(implicit_shared_inputs)
assert len(local_outputs) == len(outputs)
assert not update_d
assert not update_expr
Expand All @@ -155,7 +146,7 @@ def construct_nominal_fgraph(
fgraph.clients.pop(inp, None)
fgraph.add_input(nom_inp)

return fgraph, shared_inputs, update_d, update_expr
return fgraph, implicit_shared_inputs, update_d, update_expr


class OpFromGraph(Op, HasInnerGraph):
Expand All @@ -177,8 +168,6 @@ class OpFromGraph(Op, HasInnerGraph):
- grad() make it support DisconnectedType and the new interface
- add support for NullType and DisconnectedType when R_op supports them
- check how it works with updates.
- add test with constant as input or inside the inner graph.
- Add support for the GPU? Probably just need an opt to remove transfer
- Add support to pickle this Op.
- Add support/test with random generator
- Add optimization to removing unused inputs/outputs
Expand Down Expand Up @@ -310,11 +299,13 @@ def __init__(
self,
inputs: list[Variable],
outputs: list[Variable],
*,
inline: bool = False,
lop_overrides: str = "default",
grad_overrides: str = "default",
rop_overrides: str = "default",
connection_pattern: Optional[list[list[bool]]] = None,
strict: bool = False,
name: Optional[str] = None,
**kwargs,
):
Expand Down Expand Up @@ -399,6 +390,10 @@ def __init__(
must be equal to number of outputs. connection_pattern If not
``None``, this will be used as the connection_pattern for this
:class:`Op`.
strict: bool, default False
If true, it raises when any variables needed to compute the inner graph
are not provided as explici inputs. This can only happen for graphs with
shared variables.
name
A name for debugging purposes.
kwargs
Expand All @@ -424,6 +419,12 @@ def __init__(
inputs, outputs
)

if strict and self.shared_inputs:
raise ValueError(
"All variables needed to compute inner-graph must be provided as inputs under strict=True. "
f"The inner-graph implicitly depends on the following shared variables {self.shared_inputs}"
)

self.kwargs = kwargs
self.input_types = [inp.type for inp in inputs]
self.output_types = [out.type for out in outputs]
Expand Down
33 changes: 26 additions & 7 deletions tests/compile/test_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from pytensor.graph.rewriting.utils import rewrite_graph
from pytensor.graph.utils import MissingInputError
from pytensor.printing import debugprint
from pytensor.tensor.basic import as_tensor
from pytensor.tensor.basic import constant
from pytensor.tensor.math import dot, exp, sigmoid
from pytensor.tensor.math import round as pt_round
from pytensor.tensor.math import sum as pt_sum
Expand Down Expand Up @@ -43,12 +43,6 @@ def test_valid_input(self):
with pytest.raises(TypeError):
OpFromGraph([1], [1])

with pytest.raises(TypeError):
OpFromGraph([x, as_tensor(1)], [x])

with pytest.raises(TypeError):
OpFromGraph([shared(1)], [1])

with pytest.raises(NotImplementedError):
OpFromGraph([x], [x], updates={})

Expand Down Expand Up @@ -559,6 +553,31 @@ def test_outputs_consistency(self):
# The original `op.fgraph` outputs should stay the same, though
assert equal_computations(op.inner_outputs, [x**2 / x], op.inner_inputs, [x])

def test_explicit_input_from_constant(self):
x = pt.dscalar("x")
y = constant(1.0, name="y")
test_ofg = OpFromGraph([x, y], [x + y])

out = test_ofg(x, y)
assert out.eval({x: 5}) == 6

def test_explicit_input_from_shared(self):
x = pt.dscalar("x")
y = shared(1.0, name="y")

with pytest.raises(
ValueError,
match=r"The inner-graph implicitly depends on the following shared variables \[y\]",
):
OpFromGraph([x], [x + y], strict=True)

test_ofg = OpFromGraph([x, y], [x + y], strict=True)

out = test_ofg(x, y)
assert out.eval({x: 5}) == 6
y.set_value(2.0)
assert out.eval({x: 6})


@config.change_flags(floatX="float64")
def test_debugprint():
Expand Down