Skip to content

Commit 1dddcfe

Browse files
committed
[nnx] add more flaxlib types
1 parent 33f82b5 commit 1dddcfe

File tree

8 files changed

+492
-200
lines changed

8 files changed

+492
-200
lines changed

benchmarks/nnx_simple_training.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -116,14 +116,14 @@ def test_step_nnx(model: MLP, batch):
116116
loss = jnp.mean((y - y_pred) ** 2)
117117
return {'loss': loss}
118118

119-
cached_train_step_nnx = nnx.cached_partial(train_step_nnx, model, optimizer)
120-
cached_test_step_nnx = nnx.cached_partial(test_step_nnx, model)
119+
# cached_train_step_nnx = nnx.cached_partial(train_step_nnx, model, optimizer)
120+
# cached_test_step_nnx = nnx.cached_partial(test_step_nnx, model)
121121

122122
for step, batch in enumerate(dataset(X, Y, batch_size)):
123-
cached_train_step_nnx(batch)
123+
train_step_nnx(model, optimizer, batch)
124124

125125
if step % 1000 == 0:
126-
logs = cached_test_step_nnx((X, Y))
126+
logs = test_step_nnx(model, (X, Y))
127127

128128
if step >= total_steps - 1:
129129
break

flax/nnx/graph.py

Lines changed: 83 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import threading
2121
import typing as tp
2222

23+
from flax import config
2324
from flax.nnx import filterlib, reprlib, traversals, variablelib
2425
from flax.nnx import statelib
2526
from flax.nnx.proxy_caller import (
@@ -63,27 +64,47 @@ def is_state_leaf(x: tp.Any) -> tpe.TypeGuard[StateLeaf]:
6364
def is_node_leaf(x: tp.Any) -> tpe.TypeGuard[NodeLeaf]:
6465
return isinstance(x, Variable)
6566

67+
class IndexMap(dict[Index, tp.Any]):
68+
@staticmethod
69+
def from_refmap(refmap: RefMap) -> IndexMap:
70+
indexmap = IndexMap()
71+
indexmap.update((index, value) for value, index in refmap.items())
72+
return indexmap
73+
74+
if config.flax_use_flaxlib:
75+
import flaxlib
76+
77+
globals()['IndexMap'] = flaxlib.IndexMap
78+
6679

6780
# RefMap = dict
68-
class RefMap(tp.MutableMapping[A, B]):
81+
class RefMap(tp.MutableMapping[tp.Any, int]):
6982
"""A mapping that hashes keys by their identity."""
7083

7184
def __init__(
72-
self,
73-
mapping: tp.Mapping[A, B] | tp.Iterable[tuple[A, B]] | None = None,
74-
/,
85+
self,
86+
mapping: tp.Mapping[tp.Any, int]
87+
| tp.Iterable[tuple[tp.Any, int]]
88+
| None = None,
89+
/,
7590
):
76-
self._mapping: dict[int, tuple[A, B]] = dict()
91+
self._mapping: dict[int, tuple[tp.Any, int]] = dict()
7792
if mapping is not None:
7893
self.update(mapping)
7994

80-
def __getitem__(self, key: A) -> B:
95+
@staticmethod
96+
def from_indexmap(indexmap: IndexMap) -> RefMap:
97+
refmap = RefMap()
98+
refmap.update((value, index) for index, value in indexmap.items())
99+
return refmap
100+
101+
def __getitem__(self, key: tp.Any) -> int:
81102
return self._mapping[id(key)][1]
82103

83-
def __setitem__(self, key: A, value: B):
104+
def __setitem__(self, key: tp.Any, value: int):
84105
self._mapping[id(key)] = (key, value)
85106

86-
def __delitem__(self, key: A):
107+
def __delitem__(self, key: tp.Any):
87108
del self._mapping[id(key)]
88109

89110
def __len__(self) -> int:
@@ -92,14 +113,20 @@ def __len__(self) -> int:
92113
def __contains__(self, key: tp.Any) -> bool:
93114
return id(key) in self._mapping
94115

95-
def __iter__(self) -> tp.Iterator[A]:
116+
def __iter__(self) -> tp.Iterator[tp.Any]:
96117
for key, _ in self._mapping.values():
97118
yield key
98119

99-
def items(self) -> tp.ItemsView[A, B]:
120+
def items(self) -> tp.ItemsView[tp.Any, int]:
100121
return self._mapping.values() # type: ignore
101122

102123

124+
if config.flax_use_flaxlib:
125+
import flaxlib
126+
127+
globals()['RefMap'] = flaxlib.RefMap
128+
129+
103130
@dataclasses.dataclass(frozen=True, slots=True)
104131
class NodeImplBase(tp.Generic[Node, Leaf, AuxData]):
105132
type: type[Node]
@@ -258,6 +285,11 @@ def __treescope_repr__(self, path, subtree_renderer):
258285
subtree_renderer=subtree_renderer,
259286
)
260287

