Skip to content

Commit 86bc1d2

Browse files
Add name kwarg to Op.__call__ (#693)
1 parent 14651fb commit 86bc1d2

File tree

2 files changed

+54
-2
lines changed

2 files changed

+54
-2
lines changed

pytensor/graph/op.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,9 @@ def make_node(self, *inputs: Variable) -> Apply:
246246
)
247247
return Apply(self, inputs, [o() for o in self.otypes])
248248

249-
def __call__(self, *inputs: Any, **kwargs) -> Variable | list[Variable]:
249+
def __call__(
250+
self, *inputs: Any, name=None, return_list=False, **kwargs
251+
) -> Variable | list[Variable]:
250252
r"""Construct an `Apply` node using :meth:`Op.make_node` and return its outputs.
251253
252254
This method is just a wrapper around :meth:`Op.make_node`.
@@ -288,8 +290,15 @@ def __call__(self, *inputs: Any, **kwargs) -> Variable | list[Variable]:
288290
the :attr:`Op.default_output` property.
289291
290292
"""
291-
return_list = kwargs.pop("return_list", False)
292293
node = self.make_node(*inputs, **kwargs)
294+
if name is not None:
295+
if len(node.outputs) == 1:
296+
node.outputs[0].name = name
297+
elif self.default_output is not None:
298+
node.outputs[self.default_output].name = name
299+
else:
300+
for i, n in enumerate(node.outputs):
301+
n.name = f"{name}_{i}"
293302

294303
if config.compute_test_value != "off":
295304
compute_test_value(node)

tests/graph/test_op.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,3 +232,46 @@ def perform(self, *_):
232232

233233
x = pt.TensorType(dtype="float64", shape=(1,))("x")
234234
assert SomeOp()(x).type == pt.dvector
235+
236+
237+
@pytest.mark.parametrize("multi_output", [True, False])
238+
def test_call_name(multi_output):
239+
def dummy_variable(name):
240+
return Variable(MyType(thingy=None), None, None, name=name)
241+
242+
x = dummy_variable("x")
243+
244+
class TestCallOp(Op):
245+
def __init__(self, default_output, multi_output):
246+
super().__init__()
247+
self.default_output = default_output
248+
self.multi_output = multi_output
249+
250+
def make_node(self, input):
251+
inputs = [input]
252+
if self.multi_output:
253+
outputs = [input.type(), input.type()]
254+
else:
255+
outputs = [input.type()]
256+
return Apply(self, inputs, outputs)
257+
258+
def perform(self, node, inputs, outputs):
259+
raise NotImplementedError()
260+
261+
if multi_output:
262+
multi_op = TestCallOp(default_output=None, multi_output=multi_output)
263+
res = multi_op(x, name="test_name")
264+
for i, r in enumerate(res):
265+
assert r.name == f"test_name_{i}"
266+
267+
multi_op = TestCallOp(default_output=1, multi_output=multi_output)
268+
result = multi_op(x, name="test_name")
269+
assert result.owner.outputs[0].name is None
270+
assert result.name == "test_name"
271+
else:
272+
single_op = TestCallOp(default_output=None, multi_output=multi_output)
273+
res_single = single_op(x, name="test_name")
274+
assert res_single.name == "test_name"
275+
276+
res_nameless = single_op(x)
277+
assert res_nameless.name is None

0 commit comments

Comments
 (0)