Skip to content

Implement OpFromGraph __eq__ and __hash__ #1114

Open
@ricardoV94

Description

@ricardoV94

Description

This allows merging duplicated nodes as well as comparing graph equality.

import pytensor
import pytensor.tensor as pt
from pytensor.compile.builders import OpFromGraph
from pytensor.graph.basic import equal_computations

x = pt.scalar("x")
out1 = OpFromGraph([x], [x + 1])(x)
out2 = OpFromGraph([x], [x + 1])(x)

assert equal_computations([out1], [out2])

It should pass the assert. It fails because out1.owner.op == out2.owner.op evaluates to False. We can probably do something very similar to Scan:

pytensor/pytensor/scan/op.py

Lines 1254 to 1320 in 4b41e09

def __eq__(self, other):
if type(self) is not type(other):
return False
if self.info != other.info:
return False
if self.profile != other.profile:
return False
if self.truncate_gradient != other.truncate_gradient:
return False
if self.name != other.name:
return False
if self.allow_gc != other.allow_gc:
return False
# Compare inner graphs
# TODO: Use `self.inner_fgraph == other.inner_fgraph`
if len(self.inner_inputs) != len(other.inner_inputs):
return False
if len(self.inner_outputs) != len(other.inner_outputs):
return False
# strict=False because length already compared above
for self_in, other_in in zip(
self.inner_inputs, other.inner_inputs, strict=False
):
if self_in.type != other_in.type:
return False
return equal_computations(
self.inner_outputs,
other.inner_outputs,
self.inner_inputs,
other.inner_inputs,
)
def __str__(self):
inplace = "none"
if self.destroy_map:
# Check if all outputs are inplace
if sorted(self.destroy_map) == sorted(
range(self.info.n_mit_mot + self.info.n_mit_sot + self.info.n_sit_sot)
):
inplace = "all"
else:
inplace = str(list(self.destroy_map))
return (
f"Scan{{{self.name}, while_loop={self.info.as_while}, inplace={inplace}}}"
)
def __hash__(self):
return hash(
(
type(self),
self._hash_inner_graph,
self.info,
self.profile,
self.truncate_gradient,
self.name,
self.allow_gc,
)
)

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions