Skip to content

Switch to new backend branch #291

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
815 changes: 521 additions & 294 deletions Cargo.lock

Large diffs are not rendered by default.

11 changes: 7 additions & 4 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,11 @@ crate-type = ["cdylib"]
[dependencies]
pyo3 = { version = "0.23.0", features = ["extension-module"] }

egglog = { git = "https://github.com/egraphs-good/egglog", rev = "6f494282442803201b512e9d0828007b52a0b29c" }
egglog-experimental = { git = "https://github.com/egraphs-good/egglog-experimental", rev = "8a1b3d6ad2723a8438f51f05027161e51f37917c" }
egglog = { git = "https://github.com/egraphs-good/egglog", rev = "367a9143be7cb5354a54a9c5660d117440db77a6" }
# egglog = { git = "https://github.com/saulshanabrook/egg-smol", rev = "8cd7c0e77a27c271cbaef09ab23514675d646937" }
egglog-bridge = { git = "https://github.com/egraphs-good/egglog-backend.git", rev = "12e23e1" }
core-relations = { git = "https://github.com/egraphs-good/egglog-backend.git", rev = "12e23e1" }
# egglog-experimental = { git = "https://github.com/egraphs-good/egglog-experimental", rev = "8a1b3d6ad2723a8438f51f05027161e51f37917c" }
egraph-serialize = { version = "0.2.0", features = ["serde", "graphviz"] }
serde_json = "1.0.140"
pyo3-log = "0.12.3"
Expand All @@ -23,9 +26,9 @@ ordered-float = "3.7.0"
uuid = { version = "1.16.0", features = ["v4"] }

# Use unreleased version of egglog in experimental
[patch.'https://github.com/egraphs-good/egglog']
# [patch.'https://github.com/egraphs-good/egglog']
# https://github.com/rust-lang/cargo/issues/5478#issuecomment-522719793
egglog = { git = "https://github.com/egraphs-good//egglog.git", rev = "6f494282442803201b512e9d0828007b52a0b29c" }
# egglog = { git = "https://github.com/egraphs-good//egglog.git", rev = "6f494282442803201b512e9d0828007b52a0b29c" }

# [replace]
# 'https://github.com/egraphs-good/egglog.git#[email protected]' = { git = "https://github.com/egraphs-good/egglog.git", rev = "215714e1cbb13ae9e21bed2f2e1bf95804571512" }
Expand Down
3 changes: 3 additions & 0 deletions docs/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ _This project uses semantic versioning_

## UNRELEASED

- Upgrade egglog which includes new backend. Removes support for egglog experimental including `Rational` since it
is not compatible with new backend yet.

## 10.0.1 (2025-04-06)

