Skip to content

Commit 8bf63ad

Browse files
committed
[nnx] add more flaxlib types
1 parent 2a5df33 commit 8bf63ad

File tree

8 files changed

+459
-200
lines changed

8 files changed

+459
-200
lines changed

benchmarks/nnx_simple_training.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -116,14 +116,14 @@ def test_step_nnx(model: MLP, batch):
116116
loss = jnp.mean((y - y_pred) ** 2)
117117
return {'loss': loss}
118118

119-
cached_train_step_nnx = nnx.cached_partial(train_step_nnx, model, optimizer)
120-
cached_test_step_nnx = nnx.cached_partial(test_step_nnx, model)
119+
# cached_train_step_nnx = nnx.cached_partial(train_step_nnx, model, optimizer)
120+
# cached_test_step_nnx = nnx.cached_partial(test_step_nnx, model)
121121

122122
for step, batch in enumerate(dataset(X, Y, batch_size)):
123-
cached_train_step_nnx(batch)
123+
train_step_nnx(model, optimizer, batch)
124124

125125
if step % 1000 == 0:
126-
logs = cached_test_step_nnx((X, Y))
126+
logs = test_step_nnx(model, (X, Y))
127127

128128
if step >= total_steps - 1:
129129
break

flax/nnx/graph.py

Lines changed: 83 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import threading
2121
import typing as tp
2222

23+
from flax import config
2324
from flax.nnx import filterlib, reprlib, traversals, variablelib
2425
from flax.nnx import statelib
2526
from flax.nnx.proxy_caller import (
@@ -63,27 +64,47 @@ def is_state_leaf(x: tp.Any) -> tpe.TypeGuard[StateLeaf]:
6364
def is_node_leaf(x: tp.Any) -> tpe.TypeGuard[NodeLeaf]:
6465
return isinstance(x, Variable)
6566

67+
class IndexMap(dict[Index, tp.Any]):
68+
@staticmethod
69+
def from_refmap(refmap: RefMap) -> IndexMap:
70+
indexmap = IndexMap()
71+
indexmap.update((index, value) for value, index in refmap.items())
72+
return indexmap
73+
74+
if config.flax_use_flaxlib:
75+
import flaxlib
76+
77+
globals()['IndexMap'] = flaxlib.IndexMap
78+
6679

6780
# RefMap = dict
68-
class RefMap(tp.MutableMapping[A, B]):
81+
class RefMap(tp.MutableMapping[tp.Any, int]):
6982
"""A mapping that hashes keys by their identity."""
7083

7184
def __init__(
72-
self,
73-
mapping: tp.Mapping[A, B] | tp.Iterable[tuple[A, B]] | None = None,
74-
/,
85+
self,
86+
mapping: tp.Mapping[tp.Any, int]
87+
| tp.Iterable[tuple[tp.Any, int]]
88+
| None = None,
89+
/,
7590
):
76-
self._mapping: dict[int, tuple[A, B]] = dict()
91+
self._mapping: dict[int, tuple[tp.Any, int]] = dict()
7792
if mapping is not None:
7893
self.update(mapping)
7994

80-
def __getitem__(self, key: A) -> B:
95+
@staticmethod
96+
def from_indexmap(indexmap: IndexMap) -> RefMap:
97+
refmap = RefMap()
98+
refmap.update((value, index) for index, value in indexmap.items())
99+
return refmap
100+
101+
def __getitem__(self, key: tp.Any) -> int:
81102
return self._mapping[id(key)][1]
82103

83-
def __setitem__(self, key: A, value: B):
104+
def __setitem__(self, key: tp.Any, value: int):
84105
self._mapping[id(key)] = (key, value)
85106

86-
def __delitem__(self, key: A):
107+
def __delitem__(self, key: tp.Any):
87108
del self._mapping[id(key)]
88109

89110
def __len__(self) -> int:
@@ -92,14 +113,20 @@ def __len__(self) -> int:
92113
def __contains__(self, key: tp.Any) -> bool:
93114
return id(key) in self._mapping
94115

95-
def __iter__(self) -> tp.Iterator[A]:
116+
def __iter__(self) -> tp.Iterator[tp.Any]:
96117
for key, _ in self._mapping.values():
97118
yield key
98119

99-
def items(self) -> tp.ItemsView[A, B]:
120+
def items(self) -> tp.ItemsView[tp.Any, int]:
100121
return self._mapping.values() # type: ignore
101122

102123

124+
if config.flax_use_flaxlib:
125+
import flaxlib
126+
127+
globals()['RefMap'] = flaxlib.RefMap
128+
129+
103130
@dataclasses.dataclass(frozen=True, slots=True)
104131
class NodeImplBase(tp.Generic[Node, Leaf, AuxData]):
105132
type: type[Node]
@@ -258,6 +285,11 @@ def __treescope_repr__(self, path, subtree_renderer):
258285
subtree_renderer=subtree_renderer,
259286
)
260287

