56
56
GraphState = State [Key , StateLeaf ]
57
57
GraphFlatState = FlatState [StateLeaf ]
58
58
59
+ def maybe_use_flaxlib (name : str , register_static : bool = False ):
60
+ if config .flax_use_flaxlib :
61
+ import flaxlib
62
+
63
+ flaxlib_version : type = getattr (flaxlib , name )
64
+ globals ()[name ] = flaxlib_version
65
+
66
+ if register_static :
67
+ jax .tree_util .register_static (flaxlib_version )
59
68
60
69
def is_state_leaf (x : tp .Any ) -> tpe .TypeGuard [StateLeaf ]:
61
70
return isinstance (x , VariableState )
@@ -71,10 +80,7 @@ def from_refmap(refmap: RefMap) -> IndexMap:
71
80
indexmap .update ((index , value ) for value , index in refmap .items ())
72
81
return indexmap
73
82
74
- if config .flax_use_flaxlib :
75
- import flaxlib
76
-
77
- globals ()['IndexMap' ] = flaxlib .IndexMap
83
+ maybe_use_flaxlib ('IndexMap' )
78
84
79
85
80
86
# RefMap = dict
@@ -120,11 +126,12 @@ def __iter__(self) -> tp.Iterator[tp.Any]:
120
126
def items (self ) -> tp .ItemsView [tp .Any , int ]:
121
127
return self ._mapping .values () # type: ignore
122
128
129
+ # TODO(cgarciae): PyRefMap is currently being used for
130
+ # cached_partial, but it should be removed when cached_partial
131
+ # is no longer needed.
132
+ PyRefMap = RefMap
123
133
124
- if config .flax_use_flaxlib :
125
- import flaxlib
126
-
127
- globals ()['RefMap' ] = flaxlib .RefMap
134
+ maybe_use_flaxlib ('RefMap' )
128
135
129
136
130
137
@dataclasses .dataclass (frozen = True , slots = True )
@@ -136,28 +143,33 @@ def node_dict(self, node: Node) -> dict[Key, Leaf]:
136
143
nodes , _ = self .flatten (node )
137
144
return dict (nodes )
138
145
146
+ maybe_use_flaxlib ('NodeImplBase' )
139
147
140
148
@dataclasses .dataclass (frozen = True , slots = True )
141
- class GraphNodeImpl (NodeImplBase [Node , Leaf , AuxData ]):
142
- set_key : tp .Callable [[Node , Key , Leaf ], None ]
143
- pop_key : tp .Callable [[Node , Key ], Leaf ]
144
- create_empty : tp .Callable [[AuxData ], Node ]
145
- clear : tp .Callable [[Node ], None ]
146
- init : tp .Callable [[Node , tp .Iterable [tuple [Key , Leaf ]]], None ]
149
+ class GraphNodeImpl (NodeImplBase ):
150
+ set_key : tp .Callable [[tp .Any , Key , tp .Any ], None ]
151
+ pop_key : tp .Callable [[tp .Any , Key ], tp .Any ]
152
+ create_empty : tp .Callable [[tp .Any ], tp .Any ]
153
+ clear : tp .Callable [[tp .Any ], None ]
154
+ init : tp .Callable [[tp .Any , tp .Iterable [tuple [Key , tp .Any ]]], None ]
155
+
156
+
157
+ maybe_use_flaxlib ('GraphNodeImpl' )
147
158
148
159
149
160
@dataclasses .dataclass (frozen = True , slots = True )
150
- class PytreeNodeImpl (NodeImplBase [Node , Leaf , AuxData ]):
151
- unflatten : tp .Callable [[tp .Sequence [tuple [Key , Leaf ]], AuxData ], Node ]
161
+ class PytreeNodeImpl (NodeImplBase ):
162
+ unflatten : tp .Callable [[tp .Sequence [tuple [Key , tp .Any ]], tp .Any ], tp .Any ]
163
+
164
+
165
+ maybe_use_flaxlib ('PytreeNodeImpl' )
152
166
153
167
154
- NodeImpl = tp .Union [
155
- GraphNodeImpl [Node , Leaf , AuxData ], PytreeNodeImpl [Node , Leaf , AuxData ]
156
- ]
168
+ NodeImpl = tp .Union [GraphNodeImpl , PytreeNodeImpl ]
157
169
158
170
159
- GRAPH_REGISTRY : dict [type , NodeImpl [ tp . Any , tp . Any , tp . Any ] ] = {}
160
- PYTREE_REGISTRY : dict [type , PytreeNodeImpl [ tp . Any , tp . Any , tp . Any ] ] = {}
171
+ GRAPH_REGISTRY : dict [type , NodeImpl ] = {}
172
+ PYTREE_REGISTRY : dict [type , PytreeNodeImpl ] = {}
161
173
162
174
163
175
def register_graph_node_type (
@@ -173,13 +185,13 @@ def register_graph_node_type(
173
185
raise ValueError (f'Node type { type } is already registered.' )
174
186
175
187
GRAPH_REGISTRY [type ] = GraphNodeImpl (
176
- type = type ,
177
- flatten = flatten ,
178
- set_key = set_key ,
179
- pop_key = pop_key ,
180
- create_empty = create_empty ,
181
- clear = clear ,
182
- init = init ,
188
+ type ,
189
+ flatten ,
190
+ set_key ,
191
+ pop_key ,
192
+ create_empty ,
193
+ clear ,
194
+ init ,
183
195
)
184
196
185
197
@@ -191,9 +203,7 @@ def register_pytree_node_type(
191
203
if type in PYTREE_REGISTRY :
192
204
raise ValueError (f'Node type { type } is already registered.' )
193
205
194
- PYTREE_REGISTRY [type ] = PytreeNodeImpl (
195
- type = type , flatten = flatten , unflatten = unflatten
196
- )
206
+ PYTREE_REGISTRY [type ] = PytreeNodeImpl (type , flatten , unflatten )
197
207
198
208
199
209
def is_node (x : tp .Any ) -> bool :
@@ -210,16 +220,13 @@ def is_node_type(x: type[tp.Any]) -> bool:
210
220
return x in GRAPH_REGISTRY or x in PYTREE_REGISTRY or x is GenericPytree
211
221
212
222
213
- def get_node_impl (x : Node ) -> NodeImpl [Node , tp .Any , tp .Any ] | None :
214
- if isinstance (x , Variable ):
215
- return None
216
-
223
+ def get_node_impl (x ) -> NodeImpl | None :
217
224
node_type = type (x )
218
225
219
- if node_type in GRAPH_REGISTRY :
220
- return GRAPH_REGISTRY [ node_type ]
221
- elif node_type in PYTREE_REGISTRY :
222
- return PYTREE_REGISTRY [ node_type ]
226
+ if node_impl := GRAPH_REGISTRY . get ( node_type ) :
227
+ return node_impl
228
+ elif node_impl := PYTREE_REGISTRY . get ( node_type ) :
229
+ return node_impl
223
230
elif node_type in JAX_PYTREE_REGISTRY or issubclass (node_type , tuple ):
224
231
return PYTREE_NODE_IMPL # type: ignore
225
232
else :
@@ -285,11 +292,7 @@ def __treescope_repr__(self, path, subtree_renderer):
285
292
subtree_renderer = subtree_renderer ,
286
293
)
287
294
288
- if config .flax_use_flaxlib :
289
- import flaxlib
290
-
291
- jax .tree_util .register_static (flaxlib .NodeRef )
292
- globals ()['NodeRef' ] = flaxlib .NodeRef
295
+ maybe_use_flaxlib ('NodeRef' , register_static = True )
293
296
294
297
@jax .tree_util .register_static
295
298
@dataclasses .dataclass (frozen = True , repr = False )
@@ -331,11 +334,7 @@ def __treescope_repr__(self, path, subtree_renderer):
331
334
subtree_renderer = subtree_renderer ,
332
335
)
333
336
334
- if config .flax_use_flaxlib :
335
- import flaxlib
336
-
337
- jax .tree_util .register_static (flaxlib .VariableDef )
338
- globals ()['VariableDef' ] = flaxlib .VariableDef
337
+ maybe_use_flaxlib ('VariableDef' , register_static = True )
339
338
340
339
@jax .tree_util .register_static
341
340
@dataclasses .dataclass (frozen = True , repr = False , slots = True )
@@ -391,12 +390,7 @@ def __treescope_repr__(self, path, subtree_renderer):
391
390
subtree_renderer = subtree_renderer ,
392
391
)
393
392
394
-
395
- if config .flax_use_flaxlib :
396
- import flaxlib
397
-
398
- jax .tree_util .register_static (flaxlib .NodeDef )
399
- globals ()['NodeDef' ] = flaxlib .NodeDef
393
+ maybe_use_flaxlib ('NodeDef' , register_static = True )
400
394
401
395
402
396
@jax .tree_util .register_static
@@ -628,10 +622,10 @@ def _graph_flatten(
628
622
assert paths is not None
629
623
paths .append (tuple (path ))
630
624
variabledef = VariableDef (
631
- type = type (node ),
632
- index = index ,
633
- outer_index = ref_outer_index .get (node , None ) if ref_outer_index else None ,
634
- metadata = HashableMapping (node ._var_metadata ),
625
+ type (node ),
626
+ index ,
627
+ ref_outer_index .get (node , None ) if ref_outer_index else None ,
628
+ HashableMapping (node ._var_metadata ),
635
629
)
636
630
nodes .append (variabledef )
637
631
return
@@ -683,6 +677,111 @@ def _graph_flatten(
683
677
684
678
return
685
679
680
+ def _flatten_fast (
681
+ node : Node ,
682
+ / ,
683
+ * ,
684
+ ref_index : RefMap ,
685
+ ref_outer_index : RefMap | None ,
686
+ ) -> tuple [GraphDef [Node ], list [tp .Any ]]:
687
+ leaves : list [jax .Array | np .ndarray ] = []
688
+ nodes : list [NodeDef [tp .Any ] | VariableDef [tp .Any ] | NodeRef [tp .Any ]] = []
689
+ attributes : list [tuple [Key , NodeAttr | ArrayAttr | Static [tp .Any ]]] = []
690
+ node_impl = get_node_impl (node )
691
+ if node_impl is None and not isinstance (node , Variable ):
692
+ raise RuntimeError (f'Unsupported type: { type (node )} , this is a bug.' )
693
+ _graph_flatten_fast (
694
+ node ,
695
+ node_impl ,
696
+ ref_index ,
697
+ ref_outer_index ,
698
+ nodes ,
699
+ attributes ,
700
+ leaves ,
701
+ )
702
+ graphdef = GraphDef (
703
+ nodes = nodes , attributes = attributes , num_leaves = len (leaves )
704
+ )
705
+
706
+ return graphdef , leaves
707
+
708
+
709
+ def _graph_flatten_fast (
710
+ node : Node ,
711
+ node_impl : NodeImpl [Node , Leaf , AuxData ] | None ,
712
+ ref_index : RefMap ,
713
+ ref_outer_index : RefMap | None ,
714
+ nodes : list [NodeDef [tp .Any ] | VariableDef [tp .Any ] | NodeRef [tp .Any ]],
715
+ attributes : list [tuple [Key , NodeAttr | ArrayAttr | Static [tp .Any ]]],
716
+ leaves : list [jax .Array | np .ndarray ],
717
+ ) -> None :
718
+ is_pytree_node_ = type (node_impl ) is PytreeNodeImpl
719
+
720
+ if not is_pytree_node_ and node in ref_index :
721
+ nodes .append (NodeRef (index := ref_index [node ]))
722
+ return
723
+
724
+ is_graph_node_ = type (node_impl ) is GraphNodeImpl
725
+ is_variable = isinstance (node , Variable )
726
+
727
+ # only cache graph nodes
728
+ if is_pytree_node_ :
729
+ index = None
730
+ else :
731
+ index = len (ref_index )
732
+ ref_index [node ] = index
733
+
734
+ if is_variable :
735
+ assert isinstance (node , Variable )
736
+ assert index is not None
737
+ leaf = node .raw_value
738
+ leaves .append (leaf )
739
+ variabledef = VariableDef (
740
+ type (node ),
741
+ index ,
742
+ ref_outer_index .get (node , None ) if ref_outer_index else None ,
743
+ HashableMapping (node ._var_metadata ),
744
+ )
745
+ nodes .append (variabledef )
746
+ return
747
+
748
+ if node_impl is None :
749
+ raise RuntimeError (f'Unsupported type: { type (node )} , this is a bug.' )
750
+
751
+ values , metadata = node_impl .flatten (node )
752
+ num_attributes = len (values )
753
+ nodedef = NodeDef (
754
+ node_impl .type ,
755
+ index ,
756
+ ref_outer_index [node ]
757
+ if is_graph_node_ and ref_outer_index and node in ref_outer_index
758
+ else None ,
759
+ num_attributes ,
760
+ metadata ,
761
+ )
762
+ nodes .append (nodedef )
763
+
764
+ for key , value in values :
765
+ value_node_impl = get_node_impl (value )
766
+ if value_node_impl is not None or isinstance (value , Variable ):
767
+ attributes .append ((key , NODE_ATTR ))
768
+ _graph_flatten_fast (
769
+ value ,
770
+ value_node_impl ,
771
+ ref_index ,
772
+ ref_outer_index ,
773
+ nodes ,
774
+ attributes ,
775
+ leaves ,
776
+ )
777
+ elif isinstance (value , (jax .Array , np .ndarray )):
778
+ attributes .append ((key , ARRAY_ATTR ))
779
+ leaves .append (value )
780
+ else :
781
+ attributes .append ((key , Static (value )))
782
+
783
+ return
784
+
686
785
687
786
@dataclasses .dataclass (slots = True )
688
787
class FingerprintContext :
@@ -1286,15 +1385,15 @@ class GraphContext(threading.local):
1286
1385
)
1287
1386
ref_index_stack : list [SplitContext ] = dataclasses .field (default_factory = list )
1288
1387
index_ref_stack : list [MergeContext ] = dataclasses .field (default_factory = list )
1289
- tmp_static_cache : RefMap | None = None
1388
+ tmp_static_cache : PyRefMap | None = None
1290
1389
caching : bool = False
1291
1390
1292
1391
1293
1392
GRAPH_CONTEXT = GraphContext ()
1294
1393
1295
1394
1296
1395
@contextlib .contextmanager
1297
- def static_cache (static_cache : RefMap ):
1396
+ def static_cache (static_cache : PyRefMap ):
1298
1397
if GRAPH_CONTEXT .caching :
1299
1398
yield
1300
1399
return
@@ -1357,7 +1456,7 @@ def _cached_partial(f: tp.Callable[..., tp.Any], *cached_args):
1357
1456
Returns:
1358
1457
A partial function expecting the remaining arguments to the original function.
1359
1458
"""
1360
- cache : RefMap = RefMap ()
1459
+ cache = PyRefMap ()
1361
1460
original_ref_index : RefMap = RefMap ()
1362
1461
index_ref : IndexMap = IndexMap ()
1363
1462
cached_ref_index : RefMap = RefMap ()
@@ -1379,7 +1478,7 @@ def create_static_cache(x):
1379
1478
new_ref_index = cached_new_ref_index ,
1380
1479
)
1381
1480
cached_ref_index .update (cached_new_ref_index )
1382
- cache [node_cache ] = StaticCache .create (
1481
+ cache [node_cache ] = StaticCache .create ( # type: ignore
1383
1482
graphdef , paths , variables , cached_new_ref_index
1384
1483
)
1385
1484
return node_cache
@@ -1564,6 +1663,20 @@ def flatten(
1564
1663
else :
1565
1664
return graphdef , leaves
1566
1665
1666
+ def flatten_fast (self , node : A ) -> tuple [GraphDef [A ], list [tp .Any ]]:
1667
+ ctx = (
1668
+ current_update_context (self .ctxtag ) if self .ctxtag is not None else None
1669
+ )
1670
+ ref_outer_index = (
1671
+ ctx .inner_ref_outer_index if ctx and ctx .inner_ref_outer_index else None
1672
+ )
1673
+ graphdef , leaves = _flatten_fast (
1674
+ node ,
1675
+ ref_index = self .ref_index ,
1676
+ ref_outer_index = ref_outer_index ,
1677
+ )
1678
+ return graphdef , leaves
1679
+
1567
1680
1568
1681
@contextlib .contextmanager
1569
1682
def split_context (ctxtag : tp .Hashable | None = None ):
@@ -2621,9 +2734,9 @@ def _unflatten_pytree(
2621
2734
2622
2735
2623
2736
PYTREE_NODE_IMPL = PytreeNodeImpl (
2624
- type = GenericPytree ,
2625
- flatten = _flatten_pytree ,
2626
- unflatten = _unflatten_pytree , # type: ignore
2737
+ GenericPytree ,
2738
+ _flatten_pytree ,
2739
+ _unflatten_pytree , # type: ignore
2627
2740
)
2628
2741
2629
2742
# common pytrees
0 commit comments