@@ -269,7 +269,7 @@ def method(
269
269
unextractable : bool = False ,
270
270
) -> Callable [[Callable [P , EXPR ]], Callable [P , EXPR ]]:
271
271
return lambda fn : _WrappedMethod (
272
- egg_fn , cost , default , merge , on_merge , fn , preserve , mutates_self , unextractable
272
+ egg_fn , cost , default , merge , on_merge , fn , preserve , mutates_self , unextractable , False
273
273
)
274
274
275
275
@overload
@@ -404,6 +404,7 @@ def method(
404
404
on_merge : Callable [[Any , Any ], Iterable [ActionLike ]] | None = None ,
405
405
mutates_self : bool = False ,
406
406
unextractable : bool = False ,
407
+ subsume : bool = False ,
407
408
) -> Callable [[CALLABLE ], CALLABLE ]: ...
408
409
409
410
@@ -417,6 +418,7 @@ def method(
417
418
on_merge : Callable [[EXPR , EXPR ], Iterable [ActionLike ]] | None = None ,
418
419
mutates_self : bool = False ,
419
420
unextractable : bool = False ,
421
+ subsume : bool = False ,
420
422
) -> Callable [[Callable [P , EXPR ]], Callable [P , EXPR ]]: ...
421
423
422
424
@@ -430,11 +432,14 @@ def method(
430
432
preserve : bool = False ,
431
433
mutates_self : bool = False ,
432
434
unextractable : bool = False ,
435
+ subsume : bool = False ,
433
436
) -> Callable [[Callable [P , EXPR ]], Callable [P , EXPR ]]:
434
437
"""
435
438
Any method can be decorated with this to customize it's behavior. This is only supported in classes which subclass :class:`Expr`.
436
439
"""
437
- return lambda fn : _WrappedMethod (egg_fn , cost , default , merge , on_merge , fn , preserve , mutates_self , unextractable )
440
+ return lambda fn : _WrappedMethod (
441
+ egg_fn , cost , default , merge , on_merge , fn , preserve , mutates_self , unextractable , subsume
442
+ )
438
443
439
444
440
445
class _ExprMetaclass (type ):
@@ -519,7 +524,9 @@ def _generate_class_decls( # noqa: C901,PLR0912
519
524
(inner_tp ,) = v .__args__
520
525
type_ref = resolve_type_annotation (decls , inner_tp )
521
526
cls_decl .class_variables [k ] = ConstantDecl (type_ref .to_just ())
522
- _add_default_rewrite (decls , ClassVariableRef (cls_name , k ), type_ref , namespace .pop (k , None ), ruleset )
527
+ _add_default_rewrite (
528
+ decls , ClassVariableRef (cls_name , k ), type_ref , namespace .pop (k , None ), ruleset , subsume = False
529
+ )
523
530
else :
524
531
msg = f"On class { cls_name } , for attribute '{ k } ', expected a ClassVar, but got { v } "
525
532
raise NotImplementedError (msg )
@@ -542,12 +549,12 @@ def _generate_class_decls( # noqa: C901,PLR0912
542
549
if is_init and cls_name in LIT_CLASS_NAMES :
543
550
continue
544
551
match method :
545
- case _WrappedMethod (egg_fn , cost , default , merge , on_merge , fn , preserve , mutates , unextractable ):
552
+ case _WrappedMethod (egg_fn , cost , default , merge , on_merge , fn , preserve , mutates , unextractable , subsume ):
546
553
pass
547
554
case _:
548
555
egg_fn , cost , default , merge , on_merge = None , None , None , None , None
549
556
fn = method
550
- unextractable , preserve = False , False
557
+ unextractable , preserve , subsume = False , False , False
551
558
mutates = method_name in ALWAYS_MUTATES_SELF
552
559
if preserve :
553
560
cls_decl .preserved_methods [method_name ] = fn
@@ -572,7 +579,20 @@ def _generate_class_decls( # noqa: C901,PLR0912
572
579
continue
573
580
574
581
_ , add_rewrite = _fn_decl (
575
- decls , egg_fn , ref , fn , locals , default , cost , merge , on_merge , mutates , builtin , ruleset , unextractable
582
+ decls ,
583
+ egg_fn ,
584
+ ref ,
585
+ fn ,
586
+ locals ,
587
+ default ,
588
+ cost ,
589
+ merge ,
590
+ on_merge ,
591
+ mutates ,
592
+ builtin ,
593
+ ruleset = ruleset ,
594
+ unextractable = unextractable ,
595
+ subsume = subsume ,
576
596
)
577
597
578
598
if not builtin and not isinstance (ref , InitRef ) and not mutates :
@@ -602,6 +622,7 @@ def function(
602
622
builtin : bool = False ,
603
623
ruleset : Ruleset | None = None ,
604
624
use_body_as_name : bool = False ,
625
+ subsume : bool = False ,
605
626
) -> Callable [[CALLABLE ], CALLABLE ]: ...
606
627
607
628
@@ -617,6 +638,7 @@ def function(
617
638
unextractable : bool = False ,
618
639
ruleset : Ruleset | None = None ,
619
640
use_body_as_name : bool = False ,
641
+ subsume : bool = False ,
620
642
) -> Callable [[Callable [P , EXPR ]], Callable [P , EXPR ]]: ...
621
643
622
644
@@ -649,6 +671,7 @@ class _FunctionConstructor:
649
671
unextractable : bool = False
650
672
ruleset : Ruleset | None = None
651
673
use_body_as_name : bool = False
674
+ subsume : bool = False
652
675
653
676
def __call__ (self , fn : Callable [..., RuntimeExpr ]) -> RuntimeFunction :
654
677
return RuntimeFunction (* split_thunk (Thunk .fn (self .create_decls , fn )))
@@ -668,7 +691,8 @@ def create_decls(self, fn: Callable[..., RuntimeExpr]) -> tuple[Declarations, Ca
668
691
self .on_merge ,
669
692
self .mutates_first_arg ,
670
693
self .builtin ,
671
- self .ruleset ,
694
+ ruleset = self .ruleset ,
695
+ subsume = self .subsume ,
672
696
unextractable = self .unextractable ,
673
697
)
674
698
add_rewrite ()
@@ -690,6 +714,7 @@ def _fn_decl(
690
714
on_merge : Callable [[RuntimeExpr , RuntimeExpr ], Iterable [ActionLike ]] | None ,
691
715
mutates_first_arg : bool ,
692
716
is_builtin : bool ,
717
+ subsume : bool ,
693
718
ruleset : Ruleset | None = None ,
694
719
unextractable : bool = False ,
695
720
) -> tuple [CallableRef , Callable [[], None ]]:
@@ -804,7 +829,7 @@ def _fn_decl(
804
829
res_ref = ref
805
830
decls .set_function_decl (ref , decl )
806
831
res_thunk = Thunk .fn (_create_default_value , decls , ref , fn , args , ruleset )
807
- return res_ref , Thunk .fn (_add_default_rewrite_function , decls , res_ref , return_type , ruleset , res_thunk )
832
+ return res_ref , Thunk .fn (_add_default_rewrite_function , decls , res_ref , return_type , ruleset , res_thunk , subsume )
808
833
809
834
810
835
# Overload to support aritys 0-4 until variadic generic support map, so we can map from type to value
@@ -871,7 +896,7 @@ def _constant_thunk(
871
896
type_ref = resolve_type_annotation (decls , tp )
872
897
callable_ref = ConstantRef (name )
873
898
decls ._constants [name ] = ConstantDecl (type_ref .to_just (), egg_name )
874
- _add_default_rewrite (decls , callable_ref , type_ref , default_replacement , ruleset )
899
+ _add_default_rewrite (decls , callable_ref , type_ref , default_replacement , ruleset , subsume = False )
875
900
return decls , TypedExprDecl (type_ref .to_just (), CallDecl (callable_ref ))
876
901
877
902
@@ -898,15 +923,21 @@ def _add_default_rewrite_function(
898
923
res_type : TypeOrVarRef ,
899
924
ruleset : Ruleset | None ,
900
925
value_thunk : Callable [[], object ],
926
+ subsume : bool ,
901
927
) -> None :
902
928
"""
903
929
Helper functions that resolves a value thunk to create the default value.
904
930
"""
905
- _add_default_rewrite (decls , ref , res_type , value_thunk (), ruleset )
931
+ _add_default_rewrite (decls , ref , res_type , value_thunk (), ruleset , subsume )
906
932
907
933
908
934
def _add_default_rewrite (
909
- decls : Declarations , ref : CallableRef , type_ref : TypeOrVarRef , default_rewrite : object , ruleset : Ruleset | None
935
+ decls : Declarations ,
936
+ ref : CallableRef ,
937
+ type_ref : TypeOrVarRef ,
938
+ default_rewrite : object ,
939
+ ruleset : Ruleset | None ,
940
+ subsume : bool ,
910
941
) -> None :
911
942
"""
912
943
Adds a default rewrite for the callable, if the default rewrite is not None
@@ -916,7 +947,7 @@ def _add_default_rewrite(
916
947
if default_rewrite is None :
917
948
return
918
949
resolved_value = resolve_literal (type_ref , default_rewrite , Thunk .value (decls ))
919
- rewrite_decl = DefaultRewriteDecl (ref , resolved_value .__egg_typed_expr__ .expr )
950
+ rewrite_decl = DefaultRewriteDecl (ref , resolved_value .__egg_typed_expr__ .expr , subsume )
920
951
if ruleset :
921
952
ruleset_decls = ruleset ._current_egg_decls
922
953
ruleset_decl = ruleset .__egg_ruleset__
@@ -1341,8 +1372,6 @@ def saturate(
1341
1372
from .visualizer_widget import VisualizerWidget
1342
1373
1343
1374
def to_json () -> str :
1344
- if expr :
1345
- print (self .extract (expr ))
1346
1375
return self ._serialize (** kwargs ).to_json ()
1347
1376
1348
1377
egraphs = [to_json ()]
@@ -1407,6 +1436,7 @@ class _WrappedMethod(Generic[P, EXPR]):
1407
1436
preserve : bool
1408
1437
mutates_self : bool
1409
1438
unextractable : bool
1439
+ subsume : bool
1410
1440
1411
1441
def __call__ (self , * args : P .args , ** kwargs : P .kwargs ) -> EXPR :
1412
1442
msg = "We should never call a wrapped method. Did you forget to wrap the class?"
0 commit comments