@@ -92,38 +92,29 @@ def construct_nominal_fgraph(
92
92
dict [Variable , Variable ],
93
93
]:
94
94
"""Construct an inner-`FunctionGraph` with ordered nominal inputs."""
95
- dummy_inputs = []
96
- for n , inp in enumerate (inputs ):
97
- if (
98
- not isinstance (inp , Variable )
99
- or isinstance (inp , Constant )
100
- or isinstance (inp , SharedVariable )
101
- ):
102
- raise TypeError (
103
- f"Inputs and outputs must be non-Constant/shared Variable instances; got { inp } "
104
- )
105
-
106
- dummy_inputs .append (inp .type ())
95
+ implicit_shared_inputs = []
107
96
108
- dummy_shared_inputs = []
109
- shared_inputs = []
97
+ dummy_inputs = [inp . type () for inp in inputs ]
98
+ dummy_implicit_shared_inputs = []
110
99
for var in graph_inputs (outputs , inputs ):
100
+ if var in inputs :
101
+ continue
111
102
if isinstance (var , SharedVariable ):
112
- # To correctly support shared variables the inner- graph should
113
- # not see them; otherwise, there will be problems with
114
- # gradients.
115
- # That's why we collect the shared variables and replace them
116
- # with dummies.
117
- shared_inputs . append ( var )
118
- dummy_shared_inputs . append ( var . type ())
119
- elif var not in inputs and not isinstance ( var , Constant ):
120
- raise MissingInputError ( f"OpFromGraph is missing an input: { var } " )
121
-
122
- replacements = dict ( zip ( inputs + shared_inputs , dummy_inputs + dummy_shared_inputs ) )
103
+ # We allow shared inputs to be added automatically to the graph
104
+ implicit_shared_inputs . append ( var )
105
+ dummy_implicit_shared_inputs . append ( var . type ())
106
+ elif not isinstance ( var , Constant ):
107
+ raise MissingInputError ( f"NominalGraph is missing an input: { var } " )
108
+
109
+ replacements = dict (
110
+ zip (
111
+ inputs + implicit_shared_inputs , dummy_inputs + dummy_implicit_shared_inputs
112
+ )
113
+ )
123
114
124
115
new = rebuild_collect_shared (
125
116
cast (Sequence [Variable ], outputs ),
126
- inputs = inputs + shared_inputs ,
117
+ inputs = inputs + implicit_shared_inputs ,
127
118
replace = replacements ,
128
119
copy_inputs_over = False ,
129
120
)
@@ -133,7 +124,7 @@ def construct_nominal_fgraph(
133
124
(clone_d , update_d , update_expr , new_shared_inputs ),
134
125
) = new
135
126
136
- assert len (local_inputs ) == len (inputs ) + len (shared_inputs )
127
+ assert len (local_inputs ) == len (inputs ) + len (implicit_shared_inputs )
137
128
assert len (local_outputs ) == len (outputs )
138
129
assert not update_d
139
130
assert not update_expr
@@ -155,7 +146,7 @@ def construct_nominal_fgraph(
155
146
fgraph .clients .pop (inp , None )
156
147
fgraph .add_input (nom_inp )
157
148
158
- return fgraph , shared_inputs , update_d , update_expr
149
+ return fgraph , implicit_shared_inputs , update_d , update_expr
159
150
160
151
161
152
class OpFromGraph (Op , HasInnerGraph ):
@@ -177,8 +168,6 @@ class OpFromGraph(Op, HasInnerGraph):
177
168
- grad() make it support DisconnectedType and the new interface
178
169
- add support for NullType and DisconnectedType when R_op supports them
179
170
- check how it works with updates.
180
- - add test with constant as input or inside the inner graph.
181
- - Add support for the GPU? Probably just need an opt to remove transfer
182
171
- Add support to pickle this Op.
183
172
- Add support/test with random generator
184
173
- Add optimization to removing unused inputs/outputs
@@ -310,11 +299,13 @@ def __init__(
310
299
self ,
311
300
inputs : list [Variable ],
312
301
outputs : list [Variable ],
302
+ * ,
313
303
inline : bool = False ,
314
304
lop_overrides : str = "default" ,
315
305
grad_overrides : str = "default" ,
316
306
rop_overrides : str = "default" ,
317
307
connection_pattern : Optional [list [list [bool ]]] = None ,
308
+ strict : bool = False ,
318
309
name : Optional [str ] = None ,
319
310
** kwargs ,
320
311
):
@@ -399,6 +390,8 @@ def __init__(
399
390
must be equal to number of outputs. connection_pattern If not
400
391
``None``, this will be used as the connection_pattern for this
401
392
:class:`Op`.
393
+ strict: bool, default False
394
+ Raise if SharedVariables needed to compute the graph are not provided as explicit inputs.
402
395
name
403
396
A name for debugging purposes.
404
397
kwargs
@@ -424,6 +417,12 @@ def __init__(
424
417
inputs , outputs
425
418
)
426
419
420
+ if strict and self .shared_inputs :
421
+ raise ValueError (
422
+ "All shared variables must be provided as inputs under strict=True. "
423
+ f"The following variables were missing { self .shared_inputs } "
424
+ )
425
+
427
426
self .kwargs = kwargs
428
427
self .input_types = [inp .type for inp in inputs ]
429
428
self .output_types = [out .type for out in outputs ]
0 commit comments