Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement OpFromGraph __eq__ and __hash__ #1114

Open
ricardoV94 opened this issue Dec 9, 2024 · 0 comments
Open

Implement OpFromGraph __eq__ and __hash__ #1114

ricardoV94 opened this issue Dec 9, 2024 · 0 comments
Labels
help wanted Extra attention is needed OpFromGraph

Comments

@ricardoV94
Copy link
Member

ricardoV94 commented Dec 9, 2024

Description

This allows merging duplicated nodes as well as comparing graph equality.

import pytensor
import pytensor.tensor as pt
from pytensor.compile.builders import OpFromGraph
from pytensor.graph.basic import equal_computations

x = pt.scalar("x")
out1 = OpFromGraph([x], [x + 1])(x)
out2 = OpFromGraph([x], [x + 1])(x)

assert equal_computations([out1], [out2])

It should pass the assert. It fails because out1.owner.op == out2.owner.op evaluates to False. We can probably do something very similar to Scan:

pytensor/pytensor/scan/op.py

Lines 1254 to 1320 in 4b41e09

def __eq__(self, other):
if type(self) is not type(other):
return False
if self.info != other.info:
return False
if self.profile != other.profile:
return False
if self.truncate_gradient != other.truncate_gradient:
return False
if self.name != other.name:
return False
if self.allow_gc != other.allow_gc:
return False
# Compare inner graphs
# TODO: Use `self.inner_fgraph == other.inner_fgraph`
if len(self.inner_inputs) != len(other.inner_inputs):
return False
if len(self.inner_outputs) != len(other.inner_outputs):
return False
# strict=False because length already compared above
for self_in, other_in in zip(
self.inner_inputs, other.inner_inputs, strict=False
):
if self_in.type != other_in.type:
return False
return equal_computations(
self.inner_outputs,
other.inner_outputs,
self.inner_inputs,
other.inner_inputs,
)
def __str__(self):
inplace = "none"
if self.destroy_map:
# Check if all outputs are inplace
if sorted(self.destroy_map) == sorted(
range(self.info.n_mit_mot + self.info.n_mit_sot + self.info.n_sit_sot)
):
inplace = "all"
else:
inplace = str(list(self.destroy_map))
return (
f"Scan{{{self.name}, while_loop={self.info.as_while}, inplace={inplace}}}"
)
def __hash__(self):
return hash(
(
type(self),
self._hash_inner_graph,
self.info,
self.profile,
self.truncate_gradient,
self.name,
self.allow_gc,
)
)

@ricardoV94 ricardoV94 added help wanted Extra attention is needed OpFromGraph labels Dec 9, 2024
@ricardoV94 ricardoV94 changed the title Implement OpFromGraph.__eq__ Implement OpFromGraph __eq__ and __hash__ Dec 9, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
help wanted Extra attention is needed OpFromGraph
Projects
None yet
Development

No branches or pull requests

1 participant