- Fix bug on resolving types if not all imported to your module [#286](https://github.com/egraphs-good/egglog-python/pull/286)
Expand Down
6 changes: 3 additions & 3 deletions docs/reference/egglog-translation.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ i64(10) + i64(2)
```

```{code-cell} python
# egg: (+ (rational 1 2) (rational 2 1))
Rational(i64(1), i64(2)) / Rational(i64(2), i64(1))
# egg: (+ (bigrat (bigint 1) (bigint 2)) (big-rat (bigint 2) (bigint 1)))
BigRat(1, 2) / BigRat(2, 1)
```

These types are also all checked statically with MyPy, so for example, if you try to add a `String` and a `i64`, you will get a type error.
Expand All @@ -44,7 +44,7 @@ i64(10) + 2
```

```{code-cell} python
Rational(1, 2) / Rational(2, 1)
BigRat(1, 2) / BigRat(2, 1)
```

### `!=` Operator
Expand Down
126 changes: 63 additions & 63 deletions python/egglog/builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
"MapLike",
"MultiSet",
"PyObject",
"Rational",
"Set",
"SetLike",
"String",
Expand Down Expand Up @@ -527,91 +526,92 @@ def __add__(self, other: MultiSet[T]) -> MultiSet[T]: ...
def map(self, f: Callable[[T], T]) -> MultiSet[T]: ...


class Rational(BuiltinExpr):
@method(preserve=True)
def eval(self) -> Fraction:
call = _extract_call(self)
if call.callable != InitRef("Rational"):
msg = "Rational can only be initialized with the Rational constructor."
raise BuiltinEvalError(msg)
# Removed until egglog experimental supports new backend
# class Rational(BuiltinExpr):
# @method(preserve=True)
# def eval(self) -> Fraction:
# call = _extract_call(self)
# if call.callable != InitRef("Rational"):
# msg = "Rational can only be initialized with the Rational constructor."
# raise BuiltinEvalError(msg)

def _to_int(e: TypedExprDecl) -> int:
expr = e.expr
if not isinstance(expr, LitDecl):
msg = "Rational can only be initialized with literals"
raise BuiltinEvalError(msg)
assert isinstance(expr.value, int)
return expr.value
# def _to_int(e: TypedExprDecl) -> int:
# expr = e.expr
# if not isinstance(expr, LitDecl):
# msg = "Rational can only be initialized with literals"
# raise BuiltinEvalError(msg)
# assert isinstance(expr.value, int)
# return expr.value

num, den = call.args
return Fraction(_to_int(num), _to_int(den))
# num, den = call.args
# return Fraction(_to_int(num), _to_int(den))

@method(preserve=True)
def __float__(self) -> float:
return float(self.eval())
# @method(preserve=True)
# def __float__(self) -> float:
# return float(self.eval())

@method(preserve=True)
def __int__(self) -> int:
return int(self.eval())
# @method(preserve=True)
# def __int__(self) -> int:
# return int(self.eval())

@method(egg_fn="rational")
def __init__(self, num: i64Like, den: i64Like) -> None: ...
# @method(egg_fn="rational")
# def __init__(self, num: i64Like, den: i64Like) -> None: ...

@method(egg_fn="to-f64")
def to_f64(self) -> f64: ...
# @method(egg_fn="to-f64")
# def to_f64(self) -> f64: ...

@method(egg_fn="+")
def __add__(self, other: Rational) -> Rational: ...
# @method(egg_fn="+")
# def __add__(self, other: Rational) -> Rational: ...

@method(egg_fn="-")
def __sub__(self, other: Rational) -> Rational: ...
# @method(egg_fn="-")
# def __sub__(self, other: Rational) -> Rational: ...

@method(egg_fn="*")
def __mul__(self, other: Rational) -> Rational: ...
# @method(egg_fn="*")
# def __mul__(self, other: Rational) -> Rational: ...

@method(egg_fn="/")
def __truediv__(self, other: Rational) -> Rational: ...
# @method(egg_fn="/")
# def __truediv__(self, other: Rational) -> Rational: ...

@method(egg_fn="min")
def min(self, other: Rational) -> Rational: ...
# @method(egg_fn="min")
# def min(self, other: Rational) -> Rational: ...

@method(egg_fn="max")
def max(self, other: Rational) -> Rational: ...
# @method(egg_fn="max")
# def max(self, other: Rational) -> Rational: ...

@method(egg_fn="neg")
def __neg__(self) -> Rational: ...
# @method(egg_fn="neg")
# def __neg__(self) -> Rational: ...

@method(egg_fn="abs")
def __abs__(self) -> Rational: ...
# @method(egg_fn="abs")
# def __abs__(self) -> Rational: ...

@method(egg_fn="floor")
def floor(self) -> Rational: ...
# @method(egg_fn="floor")
# def floor(self) -> Rational: ...

@method(egg_fn="ceil")
def ceil(self) -> Rational: ...
# @method(egg_fn="ceil")
# def ceil(self) -> Rational: ...

@method(egg_fn="round")
def round(self) -> Rational: ...
# @method(egg_fn="round")
# def round(self) -> Rational: ...

@method(egg_fn="pow")
def __pow__(self, other: Rational) -> Rational: ...
# @method(egg_fn="pow")
# def __pow__(self, other: Rational) -> Rational: ...

@method(egg_fn="log")
def log(self) -> Rational: ...
# @method(egg_fn="log")
# def log(self) -> Rational: ...

@method(egg_fn="sqrt")
def sqrt(self) -> Rational: ...
# @method(egg_fn="sqrt")
# def sqrt(self) -> Rational: ...

@method(egg_fn="cbrt")
def cbrt(self) -> Rational: ...
# @method(egg_fn="cbrt")
# def cbrt(self) -> Rational: ...

@method(egg_fn="numer") # type: ignore[misc]
@property
def numer(self) -> i64: ...
# @method(egg_fn="numer") # type: ignore[misc]
# @property
# def numer(self) -> i64: ...

@method(egg_fn="denom") # type: ignore[misc]
@property
def denom(self) -> i64: ...
# @method(egg_fn="denom") # type: ignore[misc]
# @property
# def denom(self) -> i64: ...


class BigInt(BuiltinExpr):
Expand Down
8 changes: 4 additions & 4 deletions python/tests/test_high_level.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,10 +511,10 @@ def test_set(self):
assert i64(1) in s
assert i64(3) not in s

def test_rational(self):
assert Rational(1, 2).eval() == Fraction(1, 2)
assert float(Rational(1, 2)) == 0.5
assert int(Rational(1, 1)) == 1
# def test_rational(self):
# assert Rational(1, 2).eval() == Fraction(1, 2)
# assert float(Rational(1, 2)) == 0.5
# assert int(Rational(1, 1)) == 1

def test_vec(self):
assert Vec[i64].empty().eval() == ()
Expand Down
40 changes: 22 additions & 18 deletions src/egraph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use crate::error::{EggResult, WrappedError};
use crate::py_object_sort::ArcPyObjectSort;
use crate::serialize::SerializedEGraph;

use egglog::{span, SerializeConfig};
use egglog::{span, EGraph as EgglogEGraph, SerializeConfig};
use log::info;
use pyo3::prelude::*;
use std::path::PathBuf;
Expand Down Expand Up @@ -33,7 +33,7 @@ impl EGraph {
seminaive: bool,
record: bool,
) -> Self {
let mut egraph = egglog_experimental::new_experimental_egraph();
let mut egraph = EgglogEGraph::default();
egraph.fact_directory = fact_directory;
egraph.seminaive = seminaive;
if let Some(py_object_sort) = py_object_sort {
Expand Down Expand Up @@ -61,16 +61,17 @@ impl EGraph {
/// Returns a list of strings representing the output.
/// An EggSmolError is raised if there is problem parsing or executing.
#[pyo3(signature=(*commands))]
fn run_program(&mut self, commands: Vec<Command>) -> EggResult<Vec<String>> {
fn run_program(&mut self, py: Python<'_>, commands: Vec<Command>) -> EggResult<Vec<String>> {
let commands: Vec<egglog::ast::Command> = commands.into_iter().map(|x| x.into()).collect();
let mut cmds_str = String::new();
for cmd in &commands {
cmds_str = cmds_str + &cmd.to_string() + "\n";
}
info!("Running commands:\n{}", cmds_str);

let res = self.egraph.run_program(commands).map_err(|e| {
WrappedError::Egglog(e, "\nWhen running commands:\n".to_string() + &cmds_str)
let res = py.allow_threads(|| {
self.egraph.run_program(commands).map_err(|e| {
WrappedError::Egglog(e, "\nWhen running commands:\n".to_string() + &cmds_str)
})
});
if res.is_ok() {
if let Some(cmds) = &mut self.cmds {
Expand Down Expand Up @@ -115,22 +116,25 @@ impl EGraph {
)]
fn serialize(
&mut self,
py: Python<'_>,
root_eclasses: Vec<Expr>,
max_functions: Option<usize>,
max_calls_per_function: Option<usize>,
include_temporary_functions: bool,
) -> SerializedEGraph {
let root_eclasses: Vec<_> = root_eclasses
.into_iter()
.map(|x| self.egraph.eval_expr(&egglog::ast::Expr::from(x)).unwrap())
.collect();
SerializedEGraph {
egraph: self.egraph.serialize(SerializeConfig {
max_functions,
max_calls_per_function,
include_temporary_functions,
root_eclasses,
}),
}
py.allow_threads(|| {
let root_eclasses: Vec<_> = root_eclasses
.into_iter()
.map(|x| self.egraph.eval_expr(&egglog::ast::Expr::from(x)).unwrap())
.collect();
SerializedEGraph {
egraph: self.egraph.serialize(SerializeConfig {
max_functions,
max_calls_per_function,
include_temporary_functions,
root_eclasses,
}),
}
})
}
}
Loading
Loading