Skip to content

Commit 25e70e8

Browse files
committed
Render inner-graphs of Composite Ops in debugprint
1 parent 132b93f commit 25e70e8

File tree

7 files changed

+65
-64
lines changed

7 files changed

+65
-64
lines changed

pytensor/printing.py

+17-7
Original file line numberDiff line numberDiff line change
@@ -312,7 +312,11 @@ def debugprint(
312312
):
313313

314314
if hasattr(var.owner, "op"):
315-
if isinstance(var.owner.op, HasInnerGraph) and var not in inner_graph_vars:
315+
if (
316+
isinstance(var.owner.op, HasInnerGraph)
317+
or hasattr(var.owner.op, "scalar_op")
318+
and isinstance(var.owner.op.scalar_op, HasInnerGraph)
319+
) and var not in inner_graph_vars:
316320
inner_graph_vars.append(var)
317321
if print_op_info:
318322
op_information.update(op_debug_information(var.owner.op, var.owner))
@@ -355,8 +359,12 @@ def debugprint(
355359
inner_inputs = inner_fn.maker.fgraph.inputs
356360
inner_outputs = inner_fn.maker.fgraph.outputs
357361
else:
358-
inner_inputs = ig_var.owner.op.inner_inputs
359-
inner_outputs = ig_var.owner.op.inner_outputs
362+
if hasattr(ig_var.owner.op, "scalar_op"):
363+
inner_inputs = ig_var.owner.op.scalar_op.inner_inputs
364+
inner_outputs = ig_var.owner.op.scalar_op.inner_outputs
365+
else:
366+
inner_inputs = ig_var.owner.op.inner_inputs
367+
inner_outputs = ig_var.owner.op.inner_outputs
360368

361369
outer_inputs = ig_var.owner.inputs
362370

@@ -422,8 +430,9 @@ def debugprint(
422430

423431
if (
424432
isinstance(getattr(out.owner, "op", None), HasInnerGraph)
425-
and out not in inner_graph_vars
426-
):
433+
or hasattr(getattr(out.owner, "op", None), "scalar_op")
434+
and isinstance(out.owner.op.scalar_op, HasInnerGraph)
435+
) and out not in inner_graph_vars:
427436
inner_graph_vars.append(out)
428437

429438
_debugprint(
@@ -664,8 +673,9 @@ def get_id_str(
664673
if hasattr(in_var, "owner") and hasattr(in_var.owner, "op"):
665674
if (
666675
isinstance(in_var.owner.op, HasInnerGraph)
667-
and in_var not in inner_graph_ops
668-
):
676+
or hasattr(in_var.owner.op, "scalar_op")
677+
and isinstance(in_var.owner.op.scalar_op, HasInnerGraph)
678+
) and in_var not in inner_graph_ops:
669679
inner_graph_ops.append(in_var)
670680

671681
_debugprint(

pytensor/scalar/basic.py

+2-25
Original file line numberDiff line numberDiff line change
@@ -4000,7 +4000,8 @@ class Composite(ScalarOp, HasInnerGraph):
40004000

40014001
init_param: Tuple[str, ...] = ("inputs", "outputs")
40024002

4003-
def __init__(self, inputs, outputs):
4003+
def __init__(self, inputs, outputs, name="Composite"):
4004+
self.name = name
40044005
# We need to clone the graph as sometimes its nodes already
40054006
# contain a reference to an fgraph. As we want the Composite
40064007
# to be pickable, we can't have reference to fgraph.
@@ -4106,30 +4107,6 @@ def _perform(*inputs, outputs=[[None]]):
41064107
self._py_perform_fn = fgraph_to_python(self.fgraph, python_convert)
41074108
return self._py_perform_fn
41084109

4109-
@property
4110-
def name(self):
4111-
if hasattr(self, "_name"):
4112-
return self._name
4113-
4114-
# TODO FIXME: Just implement pretty printing for the `Op`; don't do
4115-
# this redundant, outside work in the `Op` itself.
4116-
for i, r in enumerate(self.fgraph.inputs):
4117-
r.name = f"i{int(i)}"
4118-
for i, r in enumerate(self.fgraph.outputs):
4119-
r.name = f"o{int(i)}"
4120-
io = set(self.fgraph.inputs + self.fgraph.outputs)
4121-
for i, r in enumerate(self.fgraph.variables):
4122-
if r not in io and len(self.fgraph.clients[r]) > 1:
4123-
r.name = f"t{int(i)}"
4124-
outputs_str = ", ".join([pprint(output) for output in self.fgraph.outputs])
4125-
rval = f"Composite{{{outputs_str}}}"
4126-
self._name = rval
4127-
return self._name
4128-
4129-
@name.setter
4130-
def name(self, name):
4131-
self._name = name
4132-
41334110
@property
41344111
def fgraph(self):
41354112
if hasattr(self, "_fgraph"):

tests/scalar/test_basic.py

+1-6
Original file line numberDiff line numberDiff line change
@@ -183,12 +183,7 @@ def test_composite_printing(self):
183183
make_function(DualLinker().accept(g))
184184

185185
assert str(g) == (
186-
"FunctionGraph(*1 -> Composite{((i0 + i1) + i2),"
187-
" (i0 + (i1 * i2)), (i0 / i1), "
188-
"(i0 // 5), "
189-
"(-i0), (i0 - i1), ((i0 ** i1) + (-i2)),"
190-
" (i0 % 3)}(x, y, z), "
191-
"*1::1, *1::2, *1::3, *1::4, *1::5, *1::6, *1::7)"
186+
"FunctionGraph(*1 -> Composite(x, y, z), *1::1, *1::2, *1::3, *1::4, *1::5, *1::6, *1::7)"
192187
)
193188

194189
def test_non_scalar_error(self):

tests/scan/test_printing.py

+27-18
Original file line numberDiff line numberDiff line change
@@ -604,31 +604,40 @@ def no_shared_fn(n, x_tm1, M):
604604
out = pytensor.function([M], out, updates=updates, mode="FAST_RUN")
605605

606606
expected_output = """forall_inplace,cpu,scan_fn} [id A] 2 (outer_out_sit_sot-0)
607-
|TensorConstant{20000} [id B] (n_steps)
608-
|TensorConstant{[ 0 ..998 19999]} [id C] (outer_in_seqs-0)
609-
|IncSubtensor{InplaceSet;:int64:} [id D] 1 (outer_in_sit_sot-0)
610-
| |AllocEmpty{dtype='int64'} [id E] 0
611-
| | |TensorConstant{20000} [id B]
612-
| |TensorConstant{(1,) of 0} [id F]
613-
| |ScalarConstant{1} [id G]
614-
|<TensorType(float64, (20000, 2, 2))> [id H] (outer_in_non_seqs-0)
607+
|TensorConstant{20000} [id B] (n_steps)
608+
|TensorConstant{[ 0 ..998 19999]} [id C] (outer_in_seqs-0)
609+
|IncSubtensor{InplaceSet;:int64:} [id D] 1 (outer_in_sit_sot-0)
610+
| |AllocEmpty{dtype='int64'} [id E] 0
611+
| | |TensorConstant{20000} [id B]
612+
| |TensorConstant{(1,) of 0} [id F]
613+
| |ScalarConstant{1} [id G]
614+
|<TensorType(float64, (20000, 2, 2))> [id H] (outer_in_non_seqs-0)
615615
616616
Inner graphs:
617617
618618
forall_inplace,cpu,scan_fn} [id A] (outer_out_sit_sot-0)
619-
>Elemwise{Composite{Switch(LT(i0, i1), i2, i0)}} [id I] (inner_out_sit_sot-0)
620-
> |TensorConstant{0} [id J]
621-
> |Subtensor{int64, int64, uint8} [id K]
622-
> | |*2-<TensorType(float64, (20000, 2, 2))> [id L] -> [id H] (inner_in_non_seqs-0)
623-
> | |ScalarFromTensor [id M]
624-
> | | |*0-<TensorType(int64, ())> [id N] -> [id C] (inner_in_seqs-0)
625-
> | |ScalarFromTensor [id O]
626-
> | | |*1-<TensorType(int64, ())> [id P] -> [id D] (inner_in_sit_sot-0)
627-
> | |ScalarConstant{0} [id Q]
628-
> |TensorConstant{1} [id R]
619+
>Elemwise{Composite} [id I] (inner_out_sit_sot-0)
620+
> |TensorConstant{0} [id J]
621+
> |Subtensor{int64, int64, uint8} [id K]
622+
> | |*2-<TensorType(float64, (20000, 2, 2))> [id L] -> [id H] (inner_in_non_seqs-0)
623+
> | |ScalarFromTensor [id M]
624+
> | | |*0-<TensorType(int64, ())> [id N] -> [id C] (inner_in_seqs-0)
625+
> | |ScalarFromTensor [id O]
626+
> | | |*1-<TensorType(int64, ())> [id P] -> [id D] (inner_in_sit_sot-0)
627+
> | |ScalarConstant{0} [id Q]
628+
> |TensorConstant{1} [id R]
629+
630+
Elemwise{Composite} [id I]
631+
>Switch [id S]
632+
> |LT [id T]
633+
> | |<int64> [id U]
634+
> | |<float64> [id V]
635+
> |<int64> [id W]
636+
> |<int64> [id U]
629637
"""
630638

631639
output_str = debugprint(out, file="str", print_op_info=True)
640+
print(output_str)
632641
lines = output_str.split("\n")
633642

634643
for truth, out in zip(expected_output.split("\n"), lines):

tests/tensor/rewriting/test_basic.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from pytensor.graph.rewriting.basic import check_stack_trace, out2in
1717
from pytensor.graph.rewriting.db import RewriteDatabaseQuery
1818
from pytensor.graph.rewriting.utils import rewrite_graph
19-
from pytensor.printing import pprint
19+
from pytensor.printing import debugprint, pprint
2020
from pytensor.raise_op import Assert, CheckAndRaise
2121
from pytensor.tensor.basic import (
2222
Alloc,
@@ -1105,7 +1105,7 @@ def test_elemwise_float_ops(self, op):
11051105
s2 = at.switch(c, x, y)
11061106

11071107
g = rewrite(FunctionGraph(mats, [op(s1, s2)]))
1108-
assert str(g).count("Switch") == 1
1108+
assert debugprint(g, file="str").count("Switch") == 1
11091109

11101110
@pytest.mark.parametrize(
11111111
"op",
@@ -1122,7 +1122,7 @@ def test_elemwise_int_ops(self, op):
11221122
s1 = at.switch(c, a, b)
11231123
s2 = at.switch(c, x, y)
11241124
g = rewrite(FunctionGraph(mats, [op(s1, s2)]))
1125-
assert str(g).count("Switch") == 1
1125+
assert debugprint(g, file="str").count("Switch") == 1
11261126

11271127
@pytest.mark.parametrize("op", [add, mul])
11281128
def test_elemwise_multi_inputs(self, op):
@@ -1134,7 +1134,7 @@ def test_elemwise_multi_inputs(self, op):
11341134
u, v = matrices("uv")
11351135
s3 = at.switch(c, u, v)
11361136
g = rewrite(FunctionGraph(mats + [u, v], [op(s1, s2, s3)]))
1137-
assert str(g).count("Switch") == 1
1137+
assert debugprint(g, file="str").count("Switch") == 1
11381138

11391139

11401140
class TestLocalOptAlloc:

tests/tensor/rewriting/test_math.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from pytensor.graph.rewriting.db import RewriteDatabaseQuery
2929
from pytensor.graph.rewriting.utils import is_same_graph, rewrite_graph
3030
from pytensor.misc.safe_asarray import _asarray
31+
from pytensor.printing import debugprint
3132
from pytensor.tensor import inplace
3233
from pytensor.tensor.basic import Alloc, join, switch
3334
from pytensor.tensor.blas import Dot22, Gemv
@@ -2416,7 +2417,7 @@ def test_elemwise(self):
24162417
at_pow,
24172418
):
24182419
g = rewrite(FunctionGraph(mats, [op(s1, s2)]))
2419-
assert str(g).count("Switch") == 1
2420+
assert debugprint(g, file="str").count("Switch") == 1
24202421
# integer Ops
24212422
mats = imatrices("cabxy")
24222423
c, a, b, x, y = mats
@@ -2428,13 +2429,13 @@ def test_elemwise(self):
24282429
bitwise_xor,
24292430
):
24302431
g = rewrite(FunctionGraph(mats, [op(s1, s2)]))
2431-
assert str(g).count("Switch") == 1
2432+
assert debugprint(g, file="str").count("Switch") == 1
24322433
# add/mul with more than two inputs
24332434
u, v = matrices("uv")
24342435
s3 = at.switch(c, u, v)
24352436
for op in (add, mul):
24362437
g = rewrite(FunctionGraph(mats + [u, v], [op(s1, s2, s3)]))
2437-
assert str(g).count("Switch") == 1
2438+
assert debugprint(g, file="str").count("Switch") == 1
24382439

24392440

24402441
class TestLocalSumProd:

tests/test_printing.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,7 @@ def test_debugprint():
273273
s = s.getvalue()
274274
exp_res = dedent(
275275
r"""
276-
Elemwise{Composite{(i2 + (i0 - i1))}} 4
276+
Elemwise{Composite} 4
277277
|InplaceDimShuffle{x,0} v={0: [0]} 3
278278
| |CGemv{inplace} d={0: [0]} 2
279279
| |AllocEmpty{dtype='float64'} 1
@@ -285,6 +285,15 @@ def test_debugprint():
285285
| |TensorConstant{0.0}
286286
|D
287287
|A
288+
289+
Inner graphs:
290+
291+
Elemwise{Composite}
292+
>add
293+
> |<float64>
294+
> |sub
295+
> |<float64>
296+
> |<float64>
288297
"""
289298
).lstrip()
290299

0 commit comments

Comments
 (0)