@@ -526,18 +526,11 @@ def _graph_flatten(
526
526
paths : list [PathParts ] | None ,
527
527
return_variables : bool ,
528
528
) -> NodeDef [tp .Any ] | NodeRef :
529
- is_pytree_node_ = isinstance (node_impl , PytreeNodeImpl )
530
- is_graph_node_ = isinstance (node_impl , GraphNodeImpl )
531
-
532
- if not is_pytree_node_ and node in ref_index :
529
+ if node in ref_index :
533
530
return NodeRef (type (node ), ref_index [node ])
534
531
535
- # only cache graph nodes
536
- if is_graph_node_ :
537
- index = len (ref_index )
538
- ref_index [node ] = index
539
- else :
540
- index = - 1
532
+ # assign index
533
+ ref_index [node ] = index = len (ref_index )
541
534
542
535
attributes : list [
543
536
tuple [Key , Static [tp .Any ] | NodeDef [tp .Any ] | VariableDef | NodeRef [tp .Any ]]
@@ -603,7 +596,7 @@ def _graph_flatten(
603
596
type = node_impl .type , # type: ignore[arg-type]
604
597
index = index ,
605
598
outer_index = ref_outer_index [node ]
606
- if is_graph_node_ and ref_outer_index and node in ref_outer_index
599
+ if ref_outer_index and node in ref_outer_index
607
600
else None ,
608
601
attributes = tuple (attributes ),
609
602
metadata = metadata ,
@@ -646,23 +639,18 @@ def _graph_fingerprint(
646
639
ref_index : RefMap ,
647
640
new_ref_index : RefMap ,
648
641
):
649
- is_pytree_node_ = type (node_impl ) is PytreeNodeImpl
650
- is_graph_node_ = type (node_impl ) is GraphNodeImpl
651
642
652
643
append_fn (type (node ))
653
644
654
- if is_graph_node_ :
655
- append_fn (id (node ))
656
- if node in ref_index :
657
- append_fn (ref_index [node ])
658
- return
659
- elif node in new_ref_index :
660
- append_fn (new_ref_index [node ])
661
- return
662
- index = new_ref_index [node ] = ctx .next_index
663
- ctx .next_index += 1
664
- else :
665
- index = - 1
645
+ append_fn (id (node ))
646
+ if node in ref_index :
647
+ append_fn (ref_index [node ])
648
+ return
649
+ elif node in new_ref_index :
650
+ append_fn (new_ref_index [node ])
651
+ return
652
+ index = new_ref_index [node ] = ctx .next_index
653
+ ctx .next_index += 1
666
654
667
655
values , metadata = node_impl .flatten (node )
668
656
@@ -732,26 +720,20 @@ def _check_graph_fingerprint(
732
720
ref_index : RefMap ,
733
721
new_ref_index : RefMap ,
734
722
) -> bool :
735
- is_pytree_node_ = type (node_impl ) is PytreeNodeImpl
736
- is_graph_node_ = type (node_impl ) is GraphNodeImpl
737
-
738
723
if type (node ) != next (fp_iterator ):
739
724
return False
740
725
741
- if is_graph_node_ :
742
- # append_fn(id(node))
743
- if id (node ) != next (fp_iterator ):
744
- return False
745
- if node in ref_index :
746
- # append_fn(ref_index[node])
747
- return ref_index [node ] == next (fp_iterator )
748
- elif node in new_ref_index :
749
- # append_fn(new_ref_index[node])
750
- return new_ref_index [node ] == next (fp_iterator )
751
- index = new_ref_index [node ] = ctx .next_index
752
- ctx .next_index += 1
753
- else :
754
- index = - 1
726
+ # append_fn(id(node))
727
+ if id (node ) != next (fp_iterator ):
728
+ return False
729
+ if node in ref_index :
730
+ # append_fn(ref_index[node])
731
+ return ref_index [node ] == next (fp_iterator )
732
+ elif node in new_ref_index :
733
+ # append_fn(new_ref_index[node])
734
+ return new_ref_index [node ] == next (fp_iterator )
735
+ index = new_ref_index [node ] = ctx .next_index
736
+ ctx .next_index += 1
755
737
756
738
values , metadata = node_impl .flatten (node )
757
739
@@ -993,6 +975,7 @@ def _get_children() -> list[tuple[Key, tp.Any]]:
993
975
# if the node type does not support the creation of an empty object it means
994
976
# that it cannot reference itself, so we can create its children first
995
977
node = node_impl .unflatten (_get_children (), nodedef .metadata )
978
+ index_ref [nodedef .index ] = node
996
979
997
980
return node
998
981
0 commit comments