288+
if config.flax_use_flaxlib:
289+
import flaxlib
290+
291+
jax.tree_util.register_static(flaxlib.NodeRef)
292+
globals()['NodeRef'] = flaxlib.NodeRef
261293

262294
@jax.tree_util.register_static
263295
@dataclasses.dataclass(frozen=True, repr=False)
@@ -299,6 +331,11 @@ def __treescope_repr__(self, path, subtree_renderer):
299331
subtree_renderer=subtree_renderer,
300332
)
301333

334+
if config.flax_use_flaxlib:
335+
import flaxlib
336+
337+
jax.tree_util.register_static(flaxlib.VariableDef)
338+
globals()['VariableDef'] = flaxlib.VariableDef
302339

303340
@jax.tree_util.register_static
304341
@dataclasses.dataclass(frozen=True, repr=False, slots=True)
@@ -331,9 +368,6 @@ def with_same_outer_index(self) -> NodeDef[Node]:
331368
metadata=self.metadata,
332369
)
333370

334-
def replace(self, **kwargs):
335-
return dataclasses.replace(self, **kwargs)
336-
337371
def __nnx_repr__(self):
338372
yield reprlib.Object(type=type(self))
339373

@@ -358,6 +392,13 @@ def __treescope_repr__(self, path, subtree_renderer):
358392
)
359393

360394

395+
if config.flax_use_flaxlib:
396+
import flaxlib
397+
398+
jax.tree_util.register_static(flaxlib.NodeDef)
399+
globals()['NodeDef'] = flaxlib.NodeDef
400+
401+
361402
@jax.tree_util.register_static
362403
@dataclasses.dataclass(frozen=True, slots=True)
363404
class ArrayAttr:
@@ -548,23 +589,24 @@ def _graph_flatten(
548589
node: Node,
549590
node_impl: NodeImpl[Node, Leaf, AuxData] | None,
550591
path: list[Key] | None,
551-
ref_index: RefMap[tp.Any, int],
592+
ref_index: RefMap,
552593
ref_outer_index: RefMap | None,
553594
nodes: list[NodeDef[tp.Any] | VariableDef[tp.Any] | NodeRef[tp.Any]],
554595
attributes: list[tuple[Key, NodeAttr | ArrayAttr | Static[tp.Any]]],
555596
leaves: list[StateLeaf | Variable[tp.Any] | jax.Array | np.ndarray],
556597
paths: list[PathParts] | None,
557598
return_variables: bool,
558599
) -> None:
559-
is_pytree_node_ = isinstance(node_impl, PytreeNodeImpl)
560-
is_graph_node_ = isinstance(node_impl, GraphNodeImpl)
561-
is_variable = isinstance(node, Variable)
600+
is_pytree_node_ = type(node_impl) is PytreeNodeImpl
562601

563602
index: int | None
564603
if not is_pytree_node_ and node in ref_index:
565604
nodes.append(NodeRef(index := ref_index[node]))
566605
return
567606

