Skip to content

Commit 762c4c5

Browse files
brandonwillardricardoV94
authored andcommitted
Remove redundant cloning when swapping nominal variables in OpFromGraph
1 parent 18cd693 commit 762c4c5

File tree

1 file changed

+39
-25
lines changed

1 file changed

+39
-25
lines changed

pytensor/compile/builders.py

+39-25
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
clone_replace,
2020
graph_inputs,
2121
io_connection_pattern,
22-
replace_nominals_with_dummies,
2322
)
2423
from pytensor.graph.fg import FunctionGraph
2524
from pytensor.graph.null_type import NullType
@@ -333,52 +332,51 @@ def __init__(
333332
if not (isinstance(inputs, list) and isinstance(outputs, list)):
334333
raise TypeError("Inputs and outputs must be lists")
335334

336-
for i in inputs + outputs:
337-
if not isinstance(i, Variable):
335+
for out in outputs:
336+
if not isinstance(out, Variable):
338337
raise TypeError(
339-
f"Inputs and outputs must be Variable instances; got {i}"
338+
f"Inputs and outputs must be Variable instances; got {out}"
340339
)
341-
if i in inputs:
342-
if isinstance(i, Constant):
343-
raise TypeError(f"Constants not allowed as inputs; {i}")
344-
if isinstance(i, SharedVariable):
345-
raise TypeError(f"SharedVariables not allowed as inputs; {i}")
340+
341+
dummy_inputs = []
342+
for n, inp in enumerate(inputs):
343+
if (
344+
not isinstance(inp, Variable)
345+
or isinstance(inp, Constant)
346+
or isinstance(inp, SharedVariable)
347+
):
348+
raise TypeError(
349+
f"Inputs and outputs must be non-Constant/shared Variable instances; got {inp}"
350+
)
351+
352+
dummy_inputs.append(inp.type())
346353

347354
if "updates" in kwargs or "givens" in kwargs:
348355
raise NotImplementedError("Updates and givens are not supported")
349356

350357
self.is_inline = inline
351358

359+
dummy_shared_inputs = []
352360
self.shared_inputs = []
353-
inner_graph_inputs = graph_inputs(outputs, inputs)
354-
for var in inner_graph_inputs:
361+
for var in graph_inputs(outputs, inputs):
355362
if isinstance(var, SharedVariable):
356363
# To correctly support shared variables the inner-graph should
357364
# not see them; otherwise, there will be problems with
358365
# gradients.
359366
# That's why we collect the shared variables and replace them
360367
# with dummies.
361368
self.shared_inputs.append(var)
369+
dummy_shared_inputs.append(var.type())
362370
elif var not in inputs and not isinstance(var, Constant):
363371
raise MissingInputError(f"OpFromGraph is missing an input: {var}")
364372

365-
inputs, outputs = replace_nominals_with_dummies(inputs, outputs)
366-
367-
# The inputs should be `NominalVariable`s, so that graphs can be merged
368-
replacements = {}
369-
for n, v in enumerate(inputs):
370-
replacements[v] = NominalVariable(n, v.type)
371-
372-
shared_vars = [
373-
NominalVariable(n, var.type)
374-
for n, var in enumerate(self.shared_inputs, start=len(inputs) + 1)
375-
]
376-
377-
replacements.update(dict(zip(self.shared_inputs, shared_vars)))
373+
replacements = dict(
374+
zip(inputs + self.shared_inputs, dummy_inputs + dummy_shared_inputs)
375+
)
378376

379377
new = rebuild_collect_shared(
380378
cast(Sequence[Variable], outputs),
381-
inputs=inputs + shared_vars,
379+
inputs=inputs + self.shared_inputs,
382380
replace=replacements,
383381
copy_inputs_over=False,
384382
)
@@ -395,6 +393,21 @@ def __init__(
395393
assert not shared_inputs
396394

397395
self.fgraph = FunctionGraph(local_inputs, local_outputs, clone=False)
396+
397+
# The inputs need to be `NominalVariable`s so that we can merge
398+
# inner-graphs
399+
nominal_local_inputs = tuple(
400+
NominalVariable(n, var.type) for n, var in enumerate(local_inputs)
401+
)
402+
403+
self.fgraph.replace_all(zip(local_inputs, nominal_local_inputs))
404+
405+
for i, inp in enumerate(self.fgraph.inputs):
406+
nom_inp = nominal_local_inputs[i]
407+
self.fgraph.inputs[i] = nom_inp
408+
self.fgraph.clients.pop(inp, None)
409+
self.fgraph.add_input(nom_inp)
410+
398411
self.kwargs = kwargs
399412
self.input_types = [inp.type for inp in inputs]
400413
self.output_types = [out.type for out in outputs]
@@ -417,6 +430,7 @@ def __init__(
417430
else:
418431
self.set_lop_overrides("default")
419432
self._lop_type = "lop"
433+
420434
self.set_rop_overrides(rop_overrides)
421435

422436
self._connection_pattern = connection_pattern

0 commit comments

Comments
 (0)