Skip to content

Improve Op string representation and debug_print formatting #319

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Jun 8, 2023
Merged
126 changes: 63 additions & 63 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,69 +22,69 @@ Getting started

.. code-block:: python

import pytensor
from pytensor import tensor as pt

# Declare two symbolic floating-point scalars
a = pt.dscalar("a")
b = pt.dscalar("b")

# Create a simple example expression
c = a + b

# Convert the expression into a callable object that takes `(a, b)`
# values as input and computes the value of `c`.
f_c = pytensor.function([a, b], c)

assert f_c(1.5, 2.5) == 4.0

# Compute the gradient of the example expression with respect to `a`
dc = pytensor.grad(c, a)

f_dc = pytensor.function([a, b], dc)

assert f_dc(1.5, 2.5) == 1.0

# Compiling functions with `pytensor.function` also optimizes
# expression graphs by removing unnecessary operations and
# replacing computations with more efficient ones.

v = pt.vector("v")
M = pt.matrix("M")

d = a/a + (M + a).dot(v)

pytensor.dprint(d)
# Elemwise{add,no_inplace} [id A] ''
# |InplaceDimShuffle{x} [id B] ''
# | |Elemwise{true_div,no_inplace} [id C] ''
# | |a [id D]
# | |a [id D]
# |dot [id E] ''
# |Elemwise{add,no_inplace} [id F] ''
# | |M [id G]
# | |InplaceDimShuffle{x,x} [id H] ''
# | |a [id D]
# |v [id I]

f_d = pytensor.function([a, v, M], d)

# `a/a` -> `1` and the dot product is replaced with a BLAS function
# (i.e. CGemv)
pytensor.dprint(f_d)
# Elemwise{Add}[(0, 1)] [id A] '' 5
# |TensorConstant{(1,) of 1.0} [id B]
# |CGemv{inplace} [id C] '' 4
# |AllocEmpty{dtype='float64'} [id D] '' 3
# | |Shape_i{0} [id E] '' 2
# | |M [id F]
# |TensorConstant{1.0} [id G]
# |Elemwise{add,no_inplace} [id H] '' 1
# | |M [id F]
# | |InplaceDimShuffle{x,x} [id I] '' 0
# | |a [id J]
# |v [id K]
# |TensorConstant{0.0} [id L]
import pytensor
from pytensor import tensor as pt

# Declare two symbolic floating-point scalars
a = pt.dscalar("a")
b = pt.dscalar("b")

# Create a simple example expression
c = a + b

# Convert the expression into a callable object that takes `(a, b)`
# values as input and computes the value of `c`.
f_c = pytensor.function([a, b], c)

assert f_c(1.5, 2.5) == 4.0

# Compute the gradient of the example expression with respect to `a`
dc = pytensor.grad(c, a)

f_dc = pytensor.function([a, b], dc)

assert f_dc(1.5, 2.5) == 1.0

# Compiling functions with `pytensor.function` also optimizes
# expression graphs by removing unnecessary operations and
# replacing computations with more efficient ones.

v = pt.vector("v")
M = pt.matrix("M")

d = a/a + (M + a).dot(v)

pytensor.dprint(d)
# Add [id A]
# ├─ ExpandDims{axis=0} [id B]
# │ └─ True_div [id C]
# ├─ a [id D]
# └─ a [id D]
# └─ dot [id E]
# ├─ Add [id F]
# │ ├─ M [id G]
# │ └─ ExpandDims{axes=[0, 1]} [id H]
# │ └─ a [id D]
# └─ v [id I]

f_d = pytensor.function([a, v, M], d)

# `a/a` -> `1` and the dot product is replaced with a BLAS function
# (i.e. CGemv)
pytensor.dprint(f_d)
# Add [id A] 5
# ├─ [1.] [id B]
# └─ CGemv{inplace} [id C] 4
# ├─ AllocEmpty{dtype='float64'} [id D] 3
# │ └─ Shape_i{0} [id E] 2
# │ └─ M [id F]
# ├─ 1.0 [id G]
# ├─ Add [id H] 1
# │ ├─ M [id F]
# │ └─ ExpandDims{axes=[0, 1]} [id I] 0
# │ └─ a [id J]
# ├─ v [id K]
# └─ 0.0 [id L]

