19
19
clone_replace ,
20
20
graph_inputs ,
21
21
io_connection_pattern ,
22
- replace_nominals_with_dummies ,
23
22
)
24
23
from pytensor .graph .fg import FunctionGraph
25
24
from pytensor .graph .null_type import NullType
@@ -333,52 +332,51 @@ def __init__(
333
332
if not (isinstance (inputs , list ) and isinstance (outputs , list )):
334
333
raise TypeError ("Inputs and outputs must be lists" )
335
334
336
- for i in inputs + outputs :
337
- if not isinstance (i , Variable ):
335
+ for out in outputs :
336
+ if not isinstance (out , Variable ):
338
337
raise TypeError (
339
- f"Inputs and outputs must be Variable instances; got { i } "
338
+ f"Inputs and outputs must be Variable instances; got { out } "
340
339
)
341
- if i in inputs :
342
- if isinstance (i , Constant ):
343
- raise TypeError (f"Constants not allowed as inputs; { i } " )
344
- if isinstance (i , SharedVariable ):
345
- raise TypeError (f"SharedVariables not allowed as inputs; { i } " )
340
+
341
+ dummy_inputs = []
342
+ for n , inp in enumerate (inputs ):
343
+ if (
344
+ not isinstance (inp , Variable )
345
+ or isinstance (inp , Constant )
346
+ or isinstance (inp , SharedVariable )
347
+ ):
348
+ raise TypeError (
349
+ f"Inputs and outputs must be non-Constant/shared Variable instances; got { inp } "
350
+ )
351
+
352
+ dummy_inputs .append (inp .type ())
346
353
347
354
if "updates" in kwargs or "givens" in kwargs :
348
355
raise NotImplementedError ("Updates and givens are not supported" )
349
356
350
357
self .is_inline = inline
351
358
359
+ dummy_shared_inputs = []
352
360
self .shared_inputs = []
353
- inner_graph_inputs = graph_inputs (outputs , inputs )
354
- for var in inner_graph_inputs :
361
+ for var in graph_inputs (outputs , inputs ):
355
362
if isinstance (var , SharedVariable ):
356
363
# To correctly support shared variables the inner-graph should
357
364
# not see them; otherwise, there will be problems with
358
365
# gradients.
359
366
# That's why we collect the shared variables and replace them
360
367
# with dummies.
361
368
self .shared_inputs .append (var )
369
+ dummy_shared_inputs .append (var .type ())
362
370
elif var not in inputs and not isinstance (var , Constant ):
363
371
raise MissingInputError (f"OpFromGraph is missing an input: { var } " )
364
372
365
- inputs , outputs = replace_nominals_with_dummies (inputs , outputs )
366
-
367
- # The inputs should be `NominalVariable`s, so that graphs can be merged
368
- replacements = {}
369
- for n , v in enumerate (inputs ):
370
- replacements [v ] = NominalVariable (n , v .type )
371
-
372
- shared_vars = [
373
- NominalVariable (n , var .type )
374
- for n , var in enumerate (self .shared_inputs , start = len (inputs ) + 1 )
375
- ]
376
-
377
- replacements .update (dict (zip (self .shared_inputs , shared_vars )))
373
+ replacements = dict (
374
+ zip (inputs + self .shared_inputs , dummy_inputs + dummy_shared_inputs )
375
+ )
378
376
379
377
new = rebuild_collect_shared (
380
378
cast (Sequence [Variable ], outputs ),
381
- inputs = inputs + shared_vars ,
379
+ inputs = inputs + self . shared_inputs ,
382
380
replace = replacements ,
383
381
copy_inputs_over = False ,
384
382
)
@@ -395,6 +393,21 @@ def __init__(
395
393
assert not shared_inputs
396
394
397
395
self .fgraph = FunctionGraph (local_inputs , local_outputs , clone = False )
396
+
397
+ # The inputs need to be `NominalVariable`s so that we can merge
398
+ # inner-graphs
399
+ nominal_local_inputs = tuple (
400
+ NominalVariable (n , var .type ) for n , var in enumerate (local_inputs )
401
+ )
402
+
403
+ self .fgraph .replace_all (zip (local_inputs , nominal_local_inputs ))
404
+
405
+ for i , inp in enumerate (self .fgraph .inputs ):
406
+ nom_inp = nominal_local_inputs [i ]
407
+ self .fgraph .inputs [i ] = nom_inp
408
+ self .fgraph .clients .pop (inp , None )
409
+ self .fgraph .add_input (nom_inp )
410
+
398
411
self .kwargs = kwargs
399
412
self .input_types = [inp .type for inp in inputs ]
400
413
self .output_types = [out .type for out in outputs ]
@@ -417,6 +430,7 @@ def __init__(
417
430
else :
418
431
self .set_lop_overrides ("default" )
419
432
self ._lop_type = "lop"
433
+
420
434
self .set_rop_overrides (rop_overrides )
421
435
422
436
self ._connection_pattern = connection_pattern
0 commit comments