@@ -291,7 +291,7 @@ def debugprint(
291
291
for var in inputs_to_print :
292
292
_debugprint (
293
293
var ,
294
- prefix = "- " ,
294
+ prefix = "→ " ,
295
295
depth = depth ,
296
296
done = done ,
297
297
print_type = print_type ,
@@ -342,11 +342,17 @@ def debugprint(
342
342
343
343
if len (inner_graph_vars ) > 0 :
344
344
print ("" , file = _file )
345
- new_prefix = " >"
346
- new_prefix_child = " >"
345
+ prefix = ""
346
+ new_prefix = prefix + " ← "
347
+ new_prefix_child = prefix + " "
347
348
print ("Inner graphs:" , file = _file )
348
349
350
+ printed_inner_graphs_nodes = set ()
349
351
for ig_var in inner_graph_vars :
352
+ if ig_var .owner in printed_inner_graphs_nodes :
353
+ continue
354
+ else :
355
+ printed_inner_graphs_nodes .add (ig_var .owner )
350
356
# This is a work-around to maintain backward compatibility
351
357
# (e.g. to only print inner graphs that have been compiled through
352
358
# a call to `Op.prepare_node`)
@@ -385,6 +391,7 @@ def debugprint(
385
391
386
392
_debugprint (
387
393
ig_var ,
394
+ prefix = prefix ,
388
395
depth = depth ,
389
396
done = done ,
390
397
print_type = print_type ,
@@ -399,13 +406,14 @@ def debugprint(
399
406
print_op_info = print_op_info ,
400
407
print_destroy_map = print_destroy_map ,
401
408
print_view_map = print_view_map ,
409
+ is_inner_graph_header = True ,
402
410
)
403
411
404
412
if print_fgraph_inputs :
405
413
for inp in inner_inputs :
406
414
_debugprint (
407
415
inp ,
408
- prefix = "- " ,
416
+ prefix = " → " ,
409
417
depth = depth ,
410
418
done = done ,
411
419
print_type = print_type ,
@@ -485,6 +493,7 @@ def _debugprint(
485
493
parent_node : Optional [Apply ] = None ,
486
494
print_op_info : bool = False ,
487
495
inner_graph_node : Optional [Apply ] = None ,
496
+ is_inner_graph_header : bool = False ,
488
497
) -> TextIO :
489
498
r"""Print the graph represented by `var`.
490
499
@@ -625,15 +634,18 @@ def get_id_str(
625
634
else :
626
635
data = ""
627
636
628
- var_output = f"{ prefix } { node .op } { output_idx } { id_str } { type_str } { var_name } { destroy_map_str } { view_map_str } { o } { data } "
637
+ if is_inner_graph_header :
638
+ var_output = f"{ prefix } { node .op } { id_str } { destroy_map_str } { view_map_str } { o } "
639
+ else :
640
+ var_output = f"{ prefix } { node .op } { output_idx } { id_str } { type_str } { var_name } { destroy_map_str } { view_map_str } { o } { data } "
629
641
630
642
if print_op_info and node not in op_information :
631
643
op_information .update (op_debug_information (node .op , node ))
632
644
633
645
node_info = (
634
646
parent_node and op_information .get (parent_node )
635
647
) or op_information .get (node )
636
- if node_info and var in node_info :
648
+ if node_info and var in node_info and not is_inner_graph_header :
637
649
var_output = f"{ var_output } ({ node_info [var ]} )"
638
650
639
651
if profile and profile .apply_time and node in profile .apply_time :
@@ -660,12 +672,13 @@ def get_id_str(
660
672
if not already_done and (
661
673
not stop_on_name or not (hasattr (var , "name" ) and var .name is not None )
662
674
):
663
- new_prefix = prefix_child + " | "
664
- new_prefix_child = prefix_child + " | "
675
+ new_prefix = prefix_child + " ├─ "
676
+ new_prefix_child = prefix_child + " │ "
665
677
666
678
for in_idx , in_var in enumerate (node .inputs ):
667
679
if in_idx == len (node .inputs ) - 1 :
668
- new_prefix_child = prefix_child + " "
680
+ new_prefix = prefix_child + " └─ "
681
+ new_prefix_child = prefix_child + " "
669
682
670
683
if hasattr (in_var , "owner" ) and hasattr (in_var .owner , "op" ):
671
684
if (
@@ -698,6 +711,8 @@ def get_id_str(
698
711
print_view_map = print_view_map ,
699
712
inner_graph_node = inner_graph_node ,
700
713
)
714
+ elif not is_inner_graph_header :
715
+ print (prefix_child + " └─ ···" , file = file )
701
716
else :
702
717
id_str = get_id_str (var )
703
718
0 commit comments