Skip to content

Commit e15c09d

Browse files
brandonwillardricardoV94
authored andcommitted
Generalize the inner-FunctionGraph construction process
1 parent 762c4c5 commit e15c09d

File tree

3 files changed

+88
-99
lines changed

3 files changed

+88
-99
lines changed

pytensor/compile/builders.py

+78-64
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from collections import OrderedDict
33
from copy import copy
44
from functools import partial
5-
from typing import List, Optional, Sequence, cast
5+
from typing import Dict, List, Optional, Sequence, Tuple, cast
66

77
import pytensor.tensor as at
88
from pytensor import function
@@ -81,6 +81,81 @@ def local_traverse(out):
8181
return ret
8282

8383

84+
def construct_nominal_fgraph(
85+
inputs: Sequence[Variable], outputs: Sequence[Variable]
86+
) -> Tuple[
87+
FunctionGraph,
88+
Sequence[Variable],
89+
Dict[Variable, Variable],
90+
Dict[Variable, Variable],
91+
]:
92+
"""Construct an inner-`FunctionGraph` with ordered nominal inputs."""
93+
dummy_inputs = []
94+
for n, inp in enumerate(inputs):
95+
if (
96+
not isinstance(inp, Variable)
97+
or isinstance(inp, Constant)
98+
or isinstance(inp, SharedVariable)
99+
):
100+
raise TypeError(
101+
f"Inputs and outputs must be non-Constant/shared Variable instances; got {inp}"
102+
)
103+
104+
dummy_inputs.append(inp.type())
105+
106+
dummy_shared_inputs = []
107+
shared_inputs = []
108+
for var in graph_inputs(outputs, inputs):
109+
if isinstance(var, SharedVariable):
110+
# To correctly support shared variables the inner-graph should
111+
# not see them; otherwise, there will be problems with
112+
# gradients.
113+
# That's why we collect the shared variables and replace them
114+
# with dummies.
115+
shared_inputs.append(var)
116+
dummy_shared_inputs.append(var.type())
117+
elif var not in inputs and not isinstance(var, Constant):
118+
raise MissingInputError(f"OpFromGraph is missing an input: {var}")
119+
120+
replacements = dict(zip(inputs + shared_inputs, dummy_inputs + dummy_shared_inputs))
121+
122+
new = rebuild_collect_shared(
123+
cast(Sequence[Variable], outputs),
124+
inputs=inputs + shared_inputs,
125+
replace=replacements,
126+
copy_inputs_over=False,
127+
)
128+
(
129+
local_inputs,
130+
local_outputs,
131+
(clone_d, update_d, update_expr, new_shared_inputs),
132+
) = new
133+
134+
assert len(local_inputs) == len(inputs) + len(shared_inputs)
135+
assert len(local_outputs) == len(outputs)
136+
assert not update_d
137+
assert not update_expr
138+
assert not new_shared_inputs
139+
140+
fgraph = FunctionGraph(local_inputs, local_outputs, clone=False)
141+
142+
# The inputs need to be `NominalVariable`s so that we can merge
143+
# inner-graphs
144+
nominal_local_inputs = tuple(
145+
NominalVariable(n, var.type) for n, var in enumerate(local_inputs)
146+
)
147+
148+
fgraph.replace_all(zip(local_inputs, nominal_local_inputs))
149+
150+
for i, inp in enumerate(fgraph.inputs):
151+
nom_inp = nominal_local_inputs[i]
152+
fgraph.inputs[i] = nom_inp
153+
fgraph.clients.pop(inp, None)
154+
fgraph.add_input(nom_inp)
155+
156+
return fgraph, shared_inputs, update_d, update_expr
157+
158+
84159
class OpFromGraph(Op, HasInnerGraph):
85160
r"""
86161
This creates an `Op` from inputs and outputs lists of variables.
@@ -338,76 +413,15 @@ def __init__(
338413
f"Inputs and outputs must be Variable instances; got {out}"
339414
)
340415

341-
dummy_inputs = []
342-
for n, inp in enumerate(inputs):
343-
if (
344-
not isinstance(inp, Variable)
345-
or isinstance(inp, Constant)
346-
or isinstance(inp, SharedVariable)
347-
):
348-
raise TypeError(
349-
f"Inputs and outputs must be non-Constant/shared Variable instances; got {inp}"
350-
)
351-
352-
dummy_inputs.append(inp.type())
353-
354416
if "updates" in kwargs or "givens" in kwargs:
355417
raise NotImplementedError("Updates and givens are not supported")
356418

357419
self.is_inline = inline
358420

359-
dummy_shared_inputs = []
360-
self.shared_inputs = []
361-
for var in graph_inputs(outputs, inputs):
362-
if isinstance(var, SharedVariable):
363-
# To correctly support shared variables the inner-graph should
364-
# not see them; otherwise, there will be problems with
365-
# gradients.
366-
# That's why we collect the shared variables and replace them
367-
# with dummies.
368-
self.shared_inputs.append(var)
369-
dummy_shared_inputs.append(var.type())
370-
elif var not in inputs and not isinstance(var, Constant):
371-
raise MissingInputError(f"OpFromGraph is missing an input: {var}")
372-
373-
replacements = dict(
374-
zip(inputs + self.shared_inputs, dummy_inputs + dummy_shared_inputs)
421+
self.fgraph, self.shared_inputs, _, _ = construct_nominal_fgraph(
422+
inputs, outputs
375423
)
376424

377-
new = rebuild_collect_shared(
378-
cast(Sequence[Variable], outputs),
379-
inputs=inputs + self.shared_inputs,
380-
replace=replacements,
381-
copy_inputs_over=False,
382-
)
383-
(
384-
local_inputs,
385-
local_outputs,
386-
(clone_d, update_d, update_expr, shared_inputs),
387-
) = new
388-
389-
assert len(local_inputs) == len(inputs) + len(self.shared_inputs)
390-
assert len(local_outputs) == len(outputs)
391-
assert not update_d
392-
assert not update_expr
393-
assert not shared_inputs
394-
395-
self.fgraph = FunctionGraph(local_inputs, local_outputs, clone=False)
396-
397-
# The inputs need to be `NominalVariable`s so that we can merge
398-
# inner-graphs
399-
nominal_local_inputs = tuple(
400-
NominalVariable(n, var.type) for n, var in enumerate(local_inputs)
401-
)
402-
403-
self.fgraph.replace_all(zip(local_inputs, nominal_local_inputs))
404-
405-
for i, inp in enumerate(self.fgraph.inputs):
406-
nom_inp = nominal_local_inputs[i]
407-
self.fgraph.inputs[i] = nom_inp
408-
self.fgraph.clients.pop(inp, None)
409-
self.fgraph.add_input(nom_inp)
410-
411425
self.kwargs = kwargs
412426
self.input_types = [inp.type for inp in inputs]
413427
self.output_types = [out.type for out in outputs]

pytensor/scan/op.py

+10-31
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,7 @@
5555

5656
import pytensor
5757
from pytensor import tensor as at
58-
from pytensor.compile import SharedVariable
59-
from pytensor.compile.builders import infer_shape
58+
from pytensor.compile.builders import construct_nominal_fgraph, infer_shape
6059
from pytensor.compile.function.pfunc import pfunc
6160
from pytensor.compile.io import In, Out
6261
from pytensor.compile.mode import Mode, get_default_mode, get_mode
@@ -65,17 +64,13 @@
6564
from pytensor.gradient import DisconnectedType, NullType, Rop, grad, grad_undefined
6665
from pytensor.graph.basic import (
6766
Apply,
68-
Constant,
69-
NominalVariable,
7067
Variable,
7168
clone_replace,
7269
equal_computations,
7370
graph_inputs,
7471
io_connection_pattern,
75-
replace_nominals_with_dummies,
7672
)
7773
from pytensor.graph.features import NoOutputFromInplace
78-
from pytensor.graph.fg import FunctionGraph
7974
from pytensor.graph.op import HasInnerGraph, Op
8075
from pytensor.graph.utils import InconsistencyError, MissingInputError
8176
from pytensor.link.c.basic import CLinker
@@ -755,22 +750,12 @@ def __init__(
755750
If ``True``, all the shared variables used in the inner-graph must be provided.
756751
757752
"""
758-
inputs, outputs = replace_nominals_with_dummies(inputs, outputs)
753+
self.fgraph, shared_inputs, _, _ = construct_nominal_fgraph(inputs, outputs)
759754

