Skip to content

Commit 9c8f936

Browse files
Cristian GarciaFlax Authors
authored andcommitted
Add pytrees as graph nodes again.
PiperOrigin-RevId: 734367472
1 parent e3789de commit 9c8f936

File tree

2 files changed

+28
-45
lines changed

2 files changed

+28
-45
lines changed

flax/nnx/graph.py

Lines changed: 25 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -526,18 +526,11 @@ def _graph_flatten(
526526
paths: list[PathParts] | None,
527527
return_variables: bool,
528528
) -> NodeDef[tp.Any] | NodeRef:
529-
is_pytree_node_ = isinstance(node_impl, PytreeNodeImpl)
530-
is_graph_node_ = isinstance(node_impl, GraphNodeImpl)
531-
532-
if not is_pytree_node_ and node in ref_index:
529+
if node in ref_index:
533530
return NodeRef(type(node), ref_index[node])
534531

535-
# only cache graph nodes
536-
if is_graph_node_:
537-
index = len(ref_index)
538-
ref_index[node] = index
539-
else:
540-
index = -1
532+
# assign index
533+
ref_index[node] = index = len(ref_index)
541534

542535
attributes: list[
543536
tuple[Key, Static[tp.Any] | NodeDef[tp.Any] | VariableDef | NodeRef[tp.Any]]
@@ -603,7 +596,7 @@ def _graph_flatten(
603596
type=node_impl.type, # type: ignore[arg-type]
604597
index=index,
605598
outer_index=ref_outer_index[node]
606-
if is_graph_node_ and ref_outer_index and node in ref_outer_index
599+
if ref_outer_index and node in ref_outer_index
607600
else None,
608601
attributes=tuple(attributes),
609602
metadata=metadata,
@@ -646,23 +639,18 @@ def _graph_fingerprint(
646639
ref_index: RefMap,
647640
new_ref_index: RefMap,
648641
):
649-
is_pytree_node_ = type(node_impl) is PytreeNodeImpl
650-
is_graph_node_ = type(node_impl) is GraphNodeImpl
651642

652643
append_fn(type(node))
653644

654-
if is_graph_node_:
655-
append_fn(id(node))
656-
if node in ref_index:
657-
append_fn(ref_index[node])
658-
return
659-
elif node in new_ref_index:
660-
append_fn(new_ref_index[node])
661-
return
662-
index = new_ref_index[node] = ctx.next_index
663-
ctx.next_index += 1
664-
else:
665-
index = -1
645+
append_fn(id(node))
646+
if node in ref_index:
647+
append_fn(ref_index[node])
648+
return
649+
elif node in new_ref_index:
650+
append_fn(new_ref_index[node])
651+
return
652+
index = new_ref_index[node] = ctx.next_index
653+
ctx.next_index += 1
666654

667655
values, metadata = node_impl.flatten(node)
668656

@@ -732,26 +720,20 @@ def _check_graph_fingerprint(
732720
ref_index: RefMap,
733721
new_ref_index: RefMap,
734722
) -> bool:
735-
is_pytree_node_ = type(node_impl) is PytreeNodeImpl
736-
is_graph_node_ = type(node_impl) is GraphNodeImpl
737-
738723
if type(node) != next(fp_iterator):
739724
return False
740725

741-
if is_graph_node_:
742-
# append_fn(id(node))
743-
if id(node) != next(fp_iterator):
744-
return False
745-
if node in ref_index:
746-
# append_fn(ref_index[node])
747-
return ref_index[node] == next(fp_iterator)
748-
elif node in new_ref_index:
749-
# append_fn(new_ref_index[node])
750-
return new_ref_index[node] == next(fp_iterator)
751-
index = new_ref_index[node] = ctx.next_index
752-
ctx.next_index += 1
753-
else:
754-
index = -1
726+
# append_fn(id(node))
727+
if id(node) != next(fp_iterator):
728+
return False
729+
if node in ref_index:
730+
# append_fn(ref_index[node])
731+
return ref_index[node] == next(fp_iterator)
732+
elif node in new_ref_index:
733+
# append_fn(new_ref_index[node])
734+
return new_ref_index[node] == next(fp_iterator)
735+
index = new_ref_index[node] = ctx.next_index
736+
ctx.next_index += 1
755737

756738
values, metadata = node_impl.flatten(node)
757739

@@ -993,6 +975,7 @@ def _get_children() -> list[tuple[Key, tp.Any]]:
993975
# if the node type does not support the creation of an empty object it means
994976
# that it cannot reference itself, so we can create its children first
995977
node = node_impl.unflatten(_get_children(), nodedef.metadata)
978+
index_ref[nodedef.index] = node
996979

997980
return node
998981

tests/nnx/graph_utils_test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def test_flatten(self):
6969
assert flat_state[0][1].value == 2
7070
assert flat_state[1][1].value == 4
7171

72-
assert len(refmap) == 2
72+
assert len(refmap) == 4
7373
assert a['b'] in refmap
7474
assert g[3] in refmap
7575

@@ -85,7 +85,7 @@ def test_flatten_no_paths(self):
8585
assert flat_state[0] == 2
8686
assert flat_state[1] == 4
8787

88-
assert len(refmap) == 2
88+
assert len(refmap) == 4
8989
assert a['b'] in refmap
9090
assert g[3] in refmap
9191

@@ -116,7 +116,7 @@ def test_unflatten_pytree(self):
116116
graphdef, state = nnx.split(g)
117117
g = nnx.merge(graphdef, state)
118118

119-
assert g[0] is not g[2]
119+
assert g[0] is g[2]
120120

121121
def test_unflatten_empty(self):
122122
a = Dict({'a': 1, 'b': nnx.Param(2)})

0 commit comments

Comments
 (0)