56
56
GraphState = State [Key , StateLeaf ]
57
57
GraphFlatState = FlatState [StateLeaf ]
58
58
59
+ def maybe_use_flaxlib (name : str , register_static : bool = False ):
60
+ if config .flax_use_flaxlib :
61
+ import flaxlib
62
+
63
+ flaxlib_version : type = getattr (flaxlib , name )
64
+ globals ()[name ] = flaxlib_version
65
+
66
+ if register_static :
67
+ jax .tree_util .register_static (flaxlib_version )
59
68
60
69
def is_state_leaf (x : tp .Any ) -> tpe .TypeGuard [StateLeaf ]:
61
70
return isinstance (x , VariableState )
@@ -71,10 +80,7 @@ def from_refmap(refmap: RefMap) -> IndexMap:
71
80
indexmap .update ((index , value ) for value , index in refmap .items ())
72
81
return indexmap
73
82
74
- if config .flax_use_flaxlib :
75
- import flaxlib
76
-
77
- globals ()['IndexMap' ] = flaxlib .IndexMap
83
+ maybe_use_flaxlib ('IndexMap' )
78
84
79
85
80
86
# RefMap = dict
@@ -120,11 +126,12 @@ def __iter__(self) -> tp.Iterator[tp.Any]:
120
126
def items (self ) -> tp .ItemsView [tp .Any , int ]:
121
127
return self ._mapping .values () # type: ignore
122
128
129
+ # TODO(cgarciae): PyRefMap is currently being used for
130
+ # cached_partial, but it should be removed when cached_partial
131
+ # is no longer needed.
132
+ PyRefMap = RefMap
123
133
124
- if config .flax_use_flaxlib :
125
- import flaxlib
126
-
127
- globals ()['RefMap' ] = flaxlib .RefMap
134
+ maybe_use_flaxlib ('RefMap' )
128
135
129
136
130
137
@dataclasses .dataclass (frozen = True , slots = True )
@@ -136,28 +143,33 @@ def node_dict(self, node: Node) -> dict[Key, Leaf]:
136
143
nodes , _ = self .flatten (node )
137
144
return dict (nodes )
138
145
146
+ maybe_use_flaxlib ('NodeImplBase' )
139
147
140
148
@dataclasses .dataclass (frozen = True , slots = True )
141
- class GraphNodeImpl (NodeImplBase [Node , Leaf , AuxData ]):
142
- set_key : tp .Callable [[Node , Key , Leaf ], None ]
143
- pop_key : tp .Callable [[Node , Key ], Leaf ]
144
- create_empty : tp .Callable [[AuxData ], Node ]
145
- clear : tp .Callable [[Node ], None ]
146
- init : tp .Callable [[Node , tp .Iterable [tuple [Key , Leaf ]]], None ]
149
+ class GraphNodeImpl (NodeImplBase ):
150
+ set_key : tp .Callable [[tp .Any , Key , tp .Any ], None ]
151
+ pop_key : tp .Callable [[tp .Any , Key ], tp .Any ]
152
+ create_empty : tp .Callable [[tp .Any ], tp .Any ]
153
+ clear : tp .Callable [[tp .Any ], None ]
154
+ init : tp .Callable [[tp .Any , tp .Iterable [tuple [Key , tp .Any ]]], None ]
155
+
156
+
157
+ maybe_use_flaxlib ('GraphNodeImpl' )
147
158
148
159
149
160
@dataclasses .dataclass (frozen = True , slots = True )
150
- class PytreeNodeImpl (NodeImplBase [Node , Leaf , AuxData ]):
151
- unflatten : tp .Callable [[tp .Sequence [tuple [Key , Leaf ]], AuxData ], Node ]
161
+ class PytreeNodeImpl (NodeImplBase ):
162
+ unflatten : tp .Callable [[tp .Sequence [tuple [Key , tp .Any ]], tp .Any ], tp .Any ]
163
+
164
+
165
+ maybe_use_flaxlib ('PytreeNodeImpl' )
152
166
153
167
154
- NodeImpl = tp .Union [
155
- GraphNodeImpl [Node , Leaf , AuxData ], PytreeNodeImpl [Node , Leaf , AuxData ]
156
- ]
168
+ NodeImpl = tp .Union [GraphNodeImpl , PytreeNodeImpl ]
157
169
158
170
159
- GRAPH_REGISTRY : dict [type , NodeImpl [ tp . Any , tp . Any , tp . Any ] ] = {}
160
- PYTREE_REGISTRY : dict [type , PytreeNodeImpl [ tp . Any , tp . Any , tp . Any ] ] = {}
171
+ GRAPH_REGISTRY : dict [type , NodeImpl ] = {}
172
+ PYTREE_REGISTRY : dict [type , PytreeNodeImpl ] = {}
161
173
162
174
163
175
def register_graph_node_type (
@@ -173,13 +185,13 @@ def register_graph_node_type(
173
185
raise ValueError (f'Node type { type } is already registered.' )
174
186
175
187
GRAPH_REGISTRY [type ] = GraphNodeImpl (
176
- type = type ,
177
- flatten = flatten ,
178
- set_key = set_key ,
179
- pop_key = pop_key ,
180
- create_empty = create_empty ,
181
- clear = clear ,
182
- init = init ,
188
+ type ,
189
+ flatten ,
190
+ set_key ,
191
+ pop_key ,
192
+ create_empty ,
193
+ clear ,
194
+ init ,
183
195
)
184
196
185
197
@@ -191,9 +203,7 @@ def register_pytree_node_type(
191
203
if type in PYTREE_REGISTRY :
192
204
raise ValueError (f'Node type { type } is already registered.' )
193
205
194
- PYTREE_REGISTRY [type ] = PytreeNodeImpl (
195
- type = type , flatten = flatten , unflatten = unflatten
196
- )
206
+ PYTREE_REGISTRY [type ] = PytreeNodeImpl (type , flatten , unflatten )
197
207
198
208
199
209
def is_node (x : tp .Any ) -> bool :
@@ -210,16 +220,13 @@ def is_node_type(x: type[tp.Any]) -> bool:
210
220
return x in GRAPH_REGISTRY or x in PYTREE_REGISTRY or x is GenericPytree
211
221
212
222
213
- def get_node_impl (x : Node ) -> NodeImpl [Node , tp .Any , tp .Any ] | None :
214
- if isinstance (x , Variable ):
215
- return None
216
-
223
+ def get_node_impl (x ) -> NodeImpl | None :
217
224
node_type = type (x )
218
225
219
- if node_type in GRAPH_REGISTRY :
220
- return GRAPH_REGISTRY [ node_type ]
221
- elif node_type in PYTREE_REGISTRY :
222
- return PYTREE_REGISTRY [ node_type ]
226
+ if node_impl := GRAPH_REGISTRY . get ( node_type ) :
227
+ return node_impl
228
+ elif node_impl := PYTREE_REGISTRY . get ( node_type ) :
229
+ return node_impl
223
230
elif node_type in JAX_PYTREE_REGISTRY or issubclass (node_type , tuple ):
224
231
return PYTREE_NODE_IMPL # type: ignore
225
232
else :
@@ -285,11 +292,7 @@ def __treescope_repr__(self, path, subtree_renderer):
285
292
subtree_renderer = subtree_renderer ,
286
293
)
287
294
288
- if config .flax_use_flaxlib :
289
- import flaxlib
290
-
291
- jax .tree_util .register_static (flaxlib .NodeRef )
292
- globals ()['NodeRef' ] = flaxlib .NodeRef
295
+ maybe_use_flaxlib ('NodeRef' , register_static = True )
293
296
294
297
@jax .tree_util .register_static
295
298
@dataclasses .dataclass (frozen = True , repr = False )
@@ -331,11 +334,7 @@ def __treescope_repr__(self, path, subtree_renderer):
331
334
subtree_renderer = subtree_renderer ,
332
335
)
333
336
334
- if config .flax_use_flaxlib :
335
- import flaxlib
336
-
337
- jax .tree_util .register_static (flaxlib .VariableDef )
338
- globals ()['VariableDef' ] = flaxlib .VariableDef
337
+ maybe_use_flaxlib ('VariableDef' , register_static = True )
339
338
340
339
@jax .tree_util .register_static
341
340
@dataclasses .dataclass (frozen = True , repr = False , slots = True )
@@ -391,12 +390,7 @@ def __treescope_repr__(self, path, subtree_renderer):
391
390
subtree_renderer = subtree_renderer ,
392
391
)
393
392
394
-
395
- if config .flax_use_flaxlib :
396
- import flaxlib
397
-
398
- jax .tree_util .register_static (flaxlib .NodeDef )
399
- globals ()['NodeDef' ] = flaxlib .NodeDef
393
+ maybe_use_flaxlib ('NodeDef' , register_static = True )
400
394
401
395
402
396
@jax .tree_util .register_static
@@ -627,10 +621,10 @@ def _graph_flatten(
627
621
assert paths is not None
628
622
paths .append (tuple (path ))
629
623
variabledef = VariableDef (
630
- type = type (node ),
631
- index = index ,
632
- outer_index = ref_outer_index .get (node , None ) if ref_outer_index else None ,
633
- metadata = HashableMapping (node ._var_metadata ),
624
+ type (node ),
625
+ index ,
626
+ ref_outer_index .get (node , None ) if ref_outer_index else None ,
627
+ HashableMapping (node ._var_metadata ),
634
628
)
635
629
nodes .append (variabledef )
636
630
return
@@ -682,6 +676,111 @@ def _graph_flatten(
682
676
683
677
return
684
678
679
+ def _flatten_fast (
680
+ node : Node ,
681
+ / ,
682
+ * ,
683
+ ref_index : RefMap ,
684
+ ref_outer_index : RefMap | None ,
685
+ ) -> tuple [GraphDef [Node ], list [tp .Any ]]:
686
+ leaves : list [jax .Array | np .ndarray ] = []
687
+ nodes : list [NodeDef [tp .Any ] | VariableDef [tp .Any ] | NodeRef [tp .Any ]] = []
688
+ attributes : list [tuple [Key , NodeAttr | ArrayAttr | Static [tp .Any ]]] = []
689
+ node_impl = get_node_impl (node )
690
+ if node_impl is None and not isinstance (node , Variable ):
691
+ raise RuntimeError (f'Unsupported type: { type (node )} , this is a bug.' )
692
+ _graph_flatten_fast (
693
+ node ,
694
+ node_impl ,
695
+ ref_index ,
696
+ ref_outer_index ,
697
+ nodes ,
698
+ attributes ,
699
+ leaves ,
700
+ )
701
+ graphdef = GraphDef (
702
+ nodes = nodes , attributes = attributes , num_leaves = len (leaves )
703
+ )
704
+
705
+ return graphdef , leaves
706
+
707
+
708
+ def _graph_flatten_fast (
709
+ node : Node ,
710
+ node_impl : NodeImpl [Node , Leaf , AuxData ] | None ,
711
+ ref_index : RefMap ,
712
+ ref_outer_index : RefMap | None ,
713
+ nodes : list [NodeDef [tp .Any ] | VariableDef [tp .Any ] | NodeRef [tp .Any ]],
714
+ attributes : list [tuple [Key , NodeAttr | ArrayAttr | Static [tp .Any ]]],
715
+ leaves : list [jax .Array | np .ndarray ],
716
+ ) -> None :
717
+ is_pytree_node_ = type (node_impl ) is PytreeNodeImpl
718
+
719
+ if not is_pytree_node_ and node in ref_index :
720
+ nodes .append (NodeRef (index := ref_index [node ]))
721
+ return
722
+
723
+ is_graph_node_ = type (node_impl ) is GraphNodeImpl
724
+ is_variable = isinstance (node , Variable )
725
+
726
+ # only cache graph nodes
727
+ if is_pytree_node_ :
728
+ index = None
729
+ else :
730
+ index = len (ref_index )
731
+ ref_index [node ] = index
732
+
733
+ if is_variable :
734
+ assert isinstance (node , Variable )
735
+ assert index is not None
736
+ leaf = node .raw_value
737
+ leaves .append (leaf )
738
+ variabledef = VariableDef (
739
+ type (node ),
740
+ index ,
741
+ ref_outer_index .get (node , None ) if ref_outer_index else None ,
742
+ HashableMapping (node ._var_metadata ),
743
+ )
744
+ nodes .append (variabledef )
745
+ return
746
+
747
+ if node_impl is None :
748
+ raise RuntimeError (f'Unsupported type: { type (node )} , this is a bug.' )
749
+
750
+ values , metadata = node_impl .flatten (node )
751
+ num_attributes = len (values )
752
+ nodedef = NodeDef (
753
+ node_impl .type ,
754
+ index ,
755
+ ref_outer_index [node ]
756
+ if is_graph_node_ and ref_outer_index and node in ref_outer_index
757
+ else None ,
758
+ num_attributes ,
759
+ metadata ,
760
+ )
761
+ nodes .append (nodedef )
762
+
763
+ for key , value in values :
764
+ value_node_impl = get_node_impl (value )
765
+ if value_node_impl is not None or isinstance (value , Variable ):
766
+ attributes .append ((key , NODE_ATTR ))
767
+ _graph_flatten_fast (
768
+ value ,
769
+ value_node_impl ,
770
+ ref_index ,
771
+ ref_outer_index ,
772
+ nodes ,
773
+ attributes ,
774
+ leaves ,
775
+ )
776
+ elif isinstance (value , (jax .Array , np .ndarray )):
777
+ attributes .append ((key , ARRAY_ATTR ))
778
+ leaves .append (value )
779
+ else :
780
+ attributes .append ((key , Static (value )))
781
+
782
+ return
783
+
685
784
686
785
@dataclasses .dataclass (slots = True )
687
786
class FingerprintContext :
@@ -1285,15 +1384,15 @@ class GraphContext(threading.local):
1285
1384
)
1286
1385
ref_index_stack : list [SplitContext ] = dataclasses .field (default_factory = list )
1287
1386
index_ref_stack : list [MergeContext ] = dataclasses .field (default_factory = list )
1288
- tmp_static_cache : RefMap | None = None
1387
+ tmp_static_cache : PyRefMap | None = None
1289
1388
caching : bool = False
1290
1389
1291
1390
1292
1391
GRAPH_CONTEXT = GraphContext ()
1293
1392
1294
1393
1295
1394
@contextlib .contextmanager
1296
- def static_cache (static_cache : RefMap ):
1395
+ def static_cache (static_cache : PyRefMap ):
1297
1396
if GRAPH_CONTEXT .caching :
1298
1397
yield
1299
1398
return
@@ -1356,7 +1455,7 @@ def _cached_partial(f: tp.Callable[..., tp.Any], *cached_args):
1356
1455
Returns:
1357
1456
A partial function expecting the remaining arguments to the original function.
1358
1457
"""
1359
- cache : RefMap = RefMap ()
1458
+ cache = PyRefMap ()
1360
1459
original_ref_index : RefMap = RefMap ()
1361
1460
index_ref : IndexMap = IndexMap ()
1362
1461
cached_ref_index : RefMap = RefMap ()
@@ -1378,7 +1477,7 @@ def create_static_cache(x):
1378
1477
new_ref_index = cached_new_ref_index ,
1379
1478
)
1380
1479
cached_ref_index .update (cached_new_ref_index )
1381
- cache [node_cache ] = StaticCache .create (
1480
+ cache [node_cache ] = StaticCache .create ( # type: ignore
1382
1481
graphdef , paths , variables , cached_new_ref_index
1383
1482
)
1384
1483
return node_cache
@@ -1563,6 +1662,20 @@ def flatten(
1563
1662
else :
1564
1663
return graphdef , leaves
1565
1664
1665
+ def flatten_fast (self , node : A ) -> tuple [GraphDef [A ], list [tp .Any ]]:
1666
+ ctx = (
1667
+ current_update_context (self .ctxtag ) if self .ctxtag is not None else None
1668
+ )
1669
+ ref_outer_index = (
1670
+ ctx .inner_ref_outer_index if ctx and ctx .inner_ref_outer_index else None
1671
+ )
1672
+ graphdef , leaves = _flatten_fast (
1673
+ node ,
1674
+ ref_index = self .ref_index ,
1675
+ ref_outer_index = ref_outer_index ,
1676
+ )
1677
+ return graphdef , leaves
1678
+
1566
1679
1567
1680
@contextlib .contextmanager
1568
1681
def split_context (ctxtag : tp .Hashable | None = None ):
@@ -2620,9 +2733,9 @@ def _unflatten_pytree(
2620
2733
2621
2734
2622
2735
PYTREE_NODE_IMPL = PytreeNodeImpl (
2623
- type = GenericPytree ,
2624
- flatten = _flatten_pytree ,
2625
- unflatten = _unflatten_pytree , # type: ignore
2736
+ GenericPytree ,
2737
+ _flatten_pytree ,
2738
+ _unflatten_pytree , # type: ignore
2626
2739
)
2627
2740
2628
2741
# common pytrees
0 commit comments