288+
if config.flax_use_flaxlib:
289+
import flaxlib
290+
291+
jax.tree_util.register_static(flaxlib.NodeRef)
292+
globals()['NodeRef'] = flaxlib.NodeRef
261293

262294
@jax.tree_util.register_static
263295
@dataclasses.dataclass(frozen=True, repr=False)
@@ -299,6 +331,11 @@ def __treescope_repr__(self, path, subtree_renderer):
299331
subtree_renderer=subtree_renderer,
300332
)
301333

334+
if config.flax_use_flaxlib:
335+
import flaxlib
336+
337+
jax.tree_util.register_static(flaxlib.VariableDef)
338+
globals()['VariableDef'] = flaxlib.VariableDef
302339

303340
@jax.tree_util.register_static
304341
@dataclasses.dataclass(frozen=True, repr=False, slots=True)
@@ -331,9 +368,6 @@ def with_same_outer_index(self) -> NodeDef[Node]:
331368
metadata=self.metadata,
332369
)
333370

334-
def replace(self, **kwargs):
335-
return dataclasses.replace(self, **kwargs)
336-
337371
def __nnx_repr__(self):
338372
yield reprlib.Object(type=type(self))
339373

@@ -358,6 +392,13 @@ def __treescope_repr__(self, path, subtree_renderer):
358392
)
359393

360394

395+
if config.flax_use_flaxlib:
396+
import flaxlib
397+
398+
jax.tree_util.register_static(flaxlib.NodeDef)
399+
globals()['NodeDef'] = flaxlib.NodeDef
400+
401+
361402
@jax.tree_util.register_static
362403
@dataclasses.dataclass(frozen=True, slots=True)
363404
class ArrayAttr:
@@ -548,22 +589,23 @@ def _graph_flatten(
548589
node: Node,
549590
node_impl: NodeImpl[Node, Leaf, AuxData] | None,
550591
path: list[Key] | None,
551-
ref_index: RefMap[tp.Any, int],
592+
ref_index: RefMap,
552593
ref_outer_index: RefMap | None,
553594
nodes: list[NodeDef[tp.Any] | VariableDef[tp.Any] | NodeRef[tp.Any]],
554595
attributes: list[tuple[Key, NodeAttr | ArrayAttr | Static[tp.Any]]],
555596
leaves: list[StateLeaf | Variable[tp.Any] | jax.Array | np.ndarray],
556597
paths: list[PathParts] | None,
557598
return_variables: bool,
558599
) -> None:
559-
is_pytree_node_ = isinstance(node_impl, PytreeNodeImpl)
560-
is_graph_node_ = isinstance(node_impl, GraphNodeImpl)
561-
is_variable = isinstance(node, Variable)
600+
is_pytree_node_ = type(node_impl) is PytreeNodeImpl
562601

563602
if not is_pytree_node_ and node in ref_index:
564603
nodes.append(NodeRef(index := ref_index[node]))
565604
return
566605

