|
2 | 2 | from collections import OrderedDict
|
3 | 3 | from copy import copy
|
4 | 4 | from functools import partial
|
5 |
| -from typing import List, Optional, Sequence, cast |
| 5 | +from typing import Dict, List, Optional, Sequence, Tuple, cast |
6 | 6 |
|
7 | 7 | import pytensor.tensor as at
|
8 | 8 | from pytensor import function
|
@@ -81,6 +81,81 @@ def local_traverse(out):
|
81 | 81 | return ret
|
82 | 82 |
|
83 | 83 |
|
| 84 | +def construct_nominal_fgraph( |
| 85 | + inputs: Sequence[Variable], outputs: Sequence[Variable] |
| 86 | +) -> Tuple[ |
| 87 | + FunctionGraph, |
| 88 | + Sequence[Variable], |
| 89 | + Dict[Variable, Variable], |
| 90 | + Dict[Variable, Variable], |
| 91 | +]: |
| 92 | + """Construct an inner-`FunctionGraph` with ordered nominal inputs.""" |
| 93 | + dummy_inputs = [] |
| 94 | + for n, inp in enumerate(inputs): |
| 95 | + if ( |
| 96 | + not isinstance(inp, Variable) |
| 97 | + or isinstance(inp, Constant) |
| 98 | + or isinstance(inp, SharedVariable) |
| 99 | + ): |
| 100 | + raise TypeError( |
| 101 | + f"Inputs and outputs must be non-Constant/shared Variable instances; got {inp}" |
| 102 | + ) |
| 103 | + |
| 104 | + dummy_inputs.append(inp.type()) |
| 105 | + |
| 106 | + dummy_shared_inputs = [] |
| 107 | + shared_inputs = [] |
| 108 | + for var in graph_inputs(outputs, inputs): |
| 109 | + if isinstance(var, SharedVariable): |
| 110 | + # To correctly support shared variables the inner-graph should |
| 111 | + # not see them; otherwise, there will be problems with |
| 112 | + # gradients. |
| 113 | + # That's why we collect the shared variables and replace them |
| 114 | + # with dummies. |
| 115 | + shared_inputs.append(var) |
| 116 | + dummy_shared_inputs.append(var.type()) |
| 117 | + elif var not in inputs and not isinstance(var, Constant): |
| 118 | + raise MissingInputError(f"OpFromGraph is missing an input: {var}") |
| 119 | + |
| 120 | + replacements = dict(zip(inputs + shared_inputs, dummy_inputs + dummy_shared_inputs)) |
| 121 | + |
| 122 | + new = rebuild_collect_shared( |
| 123 | + cast(Sequence[Variable], outputs), |
| 124 | + inputs=inputs + shared_inputs, |
| 125 | + replace=replacements, |
| 126 | + copy_inputs_over=False, |
| 127 | + ) |
| 128 | + ( |
| 129 | + local_inputs, |
| 130 | + local_outputs, |
| 131 | + (clone_d, update_d, update_expr, new_shared_inputs), |
| 132 | + ) = new |
| 133 | + |
| 134 | + assert len(local_inputs) == len(inputs) + len(shared_inputs) |
| 135 | + assert len(local_outputs) == len(outputs) |
| 136 | + assert not update_d |
| 137 | + assert not update_expr |
| 138 | + assert not new_shared_inputs |
| 139 | + |
| 140 | + fgraph = FunctionGraph(local_inputs, local_outputs, clone=False) |
| 141 | + |
| 142 | + # The inputs need to be `NominalVariable`s so that we can merge |
| 143 | + # inner-graphs |
| 144 | + nominal_local_inputs = tuple( |
| 145 | + NominalVariable(n, var.type) for n, var in enumerate(local_inputs) |
| 146 | + ) |
| 147 | + |
| 148 | + fgraph.replace_all(zip(local_inputs, nominal_local_inputs)) |
| 149 | + |
| 150 | + for i, inp in enumerate(fgraph.inputs): |
| 151 | + nom_inp = nominal_local_inputs[i] |
| 152 | + fgraph.inputs[i] = nom_inp |
| 153 | + fgraph.clients.pop(inp, None) |
| 154 | + fgraph.add_input(nom_inp) |
| 155 | + |
| 156 | + return fgraph, shared_inputs, update_d, update_expr |
| 157 | + |
| 158 | + |
84 | 159 | class OpFromGraph(Op, HasInnerGraph):
|
85 | 160 | r"""
|
86 | 161 | This creates an `Op` from inputs and outputs lists of variables.
|
@@ -338,76 +413,15 @@ def __init__(
|
338 | 413 | f"Inputs and outputs must be Variable instances; got {out}"
|
339 | 414 | )
|
340 | 415 |
|
341 |
| - dummy_inputs = [] |
342 |
| - for n, inp in enumerate(inputs): |
343 |
| - if ( |
344 |
| - not isinstance(inp, Variable) |
345 |
| - or isinstance(inp, Constant) |
346 |
| - or isinstance(inp, SharedVariable) |
347 |
| - ): |
348 |
| - raise TypeError( |
349 |
| - f"Inputs and outputs must be non-Constant/shared Variable instances; got {inp}" |
350 |
| - ) |
351 |
| - |
352 |
| - dummy_inputs.append(inp.type()) |
353 |
| - |
354 | 416 | if "updates" in kwargs or "givens" in kwargs:
|
355 | 417 | raise NotImplementedError("Updates and givens are not supported")
|
356 | 418 |
|
357 | 419 | self.is_inline = inline
|
358 | 420 |
|
359 |
| - dummy_shared_inputs = [] |
360 |
| - self.shared_inputs = [] |
361 |
| - for var in graph_inputs(outputs, inputs): |
362 |
| - if isinstance(var, SharedVariable): |
363 |
| - # To correctly support shared variables the inner-graph should |
364 |
| - # not see them; otherwise, there will be problems with |
365 |
| - # gradients. |
366 |
| - # That's why we collect the shared variables and replace them |
367 |
| - # with dummies. |
368 |
| - self.shared_inputs.append(var) |
369 |
| - dummy_shared_inputs.append(var.type()) |
370 |
| - elif var not in inputs and not isinstance(var, Constant): |
371 |
| - raise MissingInputError(f"OpFromGraph is missing an input: {var}") |
372 |
| - |
373 |
| - replacements = dict( |
374 |
| - zip(inputs + self.shared_inputs, dummy_inputs + dummy_shared_inputs) |
| 421 | + self.fgraph, self.shared_inputs, _, _ = construct_nominal_fgraph( |
| 422 | + inputs, outputs |
375 | 423 | )
|
376 | 424 |
|
377 |
| - new = rebuild_collect_shared( |
378 |
| - cast(Sequence[Variable], outputs), |
379 |
| - inputs=inputs + self.shared_inputs, |
380 |
| - replace=replacements, |
381 |
| - copy_inputs_over=False, |
382 |
| - ) |
383 |
| - ( |
384 |
| - local_inputs, |
385 |
| - local_outputs, |
386 |
| - (clone_d, update_d, update_expr, shared_inputs), |
387 |
| - ) = new |
388 |
| - |
389 |
| - assert len(local_inputs) == len(inputs) + len(self.shared_inputs) |
390 |
| - assert len(local_outputs) == len(outputs) |
391 |
| - assert not update_d |
392 |
| - assert not update_expr |
393 |
| - assert not shared_inputs |
394 |
| - |
395 |
| - self.fgraph = FunctionGraph(local_inputs, local_outputs, clone=False) |
396 |
| - |
397 |
| - # The inputs need to be `NominalVariable`s so that we can merge |
398 |
| - # inner-graphs |
399 |
| - nominal_local_inputs = tuple( |
400 |
| - NominalVariable(n, var.type) for n, var in enumerate(local_inputs) |
401 |
| - ) |
402 |
| - |
403 |
| - self.fgraph.replace_all(zip(local_inputs, nominal_local_inputs)) |
404 |
| - |
405 |
| - for i, inp in enumerate(self.fgraph.inputs): |
406 |
| - nom_inp = nominal_local_inputs[i] |
407 |
| - self.fgraph.inputs[i] = nom_inp |
408 |
| - self.fgraph.clients.pop(inp, None) |
409 |
| - self.fgraph.add_input(nom_inp) |
410 |
| - |
411 | 425 | self.kwargs = kwargs
|
412 | 426 | self.input_types = [inp.type for inp in inputs]
|
413 | 427 | self.output_types = [out.type for out in outputs]
|
|
0 commit comments