Skip to content

Commit 863f99b

Browse files
Merge pull request #228 from egraphs-good/fix-loopnest
Add ability to subsume default definitions
2 parents f79dee8 + 2e089f2 commit 863f99b

File tree

12 files changed

+540
-59
lines changed

12 files changed

+540
-59
lines changed

docs/changelog.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@ _This project uses semantic versioning_
44

55
## UNRELEASED
66

7+
- Fix pretty printing of lambda functions
8+
- Add support for subsuming rewrite generated by default function and method definitions
9+
710
## 8.0.1 (2024-10-24)
811

912
- Upgrade dependencies including [egglog](https://github.com/egraphs-good/egglog/compare/saulshanabrook:egg-smol:a555b2f5e82c684442775cc1a5da94b71930113c...b0db06832264c9b22694bd3de2bdacd55bbe9e32)

docs/reference/contributing.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ Then install the package in editable mode with the development dependencies:
3333
uv sync --all-extras
3434
```
3535

36-
Anytime you change the rust code, you can run `uv sync` to recompile the rust code.
36+
Anytime you change the rust code, you can run `uv sync --reinstall-package egglog --all-extras` to force recompiling the rust code.
3737

3838
If you would like to download a new version of the visualizer source, run `make clean; make`. This will download
3939
the most recent released version from the github actions artifact in the [egraph-visualizer](https://github.com/egraphs-good/egraph-visualizer) repo. It is checked in because it's a pain to get cargo to include only one git ignored file while ignoring the rest of the files that were ignored.

pyproject.toml

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,14 @@ array = [
3535
"numba==0.59.1",
3636
"llvmlite==0.42.0",
3737
]
38-
dev = ["ruff", "pre-commit", "mypy", "anywidget[dev]", "egglog[docs,test]"]
38+
dev = [
39+
"ruff",
40+
"pre-commit",
41+
"mypy",
42+
"anywidget[dev]",
43+
"egglog[docs,test]",
44+
"jupyterlab",
45+
]
3946

4047
test = [
4148
"pytest",

python/egglog/declarations.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -764,6 +764,7 @@ class RuleDecl:
764764
class DefaultRewriteDecl:
765765
ref: CallableRef
766766
expr: ExprDecl
767+
subsume: bool
767768

768769

769770
RewriteOrRuleDecl: TypeAlias = RewriteDecl | BiRewriteDecl | RuleDecl | DefaultRewriteDecl

python/egglog/egraph.py

Lines changed: 44 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,7 @@ def method(
269269
unextractable: bool = False,
270270
) -> Callable[[Callable[P, EXPR]], Callable[P, EXPR]]:
271271
return lambda fn: _WrappedMethod(
272-
egg_fn, cost, default, merge, on_merge, fn, preserve, mutates_self, unextractable
272+
egg_fn, cost, default, merge, on_merge, fn, preserve, mutates_self, unextractable, False
273273
)
274274

275275
@overload
@@ -404,6 +404,7 @@ def method(
404404
on_merge: Callable[[Any, Any], Iterable[ActionLike]] | None = None,
405405
mutates_self: bool = False,
406406
unextractable: bool = False,
407+
subsume: bool = False,
407408
) -> Callable[[CALLABLE], CALLABLE]: ...
408409

409410

@@ -417,6 +418,7 @@ def method(
417418
on_merge: Callable[[EXPR, EXPR], Iterable[ActionLike]] | None = None,
418419
mutates_self: bool = False,
419420
unextractable: bool = False,
421+
subsume: bool = False,
420422
) -> Callable[[Callable[P, EXPR]], Callable[P, EXPR]]: ...
421423

422424

@@ -430,11 +432,14 @@ def method(
430432
preserve: bool = False,
431433
mutates_self: bool = False,
432434
unextractable: bool = False,
435+
subsume: bool = False,
433436
) -> Callable[[Callable[P, EXPR]], Callable[P, EXPR]]:
434437
"""
435438
Any method can be decorated with this to customize it's behavior. This is only supported in classes which subclass :class:`Expr`.
436439
"""
437-
return lambda fn: _WrappedMethod(egg_fn, cost, default, merge, on_merge, fn, preserve, mutates_self, unextractable)
440+
return lambda fn: _WrappedMethod(
441+
egg_fn, cost, default, merge, on_merge, fn, preserve, mutates_self, unextractable, subsume
442+
)
438443

439444

440445
class _ExprMetaclass(type):
@@ -519,7 +524,9 @@ def _generate_class_decls( # noqa: C901,PLR0912
519524
(inner_tp,) = v.__args__
520525
type_ref = resolve_type_annotation(decls, inner_tp)
521526
cls_decl.class_variables[k] = ConstantDecl(type_ref.to_just())
522-
_add_default_rewrite(decls, ClassVariableRef(cls_name, k), type_ref, namespace.pop(k, None), ruleset)
527+
_add_default_rewrite(
528+
decls, ClassVariableRef(cls_name, k), type_ref, namespace.pop(k, None), ruleset, subsume=False
529+
)
523530
else:
524531
msg = f"On class {cls_name}, for attribute '{k}', expected a ClassVar, but got {v}"
525532
raise NotImplementedError(msg)
@@ -542,12 +549,12 @@ def _generate_class_decls( # noqa: C901,PLR0912
542549
if is_init and cls_name in LIT_CLASS_NAMES:
543550
continue
544551
match method:
545-
case _WrappedMethod(egg_fn, cost, default, merge, on_merge, fn, preserve, mutates, unextractable):
552+
case _WrappedMethod(egg_fn, cost, default, merge, on_merge, fn, preserve, mutates, unextractable, subsume):
546553
pass
547554
case _:
548555
egg_fn, cost, default, merge, on_merge = None, None, None, None, None
549556
fn = method
550-
unextractable, preserve = False, False
557+
unextractable, preserve, subsume = False, False, False
551558
mutates = method_name in ALWAYS_MUTATES_SELF
552559
if preserve:
553560
cls_decl.preserved_methods[method_name] = fn
@@ -572,7 +579,20 @@ def _generate_class_decls( # noqa: C901,PLR0912
572579
continue
573580

574581
_, add_rewrite = _fn_decl(
575-
decls, egg_fn, ref, fn, locals, default, cost, merge, on_merge, mutates, builtin, ruleset, unextractable
582+
decls,
583+
egg_fn,
584+
ref,
585+
fn,
586+
locals,
587+
default,
588+
cost,
589+
merge,
590+
on_merge,
591+
mutates,
592+
builtin,
593+
ruleset=ruleset,
594+
unextractable=unextractable,
595+
subsume=subsume,
576596
)
577597

578598
if not builtin and not isinstance(ref, InitRef) and not mutates:
@@ -602,6 +622,7 @@ def function(
602622
builtin: bool = False,
603623
ruleset: Ruleset | None = None,
604624
use_body_as_name: bool = False,
625+
subsume: bool = False,
605626
) -> Callable[[CALLABLE], CALLABLE]: ...
606627

607628

@@ -617,6 +638,7 @@ def function(
617638
unextractable: bool = False,
618639
ruleset: Ruleset | None = None,
619640
use_body_as_name: bool = False,
641+
subsume: bool = False,
620642
) -> Callable[[Callable[P, EXPR]], Callable[P, EXPR]]: ...
621643

622644

@@ -649,6 +671,7 @@ class _FunctionConstructor:
649671
unextractable: bool = False
650672
ruleset: Ruleset | None = None
651673
use_body_as_name: bool = False
674+
subsume: bool = False
652675

653676
def __call__(self, fn: Callable[..., RuntimeExpr]) -> RuntimeFunction:
654677
return RuntimeFunction(*split_thunk(Thunk.fn(self.create_decls, fn)))
@@ -668,7 +691,8 @@ def create_decls(self, fn: Callable[..., RuntimeExpr]) -> tuple[Declarations, Ca
668691
self.on_merge,
669692
self.mutates_first_arg,
670693
self.builtin,
671-
self.ruleset,
694+
ruleset=self.ruleset,
695+
subsume=self.subsume,
672696
unextractable=self.unextractable,
673697
)
674698
add_rewrite()
@@ -690,6 +714,7 @@ def _fn_decl(
690714
on_merge: Callable[[RuntimeExpr, RuntimeExpr], Iterable[ActionLike]] | None,
691715
mutates_first_arg: bool,
692716
is_builtin: bool,
717+
subsume: bool,
693718
ruleset: Ruleset | None = None,
694719
unextractable: bool = False,
695720
) -> tuple[CallableRef, Callable[[], None]]:
@@ -804,7 +829,7 @@ def _fn_decl(
804829
res_ref = ref
805830
decls.set_function_decl(ref, decl)
806831
res_thunk = Thunk.fn(_create_default_value, decls, ref, fn, args, ruleset)
807-
return res_ref, Thunk.fn(_add_default_rewrite_function, decls, res_ref, return_type, ruleset, res_thunk)
832+
return res_ref, Thunk.fn(_add_default_rewrite_function, decls, res_ref, return_type, ruleset, res_thunk, subsume)
808833

809834

810835
# Overload to support aritys 0-4 until variadic generic support map, so we can map from type to value
@@ -871,7 +896,7 @@ def _constant_thunk(
871896
type_ref = resolve_type_annotation(decls, tp)
872897
callable_ref = ConstantRef(name)
873898
decls._constants[name] = ConstantDecl(type_ref.to_just(), egg_name)
874-
_add_default_rewrite(decls, callable_ref, type_ref, default_replacement, ruleset)
899+
_add_default_rewrite(decls, callable_ref, type_ref, default_replacement, ruleset, subsume=False)
875900
return decls, TypedExprDecl(type_ref.to_just(), CallDecl(callable_ref))
876901

877902

@@ -898,15 +923,21 @@ def _add_default_rewrite_function(
898923
res_type: TypeOrVarRef,
899924
ruleset: Ruleset | None,
900925
value_thunk: Callable[[], object],
926+
subsume: bool,
901927
) -> None:
902928
"""
903929
Helper functions that resolves a value thunk to create the default value.
904930
"""
905-
_add_default_rewrite(decls, ref, res_type, value_thunk(), ruleset)
931+
_add_default_rewrite(decls, ref, res_type, value_thunk(), ruleset, subsume)
906932

907933

908934
def _add_default_rewrite(
909-
decls: Declarations, ref: CallableRef, type_ref: TypeOrVarRef, default_rewrite: object, ruleset: Ruleset | None
935+
decls: Declarations,
936+
ref: CallableRef,
937+
type_ref: TypeOrVarRef,
938+
default_rewrite: object,
939+
ruleset: Ruleset | None,
940+
subsume: bool,
910941
) -> None:
911942
"""
912943
Adds a default rewrite for the callable, if the default rewrite is not None
@@ -916,7 +947,7 @@ def _add_default_rewrite(
916947
if default_rewrite is None:
917948
return
918949
resolved_value = resolve_literal(type_ref, default_rewrite, Thunk.value(decls))
919-
rewrite_decl = DefaultRewriteDecl(ref, resolved_value.__egg_typed_expr__.expr)
950+
rewrite_decl = DefaultRewriteDecl(ref, resolved_value.__egg_typed_expr__.expr, subsume)
920951
if ruleset:
921952
ruleset_decls = ruleset._current_egg_decls
922953
ruleset_decl = ruleset.__egg_ruleset__
@@ -1341,8 +1372,6 @@ def saturate(
13411372
from .visualizer_widget import VisualizerWidget
13421373

13431374
def to_json() -> str:
1344-
if expr:
1345-
print(self.extract(expr))
13461375
return self._serialize(**kwargs).to_json()
13471376

13481377
egraphs = [to_json()]
@@ -1407,6 +1436,7 @@ class _WrappedMethod(Generic[P, EXPR]):
14071436
preserve: bool
14081437
mutates_self: bool
14091438
unextractable: bool
1439+
subsume: bool
14101440

14111441
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> EXPR:
14121442
msg = "We should never call a wrapped method. Did you forget to wrap the class?"

python/egglog/egraph_state.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ def command_to_egg(self, cmd: CommandDecl, ruleset: str) -> bindings._Command:
134134
)
135135
return bindings.RuleCommand(name or "", ruleset, rule)
136136
# TODO: Replace with just constants value and looking at REF of function
137-
case DefaultRewriteDecl(ref, expr):
137+
case DefaultRewriteDecl(ref, expr, subsume):
138138
decl = self.__egg_decls__.get_callable_decl(ref).to_function_decl()
139139
sig = decl.signature
140140
assert isinstance(sig, FunctionSignature)
@@ -144,7 +144,7 @@ def command_to_egg(self, cmd: CommandDecl, ruleset: str) -> bindings._Command:
144144
for name, tp in zip(sig.arg_names, sig.arg_types, strict=False)
145145
)
146146
rewrite_decl = RewriteDecl(
147-
sig.semantic_return_type.to_just(), CallDecl(ref, arg_mapping), expr, (), False
147+
sig.semantic_return_type.to_just(), CallDecl(ref, arg_mapping), expr, (), subsume
148148
)
149149
return self.command_to_egg(rewrite_decl, ruleset)
150150
case _:

python/egglog/exp/array_api.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
import numpy as np
1515

1616
from egglog import *
17-
from egglog.bindings import EggSmolError
1817
from egglog.runtime import RuntimeExpr
1918

2019
from .program_gen import *
@@ -272,7 +271,6 @@ def var(cls, name: StringLike) -> TupleInt: ...
272271

273272
EMPTY: ClassVar[TupleInt]
274273

275-
@method(unextractable=True)
276274
def __init__(self, length: IntLike, idx_fn: Callable[[Int], Int]) -> None: ...
277275

278276
@classmethod
@@ -287,6 +285,7 @@ def range(cls, stop: Int) -> TupleInt:
287285
def from_vec(cls, vec: Vec[Int]) -> TupleInt:
288286
return TupleInt(vec.length(), partial(index_vec_int, vec))
289287

288+
@method(subsume=True)
290289
def __add__(self, other: TupleInt) -> TupleInt:
291290
return TupleInt(
292291
self.length() + other.length(),
@@ -308,13 +307,13 @@ def fold(self, init: Int, f: Callable[[Int, Int], Int]) -> Int: ...
308307

309308
def fold_boolean(self, init: Boolean, f: Callable[[Boolean, Int], Boolean]) -> Boolean: ...
310309

310+
@method(subsume=True)
311311
def contains(self, i: Int) -> Boolean:
312312
return self.fold_boolean(FALSE, lambda acc, j: acc | (i == j))
313313

314-
@method(cost=100)
315314
def filter(self, f: Callable[[Int], Boolean]) -> TupleInt: ...
316315

317-
@method(cost=100)
316+
@method(subsume=True)
318317
def map(self, f: Callable[[Int], Int]) -> TupleInt:
319318
return TupleInt(self.length(), lambda i: f(self[i]))
320319

@@ -372,7 +371,7 @@ def _tuple_int(
372371
ne(k).to(i64(0)),
373372
),
374373
# Empty
375-
rewrite(TupleInt.EMPTY).to(TupleInt(0, bottom_indexing)),
374+
rewrite(TupleInt.EMPTY, subsume=True).to(TupleInt(0, bottom_indexing)),
376375
# if_
377376
rewrite(TupleInt.if_(TRUE, ti, ti2)).to(ti),
378377
rewrite(TupleInt.if_(FALSE, ti, ti2)).to(ti2),
@@ -388,13 +387,16 @@ def var(cls, name: StringLike) -> TupleTupleInt: ...
388387
def __init__(self, length: IntLike, idx_fn: Callable[[Int], TupleInt]) -> None: ...
389388

390389
@classmethod
390+
@method(subsume=True)
391391
def single(cls, i: TupleInt) -> TupleTupleInt:
392392
return TupleTupleInt(Int(1), lambda _: i)
393393

394394
@classmethod
395+
@method(subsume=True)
395396
def from_vec(cls, vec: Vec[Int]) -> TupleInt:
396397
return TupleInt(vec.length(), partial(index_vec_int, vec))
397398

399+
@method(subsume=True)
398400
def __add__(self, other: TupleTupleInt) -> TupleTupleInt:
399401
return TupleTupleInt(
400402
self.length() + other.length(),
@@ -732,7 +734,7 @@ def _tuple_value(
732734
rewrite(TupleValue.EMPTY.includes(v)).to(FALSE),
733735
rewrite(TupleValue(v).includes(v)).to(TRUE),
734736
rewrite(TupleValue(v).includes(v2)).to(FALSE, ne(v).to(v2)),
735-
rewrite((ti + ti2).includes(v)).to(ti.includes(v) | ti2.includes(v)),
737+
rewrite((ti + ti2).includes(v), subsume=True).to(ti.includes(v) | ti2.includes(v)),
736738
]
737739

738740

@@ -1539,13 +1541,14 @@ def try_evaling(expr: Expr, prim_expr: i64 | Bool) -> int | bool:
15391541
egraph.run(array_api_schedule)
15401542
try:
15411543
extracted = egraph.extract(prim_expr)
1542-
except EggSmolError as exc:
1544+
# Catch base exceptions so that we catch rust panics which happen when trying to extract subsumed nodes
1545+
except BaseException as exc:
1546+
egraph.display(n_inline_leaves=1, split_primitive_outputs=True)
15431547
# Try giving some context, by showing the smallest version of the larger expression
15441548
try:
15451549
expr_extracted = egraph.extract(expr)
1546-
except EggSmolError as inner_exc:
1550+
except BaseException as inner_exc:
15471551
raise ValueError(f"Cannot simplify {expr}") from inner_exc
1548-
egraph.display(n_inline_leaves=1, split_primitive_outputs=True)
15491552
msg = f"Cannot simplify to primitive {expr_extracted}"
15501553
raise ValueError(msg) from exc
15511554
return egraph.eval(extracted)

0 commit comments

Comments
 (0)