606+
is_graph_node_ = type(node_impl) is GraphNodeImpl
607+
is_variable = isinstance(node, Variable)
608+
567609
# only cache graph nodes
568610
if is_graph_node_ or is_variable:
569611
index = len(ref_index)
@@ -599,13 +641,13 @@ def _graph_flatten(
599641
values, metadata = node_impl.flatten(node)
600642
num_attributes = len(values)
601643
nodedef = NodeDef(
602-
type=node_impl.type,
603-
index=index,
604-
outer_index=ref_outer_index[node]
644+
node_impl.type,
645+
index,
646+
ref_outer_index[node]
605647
if is_graph_node_ and ref_outer_index and node in ref_outer_index
606648
else None,
607-
num_attributes=num_attributes,
608-
metadata=metadata,
649+
num_attributes,
650+
metadata,
609651
)
610652
nodes.append(nodedef)
611653

@@ -865,8 +907,8 @@ def unflatten(
865907
state: State[Key, tp.Any] | FlatState[tp.Any] | list[tp.Any],
866908
/,
867909
*,
868-
index_ref: dict[Index, tp.Any] | None = None,
869-
outer_index_outer_ref: dict[Index, tp.Any] | None = None,
910+
index_ref: IndexMap | None = None,
911+
outer_index_outer_ref: IndexMap | None = None,
870912
) -> Node:
871913
"""Unflattens a graphdef into a node with the given state.
872914
@@ -892,7 +934,7 @@ def unflatten(
892934
else:
893935
raise ValueError(f'Unsupported state type: {type(state)}')
894936
if index_ref is None:
895-
index_ref = {}
937+
index_ref = IndexMap()
896938

897939
if len(leaves) != graphdef.num_leaves:
898940
raise ValueError(
@@ -936,8 +978,8 @@ def _graph_unflatten(
936978
tuple[Key, NodeAttr | ArrayAttr | Static[tp.Any]]
937979
],
938980
leaves_iter: tp.Iterator[tp.Any],
939-
index_ref: dict[Index, tp.Any],
940-
outer_index_outer_ref: dict[Index, tp.Any] | None,
981+
index_ref: IndexMap,
982+
outer_index_outer_ref: IndexMap | None,
941983
) -> Node:
942984
"""Recursive helper for graph_unflatten.
943985
@@ -1001,7 +1043,7 @@ def make_variable(key, variabledef: VariableDef[Variable]) -> tp.Any:
10011043
assert type(nodedef) is NodeDef
10021044
if node_impl is None:
10031045
raise RuntimeError(f'Unsupported type: {nodedef.type}, this is a bug.')
1004-
if nodedef.index in index_ref:
1046+
if nodedef.index is not None and nodedef.index in index_ref:
10051047
raise RuntimeError(f'GraphDef index {nodedef.index} already used.')
10061048

10071049
def _get_children() -> list[tuple[Key, tp.Any]]:
@@ -1214,7 +1256,7 @@ class StaticCache(tp.NamedTuple):
12141256
paths: tuple[PathParts, ...]
12151257
variables: list[Variable[tp.Any]]
12161258
new_ref_index: RefMap
1217-
new_index_ref: dict[Index, tp.Any]
1259+
new_index_ref: IndexMap
12181260

12191261
@staticmethod
12201262
def create(
@@ -1223,7 +1265,7 @@ def create(
12231265
variables: list[Variable[tp.Any]],
12241266
new_ref_index: RefMap,
12251267
):
1226-
new_index_ref = {index: obj for obj, index in new_ref_index.items()}
1268+
new_index_ref = IndexMap.from_refmap(new_ref_index)
12271269
final_graphdef: GraphDef[tp.Any]
12281270
final_graphdef = graphdef.with_same_outer_index()
12291271
return StaticCache(
@@ -1243,15 +1285,15 @@ class GraphContext(threading.local):
12431285
)
12441286
ref_index_stack: list[SplitContext] = dataclasses.field(default_factory=list)
12451287
index_ref_stack: list[MergeContext] = dataclasses.field(default_factory=list)
1246-
tmp_static_cache: RefMap[tp.Any, StaticCache] | None = None
1288+
tmp_static_cache: RefMap | None = None
12471289
caching: bool = False
12481290

12491291

12501292
GRAPH_CONTEXT = GraphContext()
12511293

12521294

12531295
@contextlib.contextmanager
1254-
def static_cache(static_cache: RefMap[tp.Any, StaticCache]):
1296+
def static_cache(static_cache: RefMap):
12551297
if GRAPH_CONTEXT.caching:
12561298
yield
12571299
return
@@ -1314,9 +1356,9 @@ def _cached_partial(f: tp.Callable[..., tp.Any], *cached_args):
13141356
Returns:
13151357
A partial function expecting the remaining arguments to the original function.
13161358
"""
1317-
cache: RefMap[tp.Any, StaticCache] = RefMap()
1359+
cache: RefMap = RefMap()
13181360
original_ref_index: RefMap = RefMap()
1319-
index_ref: dict[Index, tp.Any] = {}
1361+
index_ref: IndexMap = IndexMap()
13201362
cached_ref_index: RefMap = RefMap()
13211363

13221364
def create_static_cache(x):
@@ -1542,7 +1584,7 @@ def split_context(ctxtag: tp.Hashable | None = None):
15421584
@dataclasses.dataclass
15431585
class MergeContext:
15441586
ctxtag: tp.Hashable | None
1545-
index_ref: dict[Index, tp.Any]
1587+
index_ref: IndexMap
15461588
is_inner: bool | None
15471589

15481590
def merge(
@@ -1668,7 +1710,7 @@ def merge_context(): ...
16681710
def merge_context(ctxtag: tp.Hashable | None, inner: bool | None): ...
16691711
@contextlib.contextmanager
16701712
def merge_context(ctxtag: tp.Hashable | None = None, inner: bool | None = None):
1671-
GRAPH_CONTEXT.index_ref_stack.append(MergeContext(ctxtag, {}, inner))
1713+
GRAPH_CONTEXT.index_ref_stack.append(MergeContext(ctxtag, IndexMap(), inner))
16721714

16731715
try:
16741716
yield GRAPH_CONTEXT.index_ref_stack[-1]
@@ -1691,11 +1733,11 @@ class UpdateContext:
16911733

16921734
tag: tp.Hashable
16931735
outer_ref_outer_index: RefMap | None
1694-
outer_index_inner_ref: dict[Index, tp.Any] | None
1736+
outer_index_inner_ref: IndexMap | None
16951737
# reverse caches
1696-
outer_index_outer_ref: dict[Index, tp.Any] | None
1738+
outer_index_outer_ref: IndexMap | None
16971739
inner_ref_outer_index: RefMap | None
1698-
static_cache: RefMap[tp.Any, StaticCache] | None
1740+
static_cache: RefMap | None
16991741

17001742
# define hash and eq to make this an opaque object
17011743
def __hash__(self):
@@ -1716,13 +1758,11 @@ def flatten_end(self, ref_index: RefMap):
17161758
self.outer_index_inner_ref = None
17171759
self.inner_ref_outer_index = None
17181760

1719-
def unflatten_end(self, index_ref: dict[Index, tp.Any], inner_merge: bool):
1761+
def unflatten_end(self, index_ref: IndexMap, inner_merge: bool):
17201762
if inner_merge:
17211763
# inner merge (2)
17221764
self.outer_index_inner_ref = index_ref
1723-
self.inner_ref_outer_index = RefMap(
1724-
(obj, index) for index, obj in index_ref.items()
1725-
)
1765+
self.inner_ref_outer_index = RefMap.from_indexmap(index_ref)
17261766

17271767

17281768
@dataclasses.dataclass

flaxlib_src/src/flaxlib/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,7 @@
1313
# limitations under the License.
1414

1515
from .flaxlib_cpp import RefMap as RefMap
16-
from .flaxlib_cpp import _graph_fingerprint as _graph_fingerprint
16+
from .flaxlib_cpp import IndexMap as IndexMap
17+
from .flaxlib_cpp import NodeDef as NodeDef
18+
from .flaxlib_cpp import VariableDef as VariableDef
19+
from .flaxlib_cpp import NodeRef as NodeRef

0 commit comments

Comments
 (0)