20
20
import threading
21
21
import typing as tp
22
22
23
+ from flax import config
23
24
from flax .nnx import filterlib , reprlib , traversals , variablelib
24
25
from flax .nnx import statelib
25
26
from flax .nnx .proxy_caller import (
@@ -63,27 +64,47 @@ def is_state_leaf(x: tp.Any) -> tpe.TypeGuard[StateLeaf]:
63
64
def is_node_leaf (x : tp .Any ) -> tpe .TypeGuard [NodeLeaf ]:
64
65
return isinstance (x , Variable )
65
66
67
+ class IndexMap (dict [Index , tp .Any ]):
68
+ @staticmethod
69
+ def from_refmap (refmap : RefMap ) -> IndexMap :
70
+ indexmap = IndexMap ()
71
+ indexmap .update ((index , value ) for value , index in refmap .items ())
72
+ return indexmap
73
+
74
+ if config .flax_use_flaxlib :
75
+ import flaxlib
76
+
77
+ globals ()['IndexMap' ] = flaxlib .IndexMap
78
+
66
79
67
80
# RefMap = dict
68
- class RefMap (tp .MutableMapping [A , B ]):
81
+ class RefMap (tp .MutableMapping [tp . Any , int ]):
69
82
"""A mapping that hashes keys by their identity."""
70
83
71
84
def __init__ (
72
- self ,
73
- mapping : tp .Mapping [A , B ] | tp .Iterable [tuple [A , B ]] | None = None ,
74
- / ,
85
+ self ,
86
+ mapping : tp .Mapping [tp .Any , int ]
87
+ | tp .Iterable [tuple [tp .Any , int ]]
88
+ | None = None ,
89
+ / ,
75
90
):
76
- self ._mapping : dict [int , tuple [A , B ]] = dict ()
91
+ self ._mapping : dict [int , tuple [tp . Any , int ]] = dict ()
77
92
if mapping is not None :
78
93
self .update (mapping )
79
94
80
- def __getitem__ (self , key : A ) -> B :
95
+ @staticmethod
96
+ def from_indexmap (indexmap : IndexMap ) -> RefMap :
97
+ refmap = RefMap ()
98
+ refmap .update ((value , index ) for index , value in indexmap .items ())
99
+ return refmap
100
+
101
+ def __getitem__ (self , key : tp .Any ) -> int :
81
102
return self ._mapping [id (key )][1 ]
82
103
83
- def __setitem__ (self , key : A , value : B ):
104
+ def __setitem__ (self , key : tp . Any , value : int ):
84
105
self ._mapping [id (key )] = (key , value )
85
106
86
- def __delitem__ (self , key : A ):
107
+ def __delitem__ (self , key : tp . Any ):
87
108
del self ._mapping [id (key )]
88
109
89
110
def __len__ (self ) -> int :
@@ -92,14 +113,20 @@ def __len__(self) -> int:
92
113
def __contains__ (self , key : tp .Any ) -> bool :
93
114
return id (key ) in self ._mapping
94
115
95
- def __iter__ (self ) -> tp .Iterator [A ]:
116
+ def __iter__ (self ) -> tp .Iterator [tp . Any ]:
96
117
for key , _ in self ._mapping .values ():
97
118
yield key
98
119
99
- def items (self ) -> tp .ItemsView [A , B ]:
120
+ def items (self ) -> tp .ItemsView [tp . Any , int ]:
100
121
return self ._mapping .values () # type: ignore
101
122
102
123
124
+ if config .flax_use_flaxlib :
125
+ import flaxlib
126
+
127
+ globals ()['RefMap' ] = flaxlib .RefMap
128
+
129
+
103
130
@dataclasses .dataclass (frozen = True , slots = True )
104
131
class NodeImplBase (tp .Generic [Node , Leaf , AuxData ]):
105
132
type : type [Node ]
@@ -258,6 +285,11 @@ def __treescope_repr__(self, path, subtree_renderer):
258
285
subtree_renderer = subtree_renderer ,
259
286
)
260
287
288
+ if config .flax_use_flaxlib :
289
+ import flaxlib
290
+
291
+ jax .tree_util .register_static (flaxlib .NodeRef )
292
+ globals ()['NodeRef' ] = flaxlib .NodeRef
261
293
262
294
@jax .tree_util .register_static
263
295
@dataclasses .dataclass (frozen = True , repr = False )
@@ -299,6 +331,11 @@ def __treescope_repr__(self, path, subtree_renderer):
299
331
subtree_renderer = subtree_renderer ,
300
332
)
301
333
334
+ if config .flax_use_flaxlib :
335
+ import flaxlib
336
+
337
+ jax .tree_util .register_static (flaxlib .VariableDef )
338
+ globals ()['VariableDef' ] = flaxlib .VariableDef
302
339
303
340
@jax .tree_util .register_static
304
341
@dataclasses .dataclass (frozen = True , repr = False , slots = True )
@@ -331,9 +368,6 @@ def with_same_outer_index(self) -> NodeDef[Node]:
331
368
metadata = self .metadata ,
332
369
)
333
370
334
- def replace (self , ** kwargs ):
335
- return dataclasses .replace (self , ** kwargs )
336
-
337
371
def __nnx_repr__ (self ):
338
372
yield reprlib .Object (type = type (self ))
339
373
@@ -358,6 +392,13 @@ def __treescope_repr__(self, path, subtree_renderer):
358
392
)
359
393
360
394
395
+ if config .flax_use_flaxlib :
396
+ import flaxlib
397
+
398
+ jax .tree_util .register_static (flaxlib .NodeDef )
399
+ globals ()['NodeDef' ] = flaxlib .NodeDef
400
+
401
+
361
402
@jax .tree_util .register_static
362
403
@dataclasses .dataclass (frozen = True , slots = True )
363
404
class ArrayAttr :
@@ -548,22 +589,23 @@ def _graph_flatten(
548
589
node : Node ,
549
590
node_impl : NodeImpl [Node , Leaf , AuxData ] | None ,
550
591
path : list [Key ] | None ,
551
- ref_index : RefMap [ tp . Any , int ] ,
592
+ ref_index : RefMap ,
552
593
ref_outer_index : RefMap | None ,
553
594
nodes : list [NodeDef [tp .Any ] | VariableDef [tp .Any ] | NodeRef [tp .Any ]],
554
595
attributes : list [tuple [Key , NodeAttr | ArrayAttr | Static [tp .Any ]]],
555
596
leaves : list [StateLeaf | Variable [tp .Any ] | jax .Array | np .ndarray ],
556
597
paths : list [PathParts ] | None ,
557
598
return_variables : bool ,
558
599
) -> None :
559
- is_pytree_node_ = isinstance (node_impl , PytreeNodeImpl )
560
- is_graph_node_ = isinstance (node_impl , GraphNodeImpl )
561
- is_variable = isinstance (node , Variable )
600
+ is_pytree_node_ = type (node_impl ) is PytreeNodeImpl
562
601
563
602
if not is_pytree_node_ and node in ref_index :
564
603
nodes .append (NodeRef (index := ref_index [node ]))
565
604
return
566
605
606
+ is_graph_node_ = type (node_impl ) is GraphNodeImpl
607
+ is_variable = isinstance (node , Variable )
608
+
567
609
# only cache graph nodes
568
610
if is_graph_node_ or is_variable :
569
611
index = len (ref_index )
@@ -599,13 +641,13 @@ def _graph_flatten(
599
641
values , metadata = node_impl .flatten (node )
600
642
num_attributes = len (values )
601
643
nodedef = NodeDef (
602
- type = node_impl .type ,
603
- index = index ,
604
- outer_index = ref_outer_index [node ]
644
+ node_impl .type ,
645
+ index ,
646
+ ref_outer_index [node ]
605
647
if is_graph_node_ and ref_outer_index and node in ref_outer_index
606
648
else None ,
607
- num_attributes = num_attributes ,
608
- metadata = metadata ,
649
+ num_attributes ,
650
+ metadata ,
609
651
)
610
652
nodes .append (nodedef )
611
653
@@ -865,8 +907,8 @@ def unflatten(
865
907
state : State [Key , tp .Any ] | FlatState [tp .Any ] | list [tp .Any ],
866
908
/ ,
867
909
* ,
868
- index_ref : dict [ Index , tp . Any ] | None = None ,
869
- outer_index_outer_ref : dict [ Index , tp . Any ] | None = None ,
910
+ index_ref : IndexMap | None = None ,
911
+ outer_index_outer_ref : IndexMap | None = None ,
870
912
) -> Node :
871
913
"""Unflattens a graphdef into a node with the given state.
872
914
@@ -892,7 +934,7 @@ def unflatten(
892
934
else :
893
935
raise ValueError (f'Unsupported state type: { type (state )} ' )
894
936
if index_ref is None :
895
- index_ref = {}
937
+ index_ref = IndexMap ()
896
938
897
939
if len (leaves ) != graphdef .num_leaves :
898
940
raise ValueError (
@@ -936,8 +978,8 @@ def _graph_unflatten(
936
978
tuple [Key , NodeAttr | ArrayAttr | Static [tp .Any ]]
937
979
],
938
980
leaves_iter : tp .Iterator [tp .Any ],
939
- index_ref : dict [ Index , tp . Any ] ,
940
- outer_index_outer_ref : dict [ Index , tp . Any ] | None ,
981
+ index_ref : IndexMap ,
982
+ outer_index_outer_ref : IndexMap | None ,
941
983
) -> Node :
942
984
"""Recursive helper for graph_unflatten.
943
985
@@ -1001,7 +1043,7 @@ def make_variable(key, variabledef: VariableDef[Variable]) -> tp.Any:
1001
1043
assert type (nodedef ) is NodeDef
1002
1044
if node_impl is None :
1003
1045
raise RuntimeError (f'Unsupported type: { nodedef .type } , this is a bug.' )
1004
- if nodedef .index in index_ref :
1046
+ if nodedef .index is not None and nodedef . index in index_ref :
1005
1047
raise RuntimeError (f'GraphDef index { nodedef .index } already used.' )
1006
1048
1007
1049
def _get_children () -> list [tuple [Key , tp .Any ]]:
@@ -1214,7 +1256,7 @@ class StaticCache(tp.NamedTuple):
1214
1256
paths : tuple [PathParts , ...]
1215
1257
variables : list [Variable [tp .Any ]]
1216
1258
new_ref_index : RefMap
1217
- new_index_ref : dict [ Index , tp . Any ]
1259
+ new_index_ref : IndexMap
1218
1260
1219
1261
@staticmethod
1220
1262
def create (
@@ -1223,7 +1265,7 @@ def create(
1223
1265
variables : list [Variable [tp .Any ]],
1224
1266
new_ref_index : RefMap ,
1225
1267
):
1226
- new_index_ref = { index : obj for obj , index in new_ref_index . items ()}
1268
+ new_index_ref = IndexMap . from_refmap ( new_ref_index )
1227
1269
final_graphdef : GraphDef [tp .Any ]
1228
1270
final_graphdef = graphdef .with_same_outer_index ()
1229
1271
return StaticCache (
@@ -1243,15 +1285,15 @@ class GraphContext(threading.local):
1243
1285
)
1244
1286
ref_index_stack : list [SplitContext ] = dataclasses .field (default_factory = list )
1245
1287
index_ref_stack : list [MergeContext ] = dataclasses .field (default_factory = list )
1246
- tmp_static_cache : RefMap [ tp . Any , StaticCache ] | None = None
1288
+ tmp_static_cache : RefMap | None = None
1247
1289
caching : bool = False
1248
1290
1249
1291
1250
1292
GRAPH_CONTEXT = GraphContext ()
1251
1293
1252
1294
1253
1295
@contextlib .contextmanager
1254
- def static_cache (static_cache : RefMap [ tp . Any , StaticCache ] ):
1296
+ def static_cache (static_cache : RefMap ):
1255
1297
if GRAPH_CONTEXT .caching :
1256
1298
yield
1257
1299
return
@@ -1314,9 +1356,9 @@ def _cached_partial(f: tp.Callable[..., tp.Any], *cached_args):
1314
1356
Returns:
1315
1357
A partial function expecting the remaining arguments to the original function.
1316
1358
"""
1317
- cache : RefMap [ tp . Any , StaticCache ] = RefMap ()
1359
+ cache : RefMap = RefMap ()
1318
1360
original_ref_index : RefMap = RefMap ()
1319
- index_ref : dict [ Index , tp . Any ] = {}
1361
+ index_ref : IndexMap = IndexMap ()
1320
1362
cached_ref_index : RefMap = RefMap ()
1321
1363
1322
1364
def create_static_cache (x ):
@@ -1542,7 +1584,7 @@ def split_context(ctxtag: tp.Hashable | None = None):
1542
1584
@dataclasses .dataclass
1543
1585
class MergeContext :
1544
1586
ctxtag : tp .Hashable | None
1545
- index_ref : dict [ Index , tp . Any ]
1587
+ index_ref : IndexMap
1546
1588
is_inner : bool | None
1547
1589
1548
1590
def merge (
@@ -1668,7 +1710,7 @@ def merge_context(): ...
1668
1710
def merge_context (ctxtag : tp .Hashable | None , inner : bool | None ): ...
1669
1711
@contextlib .contextmanager
1670
1712
def merge_context (ctxtag : tp .Hashable | None = None , inner : bool | None = None ):
1671
- GRAPH_CONTEXT .index_ref_stack .append (MergeContext (ctxtag , {} , inner ))
1713
+ GRAPH_CONTEXT .index_ref_stack .append (MergeContext (ctxtag , IndexMap () , inner ))
1672
1714
1673
1715
try :
1674
1716
yield GRAPH_CONTEXT .index_ref_stack [- 1 ]
@@ -1691,11 +1733,11 @@ class UpdateContext:
1691
1733
1692
1734
tag : tp .Hashable
1693
1735
outer_ref_outer_index : RefMap | None
1694
- outer_index_inner_ref : dict [ Index , tp . Any ] | None
1736
+ outer_index_inner_ref : IndexMap | None
1695
1737
# reverse caches
1696
- outer_index_outer_ref : dict [ Index , tp . Any ] | None
1738
+ outer_index_outer_ref : IndexMap | None
1697
1739
inner_ref_outer_index : RefMap | None
1698
- static_cache : RefMap [ tp . Any , StaticCache ] | None
1740
+ static_cache : RefMap | None
1699
1741
1700
1742
# define hash and eq to make this an opaque object
1701
1743
def __hash__ (self ):
@@ -1716,13 +1758,11 @@ def flatten_end(self, ref_index: RefMap):
1716
1758
self .outer_index_inner_ref = None
1717
1759
self .inner_ref_outer_index = None
1718
1760
1719
- def unflatten_end (self , index_ref : dict [ Index , tp . Any ] , inner_merge : bool ):
1761
+ def unflatten_end (self , index_ref : IndexMap , inner_merge : bool ):
1720
1762
if inner_merge :
1721
1763
# inner merge (2)
1722
1764
self .outer_index_inner_ref = index_ref
1723
- self .inner_ref_outer_index = RefMap (
1724
- (obj , index ) for index , obj in index_ref .items ()
1725
- )
1765
+ self .inner_ref_outer_index = RefMap .from_indexmap (index_ref )
1726
1766
1727
1767
1728
1768
@dataclasses .dataclass
0 commit comments