Skip to content

Commit 947a10b

Browse files
committed
Improve debug_print formatting
* Use better ASCII symbols * Show continuation hint for already seen nodes * Don't repeat entries for multiple output inner graphs
1 parent dc5fc6d commit 947a10b

File tree

4 files changed

+537
-488
lines changed

4 files changed

+537
-488
lines changed

pytensor/printing.py

+24-9
Original file line numberDiff line numberDiff line change
@@ -291,7 +291,7 @@ def debugprint(
291291
for var in inputs_to_print:
292292
_debugprint(
293293
var,
294-
prefix="-",
294+
prefix="",
295295
depth=depth,
296296
done=done,
297297
print_type=print_type,
@@ -342,11 +342,17 @@ def debugprint(
342342

343343
if len(inner_graph_vars) > 0:
344344
print("", file=_file)
345-
new_prefix = " >"
346-
new_prefix_child = " >"
345+
prefix = ""
346+
new_prefix = prefix + " ← "
347+
new_prefix_child = prefix + " "
347348
print("Inner graphs:", file=_file)
348349

350+
printed_inner_graphs_nodes = set()
349351
for ig_var in inner_graph_vars:
352+
if ig_var.owner in printed_inner_graphs_nodes:
353+
continue
354+
else:
355+
printed_inner_graphs_nodes.add(ig_var.owner)
350356
# This is a work-around to maintain backward compatibility
351357
# (e.g. to only print inner graphs that have been compiled through
352358
# a call to `Op.prepare_node`)
@@ -385,6 +391,7 @@ def debugprint(
385391

386392
_debugprint(
387393
ig_var,
394+
prefix=prefix,
388395
depth=depth,
389396
done=done,
390397
print_type=print_type,
@@ -399,13 +406,14 @@ def debugprint(
399406
print_op_info=print_op_info,
400407
print_destroy_map=print_destroy_map,
401408
print_view_map=print_view_map,
409+
is_inner_graph_header=True,
402410
)
403411

404412
if print_fgraph_inputs:
405413
for inp in inner_inputs:
406414
_debugprint(
407415
inp,
408-
prefix="-",
416+
prefix="",
409417
depth=depth,
410418
done=done,
411419
print_type=print_type,
@@ -485,6 +493,7 @@ def _debugprint(
485493
parent_node: Optional[Apply] = None,
486494
print_op_info: bool = False,
487495
inner_graph_node: Optional[Apply] = None,
496+
is_inner_graph_header: bool = False,
488497
) -> TextIO:
489498
r"""Print the graph represented by `var`.
490499
@@ -625,15 +634,18 @@ def get_id_str(
625634
else:
626635
data = ""
627636

628-
var_output = f"{prefix}{node.op}{output_idx}{id_str}{type_str}{var_name}{destroy_map_str}{view_map_str}{o}{data}"
637+
if is_inner_graph_header:
638+
var_output = f"{prefix}{node.op}{id_str}{destroy_map_str}{view_map_str}{o}"
639+
else:
640+
var_output = f"{prefix}{node.op}{output_idx}{id_str}{type_str}{var_name}{destroy_map_str}{view_map_str}{o}{data}"
629641

630642
if print_op_info and node not in op_information:
631643
op_information.update(op_debug_information(node.op, node))
632644

633645
node_info = (
634646
parent_node and op_information.get(parent_node)
635647
) or op_information.get(node)
636-
if node_info and var in node_info:
648+
if node_info and var in node_info and not is_inner_graph_header:
637649
var_output = f"{var_output} ({node_info[var]})"
638650

639651
if profile and profile.apply_time and node in profile.apply_time:
@@ -660,12 +672,13 @@ def get_id_str(
660672
if not already_done and (
661673
not stop_on_name or not (hasattr(var, "name") and var.name is not None)
662674
):
663-
new_prefix = prefix_child + " |"
664-
new_prefix_child = prefix_child + " |"
675+
new_prefix = prefix_child + " ├─ "
676+
new_prefix_child = prefix_child + " "
665677

666678
for in_idx, in_var in enumerate(node.inputs):
667679
if in_idx == len(node.inputs) - 1:
668-
new_prefix_child = prefix_child + " "
680+
new_prefix = prefix_child + " └─ "
681+
new_prefix_child = prefix_child + " "
669682

670683
if hasattr(in_var, "owner") and hasattr(in_var.owner, "op"):
671684
if (
@@ -698,6 +711,8 @@ def get_id_str(
698711
print_view_map=print_view_map,
699712
inner_graph_node=inner_graph_node,
700713
)
714+
elif not is_inner_graph_header:
715+
print(prefix_child + " └─ ···", file=file)
701716
else:
702717
id_str = get_id_str(var)
703718

tests/compile/test_builders.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -572,18 +572,18 @@ def test_debugprint():
572572
lines = output_str.split("\n")
573573

574574
exp_res = """OpFromGraph{inline=False} [id A]
575-
|x [id B]
576-
|y [id C]
577-
|z [id D]
575+
├─ x [id B]
576+
├─ y [id C]
577+
└─ z [id D]
578578
579579
Inner graphs:
580580
581581
OpFromGraph{inline=False} [id A]
582-
>Add [id E]
583-
> |*0-<Matrix(float64, shape=(?, ?))> [id F]
584-
> |Mul [id G]
585-
> |*1-<Matrix(float64, shape=(?, ?))> [id H]
586-
> |*2-<Matrix(float64, shape=(?, ?))> [id I]
582+
Add [id E]
583+
├─ *0-<Matrix(float64, shape=(?, ?))> [id F]
584+
└─ Mul [id G]
585+
├─ *1-<Matrix(float64, shape=(?, ?))> [id H]
586+
└─ *2-<Matrix(float64, shape=(?, ?))> [id I]
587587
"""
588588

589589
for truth, out in zip(exp_res.split("\n"), lines):

0 commit comments

Comments
 (0)