607+
is_graph_node_ = type(node_impl) is GraphNodeImpl
608+
is_variable = isinstance(node, Variable)
609+
568610
# only cache graph nodes
569611
if is_graph_node_ or is_variable:
570612
index = len(ref_index)
@@ -600,13 +642,13 @@ def _graph_flatten(
600642
values, metadata = node_impl.flatten(node)
601643
num_attributes = len(values)
602644
nodedef = NodeDef(
603-
type=node_impl.type,
604-
index=index,
605-
outer_index=ref_outer_index[node]
645+
node_impl.type,
646+
index,
647+
ref_outer_index[node]
606648
if is_graph_node_ and ref_outer_index and node in ref_outer_index
607649
else None,
608-
num_attributes=num_attributes,
609-
metadata=metadata,
650+
num_attributes,
651+
metadata,
610652
)
611653
nodes.append(nodedef)
612654

@@ -866,8 +908,8 @@ def unflatten(
866908
state: State[Key, tp.Any] | FlatState[tp.Any] | list[tp.Any],
867909
/,
868910
*,
869-
index_ref: dict[Index, tp.Any] | None = None,
870-
outer_index_outer_ref: dict[Index, tp.Any] | None = None,
911+
index_ref: IndexMap | None = None,
912+
outer_index_outer_ref: IndexMap | None = None,
871913
) -> Node:
872914
"""Unflattens a graphdef into a node with the given state.
873915
@@ -893,7 +935,7 @@ def unflatten(
893935
else:
894936
raise ValueError(f'Unsupported state type: {type(state)}')
895937
if index_ref is None:
896-
index_ref = {}
938+
index_ref = IndexMap()
897939

898940
if len(leaves) != graphdef.num_leaves:
899941
raise ValueError(
@@ -937,8 +979,8 @@ def _graph_unflatten(
937979
tuple[Key, NodeAttr | ArrayAttr | Static[tp.Any]]
938980
],
939981
leaves_iter: tp.Iterator[tp.Any],
940-
index_ref: dict[Index, tp.Any],
941-
outer_index_outer_ref: dict[Index, tp.Any] | None,
982+
index_ref: IndexMap,
983+
outer_index_outer_ref: IndexMap | None,
942984
) -> Node:
943985
"""Recursive helper for graph_unflatten.
944986
@@ -1002,7 +1044,7 @@ def make_variable(key, variabledef: VariableDef[Variable]) -> tp.Any:
10021044
assert type(nodedef) is NodeDef
10031045
if node_impl is None:
10041046
raise RuntimeError(f'Unsupported type: {nodedef.type}, this is a bug.')
1005-
if nodedef.index in index_ref:
1047+
if nodedef.index is not None and nodedef.index in index_ref:
10061048
raise RuntimeError(f'GraphDef index {nodedef.index} already used.')
10071049

10081050
def _get_children() -> list[tuple[Key, tp.Any]]:
@@ -1215,7 +1257,7 @@ class StaticCache(tp.NamedTuple):
12151257
paths: tuple[PathParts, ...]
12161258
variables: list[Variable[tp.Any]]
12171259
new_ref_index: RefMap
1218-
new_index_ref: dict[Index, tp.Any]
1260+
new_index_ref: IndexMap
12191261

12201262
@staticmethod
12211263
def create(
@@ -1224,7 +1266,7 @@ def create(
12241266
variables: list[Variable[tp.Any]],
12251267
new_ref_index: RefMap,
12261268
):
1227-
new_index_ref = {index: obj for obj, index in new_ref_index.items()}
1269+
new_index_ref = IndexMap.from_refmap(new_ref_index)
12281270
final_graphdef: GraphDef[tp.Any]
12291271
final_graphdef = graphdef.with_same_outer_index()
12301272
return StaticCache(
@@ -1244,15 +1286,15 @@ class GraphContext(threading.local):
12441286
)
12451287
ref_index_stack: list[SplitContext] = dataclasses.field(default_factory=list)
12461288
index_ref_stack: list[MergeContext] = dataclasses.field(default_factory=list)
1247-
tmp_static_cache: RefMap[tp.Any, StaticCache] | None = None
1289+
tmp_static_cache: RefMap | None = None
12481290
caching: bool = False
12491291

12501292

12511293
GRAPH_CONTEXT = GraphContext()
12521294

12531295

12541296
@contextlib.contextmanager
1255-
def static_cache(static_cache: RefMap[tp.Any, StaticCache]):
1297+
def static_cache(static_cache: RefMap):
12561298
if GRAPH_CONTEXT.caching:
12571299
yield
12581300
return
@@ -1315,9 +1357,9 @@ def _cached_partial(f: tp.Callable[..., tp.Any], *cached_args):
13151357
Returns:
13161358
A partial function expecting the remaining arguments to the original function.
13171359
"""
1318-
cache: RefMap[tp.Any, StaticCache] = RefMap()
1360+
cache: RefMap = RefMap()
13191361
original_ref_index: RefMap = RefMap()
1320-
index_ref: dict[Index, tp.Any] = {}
1362+
index_ref: IndexMap = IndexMap()
13211363
cached_ref_index: RefMap = RefMap()
13221364

13231365
def create_static_cache(x):
@@ -1543,7 +1585,7 @@ def split_context(ctxtag: tp.Hashable | None = None):
15431585
@dataclasses.dataclass
15441586
class MergeContext:
15451587
ctxtag: tp.Hashable | None
1546-
index_ref: dict[Index, tp.Any]
1588+
index_ref: IndexMap
15471589
is_inner: bool | None
15481590

15491591
def merge(
@@ -1669,7 +1711,7 @@ def merge_context(): ...
16691711
def merge_context(ctxtag: tp.Hashable | None, inner: bool | None): ...
16701712
@contextlib.contextmanager
16711713
def merge_context(ctxtag: tp.Hashable | None = None, inner: bool | None = None):
1672-
GRAPH_CONTEXT.index_ref_stack.append(MergeContext(ctxtag, {}, inner))
1714+
GRAPH_CONTEXT.index_ref_stack.append(MergeContext(ctxtag, IndexMap(), inner))
16731715

16741716
try:
16751717
yield GRAPH_CONTEXT.index_ref_stack[-1]
@@ -1692,11 +1734,11 @@ class UpdateContext:
16921734

16931735
tag: tp.Hashable
16941736
outer_ref_outer_index: RefMap | None
1695-
outer_index_inner_ref: dict[Index, tp.Any] | None
1737+
outer_index_inner_ref: IndexMap | None
16961738
# reverse caches
1697-
outer_index_outer_ref: dict[Index, tp.Any] | None
1739+
outer_index_outer_ref: IndexMap | None
16981740
inner_ref_outer_index: RefMap | None
1699-
static_cache: RefMap[tp.Any, StaticCache] | None
1741+
static_cache: RefMap | None
17001742

17011743
# define hash and eq to make this an opaque object
17021744
def __hash__(self):
@@ -1717,13 +1759,11 @@ def flatten_end(self, ref_index: RefMap):
17171759
self.outer_index_inner_ref = None
17181760
self.inner_ref_outer_index = None
17191761

1720-
def unflatten_end(self, index_ref: dict[Index, tp.Any], inner_merge: bool):
1762+
def unflatten_end(self, index_ref: IndexMap, inner_merge: bool):
17211763
if inner_merge:
17221764
# inner merge (2)
17231765
self.outer_index_inner_ref = index_ref
1724-
self.inner_ref_outer_index = RefMap(
1725-
(obj, index) for index, obj in index_ref.items()
1726-
)
1766+
self.inner_ref_outer_index = RefMap.from_indexmap(index_ref)
17271767

17281768

17291769
@dataclasses.dataclass

flaxlib_src/src/flaxlib/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,7 @@
1313
# limitations under the License.
1414

1515
from .flaxlib_cpp import RefMap as RefMap
16-
from .flaxlib_cpp import _graph_fingerprint as _graph_fingerprint
16+
from .flaxlib_cpp import IndexMap as IndexMap
17+
from .flaxlib_cpp import NodeDef as NodeDef
18+
from .flaxlib_cpp import VariableDef as VariableDef
19+
from .flaxlib_cpp import NodeRef as NodeRef

0 commit comments

Comments
 (0)