See `the PyTensor documentation <https://pytensor.readthedocs.io/en/latest/>`__ for in-depth tutorials.

Expand Down
21 changes: 14 additions & 7 deletions pytensor/graph/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -763,13 +763,20 @@ def signature(self):
return (self.type, self.data)

def __str__(self):
if self.name is not None:
return self.name
else:
name = str(self.data)
if len(name) > 20:
name = name[:10] + "..." + name[-10:]
return f"{type(self).__name__}{{{name}}}"
data_str = str(self.data)
if len(data_str) > 20:
data_str = data_str[:10].strip() + " ... " + data_str[-10:].strip()

if self.name is None:
return data_str

return f"{self.name}{{{data_str}}}"

def __repr__(self):
data_str = repr(self.data)
if len(data_str) > 20:
data_str = data_str[:10].strip() + " ... " + data_str[-10:].strip()
return f"{type(self).__name__}({repr(self.type)}, data={data_str})"

def clone(self, **kwargs):
return self
Expand Down
5 changes: 5 additions & 0 deletions pytensor/graph/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,6 +621,11 @@ def make_thunk(
def __str__(self):
return getattr(type(self), "__name__", super().__str__())

def __repr__(self):
props = getattr(self, "__props__", ())
props = ",".join(f"{prop}={getattr(self, prop, '?')}" for prop in props)
return f"{self.__class__.__name__}({props})"


class _NoPythonOp(Op):
"""A class used to indicate that an `Op` does not provide a Python implementation.
Expand Down
1 change: 1 addition & 0 deletions pytensor/graph/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,7 @@ def __eq__(self, other):

dct["__eq__"] = __eq__

# FIXME: This overrides __str__ inheritance when props are provided
if "__str__" not in dct:
if len(props) == 0:

Expand Down
33 changes: 24 additions & 9 deletions pytensor/printing.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ def debugprint(
for var in inputs_to_print:
_debugprint(
var,
prefix="-",
prefix="",
depth=depth,
done=done,
print_type=print_type,
Expand Down Expand Up @@ -342,11 +342,17 @@ def debugprint(

if len(inner_graph_vars) > 0:
print("", file=_file)
new_prefix = " >"
new_prefix_child = " >"
prefix = ""
new_prefix = prefix + " ← "
new_prefix_child = prefix + " "
print("Inner graphs:", file=_file)

printed_inner_graphs_nodes = set()
for ig_var in inner_graph_vars:
if ig_var.owner in printed_inner_graphs_nodes:
continue
else:
printed_inner_graphs_nodes.add(ig_var.owner)
# This is a work-around to maintain backward compatibility
# (e.g. to only print inner graphs that have been compiled through
# a call to `Op.prepare_node`)
Expand Down Expand Up @@ -385,6 +391,7 @@ def debugprint(

_debugprint(
ig_var,
prefix=prefix,
depth=depth,
done=done,
print_type=print_type,
Expand All @@ -399,13 +406,14 @@ def debugprint(
print_op_info=print_op_info,
print_destroy_map=print_destroy_map,
print_view_map=print_view_map,
is_inner_graph_header=True,
)

if print_fgraph_inputs:
for inp in inner_inputs:
_debugprint(
inp,
prefix="-",
prefix="",
depth=depth,
done=done,
print_type=print_type,
Expand Down Expand Up @@ -485,6 +493,7 @@ def _debugprint(
parent_node: Optional[Apply] = None,
print_op_info: bool = False,
inner_graph_node: Optional[Apply] = None,
is_inner_graph_header: bool = False,
) -> TextIO:
r"""Print the graph represented by `var`.

Expand Down Expand Up @@ -625,15 +634,18 @@ def get_id_str(
else:
data = ""

var_output = f"{prefix}{node.op}{output_idx}{id_str}{type_str}{var_name}{destroy_map_str}{view_map_str}{o}{data}"
if is_inner_graph_header:
var_output = f"{prefix}{node.op}{id_str}{destroy_map_str}{view_map_str}{o}"
else:
var_output = f"{prefix}{node.op}{output_idx}{id_str}{type_str}{var_name}{destroy_map_str}{view_map_str}{o}{data}"

if print_op_info and node not in op_information:
op_information.update(op_debug_information(node.op, node))

node_info = (
parent_node and op_information.get(parent_node)
) or op_information.get(node)
if node_info and var in node_info:
if node_info and var in node_info and not is_inner_graph_header:
var_output = f"{var_output} ({node_info[var]})"

if profile and profile.apply_time and node in profile.apply_time:
Expand All @@ -660,12 +672,13 @@ def get_id_str(
if not already_done and (
not stop_on_name or not (hasattr(var, "name") and var.name is not None)
):
new_prefix = prefix_child + " |"
new_prefix_child = prefix_child + " |"
new_prefix = prefix_child + " ├─ "
new_prefix_child = prefix_child + " "

for in_idx, in_var in enumerate(node.inputs):
if in_idx == len(node.inputs) - 1:
new_prefix_child = prefix_child + " "
new_prefix = prefix_child + " └─ "
new_prefix_child = prefix_child + " "

if hasattr(in_var, "owner") and hasattr(in_var.owner, "op"):
if (
Expand Down Expand Up @@ -698,6 +711,8 @@ def get_id_str(
print_view_map=print_view_map,
inner_graph_node=inner_graph_node,
)
elif not is_inner_graph_header:
print(prefix_child + " └─ ···", file=file)
else:
id_str = get_id_str(var)

Expand Down
22 changes: 21 additions & 1 deletion pytensor/scalar/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -4143,6 +4143,7 @@ class Composite(ScalarInnerGraphOp):

def __init__(self, inputs, outputs, name="Composite"):
self.name = name
self._name = None
# We need to clone the graph as sometimes its nodes already
# contain a reference to an fgraph. As we want the Composite
# to be pickable, we can't have reference to fgraph.
Expand Down Expand Up @@ -4189,7 +4190,26 @@ def __init__(self, inputs, outputs, name="Composite"):
super().__init__()

def __str__(self):
return self.name
if self._name is not None:
return self._name

# Rename internal variables
for i, r in enumerate(self.fgraph.inputs):
r.name = f"i{int(i)}"
for i, r in enumerate(self.fgraph.outputs):
r.name = f"o{int(i)}"
io = set(self.fgraph.inputs + self.fgraph.outputs)
for i, r in enumerate(self.fgraph.variables):
if r not in io and len(self.fgraph.clients[r]) > 1:
r.name = f"t{int(i)}"

if len(self.fgraph.outputs) > 1 or len(self.fgraph.apply_nodes) > 10:
self._name = "Composite{...}"
else:
outputs_str = ", ".join([pprint(output) for output in self.fgraph.outputs])
self._name = f"Composite{{{outputs_str}}}"

return self._name

def make_new_inplace(self, output_types_preference=None, name=None):
"""
Expand Down
21 changes: 6 additions & 15 deletions pytensor/scan/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -1282,27 +1282,18 @@ def __eq__(self, other):
)

def __str__(self):
device_str = "cpu"
if self.info.as_while:
name = "do_while"
else:
name = "for"
aux_txt = "%s"
inplace = "none"
if len(self.destroy_map.keys()) > 0:
# Check if all outputs are inplace
if sorted(self.destroy_map.keys()) == sorted(
range(self.info.n_mit_mot + self.info.n_mit_sot + self.info.n_sit_sot)
):
aux_txt += "all_inplace,%s,%s}"
inplace = "all"
else:
aux_txt += "{inplace{"
for k in self.destroy_map.keys():
aux_txt += str(k) + ","
aux_txt += "},%s,%s}"
else:
aux_txt += "{%s,%s}"
aux_txt = aux_txt % (name, device_str, str(self.name))
return aux_txt
inplace = str(list(self.destroy_map.keys()))
return (
f"Scan{{{self.name}, while_loop={self.info.as_while}, inplace={inplace}}}"
)

def __hash__(self):
return hash(
Expand Down
Loading