@@ -840,22 +840,34 @@ def __hash__(self):
840
840
841
841
@staticmethod
842
842
def str_from_slice (entry ):
843
- msg = []
844
- for x in [entry .start , entry .stop , entry .step ]:
845
- if x is None :
846
- msg .append ("" )
847
- else :
848
- msg .append (str (x ))
849
- return ":" .join (msg )
843
+ if entry .step :
844
+ return ":" .join (
845
+ (
846
+ "start" if entry .start else "" ,
847
+ "stop" if entry .stop else "" ,
848
+ "step" ,
849
+ )
850
+ )
851
+ if entry .stop :
852
+ return f"{ 'start' if entry .start else '' } :stop"
853
+ if entry .start :
854
+ return "start:"
855
+ return ":"
850
856
851
- def __str__ (self ):
857
+ @staticmethod
858
+ def str_from_indices (idx_list ):
852
859
indices = []
853
- for entry in self .idx_list :
860
+ letter_indexes = 0
861
+ for entry in idx_list :
854
862
if isinstance (entry , slice ):
855
- indices .append (self .str_from_slice (entry ))
863
+ indices .append (Subtensor .str_from_slice (entry ))
856
864
else :
857
- indices .append (str (entry ))
858
- return f"{ self .__class__ .__name__ } {{{ ', ' .join (indices )} }}"
865
+ indices .append ("ijk" [letter_indexes % 3 ] * (letter_indexes // 3 + 1 ))
866
+ letter_indexes += 1
867
+ return ", " .join (indices )
868
+
869
+ def __str__ (self ):
870
+ return f"{ self .__class__ .__name__ } {{{ self .str_from_indices (self .idx_list )} }}"
859
871
860
872
@staticmethod
861
873
def default_helper_c_code_args ():
@@ -1498,21 +1510,8 @@ def __hash__(self):
1498
1510
return hash ((type (self ), idx_list , self .inplace , self .set_instead_of_inc ))
1499
1511
1500
1512
def __str__ (self ):
1501
- indices = []
1502
- for entry in self .idx_list :
1503
- if isinstance (entry , slice ):
1504
- indices .append (Subtensor .str_from_slice (entry ))
1505
- else :
1506
- indices .append (str (entry ))
1507
- if self .inplace :
1508
- msg = "Inplace"
1509
- else :
1510
- msg = ""
1511
- if not self .set_instead_of_inc :
1512
- msg += "Inc"
1513
- else :
1514
- msg += "Set"
1515
- return f"{ self .__class__ .__name__ } {{{ msg } ;{ ', ' .join (indices )} }}"
1513
+ name = "SetSubtensor" if self .set_instead_of_inc else "IncSubtensor"
1514
+ return f"{ name } {{{ Subtensor .str_from_indices (self .idx_list )} }}"
1516
1515
1517
1516
def make_node (self , x , y , * inputs ):
1518
1517
"""
@@ -2661,10 +2660,10 @@ def __init__(
2661
2660
self .ignore_duplicates = ignore_duplicates
2662
2661
2663
2662
def __str__ (self ):
2664
- return "{}{{{}, {}}}" . format (
2665
- self . __class__ . __name__ ,
2666
- "inplace=" + str ( self .inplace ),
2667
- " set_instead_of_inc=" + str ( self . set_instead_of_inc ),
2663
+ return (
2664
+ "AdvancedSetSubtensor"
2665
+ if self .set_instead_of_inc
2666
+ else "AdvancedIncSubtensor"
2668
2667
)
2669
2668
2670
2669
def make_node (self , x , y , * inputs ):
0 commit comments