760-
input_replacements = []
761-
for n, v in enumerate(inputs):
762-
if not isinstance(v, (SharedVariable, Constant)):
763-
input_replacements.append((v, NominalVariable(n, v.type)))
764-
765-
assert not isinstance(v, NominalVariable)
766-
767-
outputs = clone_replace(outputs, replace=input_replacements)
768-
769-
if input_replacements:
770-
_, inputs_ = zip(*input_replacements)
771-
inputs = list(inputs_)
772-
else:
773-
inputs = []
755+
# The shared variables should have been removed, so, if there are
756+
# any, it's because the user didn't specify an input.
757+
if shared_inputs:
758+
raise MissingInputError(f"Scan is missing inputs: {shared_inputs}")
774759

775760
self.info = info
776761
self.truncate_gradient = truncate_gradient
@@ -782,7 +767,7 @@ def __init__(
782767
# Clone mode_instance, altering "allow_gc" for the linker,
783768
# and adding a message if we profile
784769
if self.name:
785-
message = self.name + " sub profile"
770+
message = f"{self.name} sub profile"
786771
else:
787772
message = "Scan sub profile"
788773

@@ -805,7 +790,7 @@ def tensorConstructor(shape, dtype):
805790
while idx < info.n_mit_mot_outs:
806791
# Not that for mit_mot there are several output slices per
807792
# output sequence
808-
o = outputs[idx]
793+
o = self.fgraph.outputs[idx]
809794
self.output_types.append(
810795
# TODO: What can we actually say about the shape of this
811796
# added dimension?
@@ -818,15 +803,15 @@ def tensorConstructor(shape, dtype):
818803
# mit_sot / sit_sot / nit_sot
819804
end = idx + info.n_mit_sot + info.n_sit_sot + info.n_nit_sot
820805

821-
for o in outputs[idx:end]:
806+
for o in self.fgraph.outputs[idx:end]:
822807
self.output_types.append(
823808
# TODO: What can we actually say about the shape of this
824809
# added dimension?
825810
typeConstructor((None,) + o.type.shape, o.type.dtype)
826811
)
827812

828813
# shared outputs + possibly the ending condition
829-
for o in outputs[end:]:
814+
for o in self.fgraph.outputs[end:]:
830815
self.output_types.append(o.type)
831816

832817
if info.as_while:
@@ -862,19 +847,13 @@ def tensorConstructor(shape, dtype):
862847
self.n_outer_inputs = info.n_outer_inputs
863848
self.n_outer_outputs = info.n_outer_outputs
864849

865-
self.fgraph = FunctionGraph(inputs, outputs, clone=False)
866-
867850
_ = self.prepare_fgraph(self.fgraph)
868851

869852
if any(node.op.destroy_map for node in self.fgraph.apply_nodes):
870853
raise InconsistencyError(
871854
"Inner-graphs must not contain in-place operations."
872855
)
873856

874-
# Do the missing inputs check here to have the error early.
875-
for var in graph_inputs(self.inner_outputs, self.inner_inputs):
876-
if var not in self.inner_inputs and not isinstance(var, Constant):
877-
raise MissingInputError(f"ScanOp is missing an input: {repr(var)}")
878857
self._cmodule_key = CLinker().cmodule_key_variables(
879858
self.inner_inputs, self.inner_outputs, []
880859
)

tests/scan/test_basic.py

-4
Original file line numberDiff line numberDiff line change
@@ -586,10 +586,6 @@ def f_rnn_shared(u_t, x_tm1, tmp_W_in, tmp_W):
586586
assert np.allclose(pytensor_values, v_out)
587587

588588
def test_oinp_iinp_iout_oout_mappings(self):
589-
"""
590-
Test the mapping produces by
591-
ScanOp.get_oinp_iinp_iout_oout_mappings()
592-
"""
593589

594590
rng = RandomStream(123)
595591

0 commit comments

Comments
 (0)