20
20
import threading
21
21
import typing as tp
22
22
23
+ from flax import config
23
24
from flax .nnx import filterlib , reprlib , traversals , variablelib
24
25
from flax .nnx import statelib
25
26
from flax .nnx .proxy_caller import (
@@ -63,27 +64,47 @@ def is_state_leaf(x: tp.Any) -> tpe.TypeGuard[StateLeaf]:
63
64
def is_node_leaf (x : tp .Any ) -> tpe .TypeGuard [NodeLeaf ]:
64
65
return isinstance (x , Variable )
65
66
67
+ class IndexMap (dict [Index , tp .Any ]):
68
+ @staticmethod
69
+ def from_refmap (refmap : RefMap ) -> IndexMap :
70
+ indexmap = IndexMap ()
71
+ indexmap .update ((index , value ) for value , index in refmap .items ())
72
+ return indexmap
73
+
74
+ if config .flax_use_flaxlib :
75
+ import flaxlib
76
+
77
+ globals ()['IndexMap' ] = flaxlib .IndexMap
78
+
66
79
67
80
# RefMap = dict
68
- class RefMap (tp .MutableMapping [A , B ]):
81
+ class RefMap (tp .MutableMapping [tp . Any , int ]):
69
82
"""A mapping that hashes keys by their identity."""
70
83
71
84
def __init__ (
72
- self ,
73
- mapping : tp .Mapping [A , B ] | tp .Iterable [tuple [A , B ]] | None = None ,
74
- / ,
85
+ self ,
86
+ mapping : tp .Mapping [tp .Any , int ]
87
+ | tp .Iterable [tuple [tp .Any , int ]]
88
+ | None = None ,
89
+ / ,
75
90
):
76
- self ._mapping : dict [int , tuple [A , B ]] = dict ()
91
+ self ._mapping : dict [int , tuple [tp . Any , int ]] = dict ()
77
92
if mapping is not None :
78
93
self .update (mapping )
79
94
80
- def __getitem__ (self , key : A ) -> B :
95
+ @staticmethod
96
+ def from_indexmap (indexmap : IndexMap ) -> RefMap :
97
+ refmap = RefMap ()
98
+ refmap .update ((value , index ) for index , value in indexmap .items ())
99
+ return refmap
100
+
101
+ def __getitem__ (self , key : tp .Any ) -> int :
81
102
return self ._mapping [id (key )][1 ]
82
103
83
- def __setitem__ (self , key : A , value : B ):
104
+ def __setitem__ (self , key : tp . Any , value : int ):
84
105
self ._mapping [id (key )] = (key , value )
85
106
86
- def __delitem__ (self , key : A ):
107
+ def __delitem__ (self , key : tp . Any ):
87
108
del self ._mapping [id (key )]
88
109
89
110
def __len__ (self ) -> int :
@@ -92,14 +113,20 @@ def __len__(self) -> int:
92
113
def __contains__ (self , key : tp .Any ) -> bool :
93
114
return id (key ) in self ._mapping
94
115
95
- def __iter__ (self ) -> tp .Iterator [A ]:
116
+ def __iter__ (self ) -> tp .Iterator [tp . Any ]:
96
117
for key , _ in self ._mapping .values ():
97
118
yield key
98
119
99
- def items (self ) -> tp .ItemsView [A , B ]:
120
+ def items (self ) -> tp .ItemsView [tp . Any , int ]:
100
121
return self ._mapping .values () # type: ignore
101
122
102
123
124
+ if config .flax_use_flaxlib :
125
+ import flaxlib
126
+
127
+ globals ()['RefMap' ] = flaxlib .RefMap
128
+
129
+
103
130
@dataclasses .dataclass (frozen = True , slots = True )
104
131
class NodeImplBase (tp .Generic [Node , Leaf , AuxData ]):
105
132
type : type [Node ]
@@ -258,6 +285,11 @@ def __treescope_repr__(self, path, subtree_renderer):
258
285
subtree_renderer = subtree_renderer ,
259
286
)
260
287
288
+ if config .flax_use_flaxlib :
289
+ import flaxlib
290
+
291
+ jax .tree_util .register_static (flaxlib .NodeRef )
292
+ globals ()['NodeRef' ] = flaxlib .NodeRef
261
293
262
294
@jax .tree_util .register_static
263
295
@dataclasses .dataclass (frozen = True , repr = False )
@@ -299,6 +331,11 @@ def __treescope_repr__(self, path, subtree_renderer):
299
331
subtree_renderer = subtree_renderer ,
300
332
)
301
333
334
+ if config .flax_use_flaxlib :
335
+ import flaxlib
336
+
337
+ jax .tree_util .register_static (flaxlib .VariableDef )
338
+ globals ()['VariableDef' ] = flaxlib .VariableDef
302
339
303
340
@jax .tree_util .register_static
304
341
@dataclasses .dataclass (frozen = True , repr = False , slots = True )
@@ -331,9 +368,6 @@ def with_same_outer_index(self) -> NodeDef[Node]:
331
368
metadata = self .metadata ,
332
369
)
333
370
334
- def replace (self , ** kwargs ):
335
- return dataclasses .replace (self , ** kwargs )
336
-
337
371
def __nnx_repr__ (self ):
338
372
yield reprlib .Object (type = type (self ))
339
373
@@ -358,6 +392,13 @@ def __treescope_repr__(self, path, subtree_renderer):
358
392
)
359
393
360
394
395
+ if config .flax_use_flaxlib :
396
+ import flaxlib
397
+
398
+ jax .tree_util .register_static (flaxlib .NodeDef )
399
+ globals ()['NodeDef' ] = flaxlib .NodeDef
400
+
401
+
361
402
@jax .tree_util .register_static
362
403
@dataclasses .dataclass (frozen = True , slots = True )
363
404
class ArrayAttr :
@@ -548,23 +589,24 @@ def _graph_flatten(
548
589
node : Node ,
549
590
node_impl : NodeImpl [Node , Leaf , AuxData ] | None ,
550
591
path : list [Key ] | None ,
551
- ref_index : RefMap [ tp . Any , int ] ,
592
+ ref_index : RefMap ,
552
593
ref_outer_index : RefMap | None ,
553
594
nodes : list [NodeDef [tp .Any ] | VariableDef [tp .Any ] | NodeRef [tp .Any ]],
554
595
attributes : list [tuple [Key , NodeAttr | ArrayAttr | Static [tp .Any ]]],
555
596
leaves : list [StateLeaf | Variable [tp .Any ] | jax .Array | np .ndarray ],
556
597
paths : list [PathParts ] | None ,
557
598
return_variables : bool ,
558
599
) -> None :
559
- is_pytree_node_ = isinstance (node_impl , PytreeNodeImpl )
560
- is_graph_node_ = isinstance (node_impl , GraphNodeImpl )
561
- is_variable = isinstance (node , Variable )
600
+ is_pytree_node_ = type (node_impl ) is PytreeNodeImpl
562
601
563
602
index : int | None
564
603
if not is_pytree_node_ and node in ref_index :
565
604
nodes .append (NodeRef (index := ref_index [node ]))
566
605
return
567
606
607
+ is_graph_node_ = type (node_impl ) is GraphNodeImpl
608
+ is_variable = isinstance (node , Variable )
609
+
568
610
# only cache graph nodes
569
611
if is_graph_node_ or is_variable :
570
612
index = len (ref_index )
@@ -600,13 +642,13 @@ def _graph_flatten(
600
642
values , metadata = node_impl .flatten (node )
601
643
num_attributes = len (values )
602
644
nodedef = NodeDef (
603
- type = node_impl .type ,
604
- index = index ,
605
- outer_index = ref_outer_index [node ]
645
+ node_impl .type ,
646
+ index ,
647
+ ref_outer_index [node ]
606
648
if is_graph_node_ and ref_outer_index and node in ref_outer_index
607
649
else None ,
608
- num_attributes = num_attributes ,
609
- metadata = metadata ,
650
+ num_attributes ,
651
+ metadata ,
610
652
)
611
653
nodes .append (nodedef )
612
654
@@ -866,8 +908,8 @@ def unflatten(
866
908
state : State [Key , tp .Any ] | FlatState [tp .Any ] | list [tp .Any ],
867
909
/ ,
868
910
* ,
869
- index_ref : dict [ Index , tp . Any ] | None = None ,
870
- outer_index_outer_ref : dict [ Index , tp . Any ] | None = None ,
911
+ index_ref : IndexMap | None = None ,
912
+ outer_index_outer_ref : IndexMap | None = None ,
871
913
) -> Node :
872
914
"""Unflattens a graphdef into a node with the given state.
873
915
@@ -893,7 +935,7 @@ def unflatten(
893
935
else :
894
936
raise ValueError (f'Unsupported state type: { type (state )} ' )
895
937
if index_ref is None :
896
- index_ref = {}
938
+ index_ref = IndexMap ()
897
939
898
940
if len (leaves ) != graphdef .num_leaves :
899
941
raise ValueError (
@@ -937,8 +979,8 @@ def _graph_unflatten(
937
979
tuple [Key , NodeAttr | ArrayAttr | Static [tp .Any ]]
938
980
],
939
981
leaves_iter : tp .Iterator [tp .Any ],
940
- index_ref : dict [ Index , tp . Any ] ,
941
- outer_index_outer_ref : dict [ Index , tp . Any ] | None ,
982
+ index_ref : IndexMap ,
983
+ outer_index_outer_ref : IndexMap | None ,
942
984
) -> Node :
943
985
"""Recursive helper for graph_unflatten.
944
986
@@ -1002,7 +1044,7 @@ def make_variable(key, variabledef: VariableDef[Variable]) -> tp.Any:
1002
1044
assert type (nodedef ) is NodeDef
1003
1045
if node_impl is None :
1004
1046
raise RuntimeError (f'Unsupported type: { nodedef .type } , this is a bug.' )
1005
- if nodedef .index in index_ref :
1047
+ if nodedef .index is not None and nodedef . index in index_ref :
1006
1048
raise RuntimeError (f'GraphDef index { nodedef .index } already used.' )
1007
1049
1008
1050
def _get_children () -> list [tuple [Key , tp .Any ]]:
@@ -1215,7 +1257,7 @@ class StaticCache(tp.NamedTuple):
1215
1257
paths : tuple [PathParts , ...]
1216
1258
variables : list [Variable [tp .Any ]]
1217
1259
new_ref_index : RefMap
1218
- new_index_ref : dict [ Index , tp . Any ]
1260
+ new_index_ref : IndexMap
1219
1261
1220
1262
@staticmethod
1221
1263
def create (
@@ -1224,7 +1266,7 @@ def create(
1224
1266
variables : list [Variable [tp .Any ]],
1225
1267
new_ref_index : RefMap ,
1226
1268
):
1227
- new_index_ref = { index : obj for obj , index in new_ref_index . items ()}
1269
+ new_index_ref = IndexMap . from_refmap ( new_ref_index )
1228
1270
final_graphdef : GraphDef [tp .Any ]
1229
1271
final_graphdef = graphdef .with_same_outer_index ()
1230
1272
return StaticCache (
@@ -1244,15 +1286,15 @@ class GraphContext(threading.local):
1244
1286
)
1245
1287
ref_index_stack : list [SplitContext ] = dataclasses .field (default_factory = list )
1246
1288
index_ref_stack : list [MergeContext ] = dataclasses .field (default_factory = list )
1247
- tmp_static_cache : RefMap [ tp . Any , StaticCache ] | None = None
1289
+ tmp_static_cache : RefMap | None = None
1248
1290
caching : bool = False
1249
1291
1250
1292
1251
1293
GRAPH_CONTEXT = GraphContext ()
1252
1294
1253
1295
1254
1296
@contextlib .contextmanager
1255
- def static_cache (static_cache : RefMap [ tp . Any , StaticCache ] ):
1297
+ def static_cache (static_cache : RefMap ):
1256
1298
if GRAPH_CONTEXT .caching :
1257
1299
yield
1258
1300
return
@@ -1315,9 +1357,9 @@ def _cached_partial(f: tp.Callable[..., tp.Any], *cached_args):
1315
1357
Returns:
1316
1358
A partial function expecting the remaining arguments to the original function.
1317
1359
"""
1318
- cache : RefMap [ tp . Any , StaticCache ] = RefMap ()
1360
+ cache : RefMap = RefMap ()
1319
1361
original_ref_index : RefMap = RefMap ()
1320
- index_ref : dict [ Index , tp . Any ] = {}
1362
+ index_ref : IndexMap = IndexMap ()
1321
1363
cached_ref_index : RefMap = RefMap ()
1322
1364
1323
1365
def create_static_cache (x ):
@@ -1543,7 +1585,7 @@ def split_context(ctxtag: tp.Hashable | None = None):
1543
1585
@dataclasses .dataclass
1544
1586
class MergeContext :
1545
1587
ctxtag : tp .Hashable | None
1546
- index_ref : dict [ Index , tp . Any ]
1588
+ index_ref : IndexMap
1547
1589
is_inner : bool | None
1548
1590
1549
1591
def merge (
@@ -1669,7 +1711,7 @@ def merge_context(): ...
1669
1711
def merge_context (ctxtag : tp .Hashable | None , inner : bool | None ): ...
1670
1712
@contextlib .contextmanager
1671
1713
def merge_context (ctxtag : tp .Hashable | None = None , inner : bool | None = None ):
1672
- GRAPH_CONTEXT .index_ref_stack .append (MergeContext (ctxtag , {} , inner ))
1714
+ GRAPH_CONTEXT .index_ref_stack .append (MergeContext (ctxtag , IndexMap () , inner ))
1673
1715
1674
1716
try :
1675
1717
yield GRAPH_CONTEXT .index_ref_stack [- 1 ]
@@ -1692,11 +1734,11 @@ class UpdateContext:
1692
1734
1693
1735
tag : tp .Hashable
1694
1736
outer_ref_outer_index : RefMap | None
1695
- outer_index_inner_ref : dict [ Index , tp . Any ] | None
1737
+ outer_index_inner_ref : IndexMap | None
1696
1738
# reverse caches
1697
- outer_index_outer_ref : dict [ Index , tp . Any ] | None
1739
+ outer_index_outer_ref : IndexMap | None
1698
1740
inner_ref_outer_index : RefMap | None
1699
- static_cache : RefMap [ tp . Any , StaticCache ] | None
1741
+ static_cache : RefMap | None
1700
1742
1701
1743
# define hash and eq to make this an opaque object
1702
1744
def __hash__ (self ):
@@ -1717,13 +1759,11 @@ def flatten_end(self, ref_index: RefMap):
1717
1759
self .outer_index_inner_ref = None
1718
1760
self .inner_ref_outer_index = None
1719
1761
1720
- def unflatten_end (self , index_ref : dict [ Index , tp . Any ] , inner_merge : bool ):
1762
+ def unflatten_end (self , index_ref : IndexMap , inner_merge : bool ):
1721
1763
if inner_merge :
1722
1764
# inner merge (2)
1723
1765
self .outer_index_inner_ref = index_ref
1724
- self .inner_ref_outer_index = RefMap (
1725
- (obj , index ) for index , obj in index_ref .items ()
1726
- )
1766
+ self .inner_ref_outer_index = RefMap .from_indexmap (index_ref )
1727
1767
1728
1768
1729
1769
@dataclasses .dataclass
0 commit comments