Skip to content

Commit 2f86f79

Browse files
committed
Improve string representation of Subtensor Ops
1 parent 9066338 commit 2f86f79

File tree

2 files changed

+112
-113
lines changed

2 files changed

+112
-113
lines changed

pytensor/tensor/subtensor.py

Lines changed: 30 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -840,22 +840,34 @@ def __hash__(self):
840840

841841
@staticmethod
842842
def str_from_slice(entry):
843-
msg = []
844-
for x in [entry.start, entry.stop, entry.step]:
845-
if x is None:
846-
msg.append("")
847-
else:
848-
msg.append(str(x))
849-
return ":".join(msg)
843+
if entry.step:
844+
return ":".join(
845+
(
846+
"start" if entry.start else "",
847+
"stop" if entry.stop else "",
848+
"step",
849+
)
850+
)
851+
if entry.stop:
852+
return f"{'start' if entry.start else ''}:stop"
853+
if entry.start:
854+
return "start:"
855+
return ":"
850856

851-
def __str__(self):
857+
@staticmethod
858+
def str_from_indices(idx_list):
852859
indices = []
853-
for entry in self.idx_list:
860+
letter_indexes = 0
861+
for entry in idx_list:
854862
if isinstance(entry, slice):
855-
indices.append(self.str_from_slice(entry))
863+
indices.append(Subtensor.str_from_slice(entry))
856864
else:
857-
indices.append(str(entry))
858-
return f"{self.__class__.__name__}{{{', '.join(indices)}}}"
865+
indices.append("ijk"[letter_indexes % 3] * (letter_indexes // 3 + 1))
866+
letter_indexes += 1
867+
return ", ".join(indices)
868+
869+
def __str__(self):
870+
return f"{self.__class__.__name__}{{{self.str_from_indices(self.idx_list)}}}"
859871

860872
@staticmethod
861873
def default_helper_c_code_args():
@@ -1498,21 +1510,8 @@ def __hash__(self):
14981510
return hash((type(self), idx_list, self.inplace, self.set_instead_of_inc))
14991511

15001512
def __str__(self):
1501-
indices = []
1502-
for entry in self.idx_list:
1503-
if isinstance(entry, slice):
1504-
indices.append(Subtensor.str_from_slice(entry))
1505-
else:
1506-
indices.append(str(entry))
1507-
if self.inplace:
1508-
msg = "Inplace"
1509-
else:
1510-
msg = ""
1511-
if not self.set_instead_of_inc:
1512-
msg += "Inc"
1513-
else:
1514-
msg += "Set"
1515-
return f"{self.__class__.__name__}{{{msg};{', '.join(indices)}}}"
1513+
name = "SetSubtensor" if self.set_instead_of_inc else "IncSubtensor"
1514+
return f"{name}{{{Subtensor.str_from_indices(self.idx_list)}}}"
15161515

15171516
def make_node(self, x, y, *inputs):
15181517
"""
@@ -2661,10 +2660,10 @@ def __init__(
26612660
self.ignore_duplicates = ignore_duplicates
26622661

26632662
def __str__(self):
2664-
return "{}{{{}, {}}}".format(
2665-
self.__class__.__name__,
2666-
"inplace=" + str(self.inplace),
2667-
" set_instead_of_inc=" + str(self.set_instead_of_inc),
2663+
return (
2664+
"AdvancedSetSubtensor"
2665+
if self.set_instead_of_inc
2666+
else "AdvancedIncSubtensor"
26682667
)
26692668

26702669
def make_node(self, x, y, *inputs):

0 commit comments

Comments
 (0)