Skip to content

Commit 64b42f8

Browse files
committed
add fast_flatten
1 parent 8bf63ad commit 64b42f8

File tree

5 files changed

+403
-116
lines changed

5 files changed

+403
-116
lines changed

flax/nnx/graph.py

Lines changed: 179 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,15 @@
5656
GraphState = State[Key, StateLeaf]
5757
GraphFlatState = FlatState[StateLeaf]
5858

59+
def maybe_use_flaxlib(name: str, register_static: bool = False):
60+
if config.flax_use_flaxlib:
61+
import flaxlib
62+
63+
flaxlib_version: type = getattr(flaxlib, name)
64+
globals()[name] = flaxlib_version
65+
66+
if register_static:
67+
jax.tree_util.register_static(flaxlib_version)
5968

6069
def is_state_leaf(x: tp.Any) -> tpe.TypeGuard[StateLeaf]:
6170
return isinstance(x, VariableState)
@@ -71,10 +80,7 @@ def from_refmap(refmap: RefMap) -> IndexMap:
7180
indexmap.update((index, value) for value, index in refmap.items())
7281
return indexmap
7382

74-
if config.flax_use_flaxlib:
75-
import flaxlib
76-
77-
globals()['IndexMap'] = flaxlib.IndexMap
83+
maybe_use_flaxlib('IndexMap')
7884

7985

8086
# RefMap = dict
@@ -120,11 +126,12 @@ def __iter__(self) -> tp.Iterator[tp.Any]:
120126
def items(self) -> tp.ItemsView[tp.Any, int]:
121127
return self._mapping.values() # type: ignore
122128

129+
# TODO(cgarciae): PyRefMap is currently being used for
130+
# cached_partial, but it should be removed when cached_partial
131+
# is no longer needed.
132+
PyRefMap = RefMap
123133

124-
if config.flax_use_flaxlib:
125-
import flaxlib
126-
127-
globals()['RefMap'] = flaxlib.RefMap
134+
maybe_use_flaxlib('RefMap')
128135

129136

130137
@dataclasses.dataclass(frozen=True, slots=True)
@@ -136,28 +143,33 @@ def node_dict(self, node: Node) -> dict[Key, Leaf]:
136143
nodes, _ = self.flatten(node)
137144
return dict(nodes)
138145

146+
maybe_use_flaxlib('NodeImplBase')
139147

140148
@dataclasses.dataclass(frozen=True, slots=True)
141-
class GraphNodeImpl(NodeImplBase[Node, Leaf, AuxData]):
142-
set_key: tp.Callable[[Node, Key, Leaf], None]
143-
pop_key: tp.Callable[[Node, Key], Leaf]
144-
create_empty: tp.Callable[[AuxData], Node]
145-
clear: tp.Callable[[Node], None]
146-
init: tp.Callable[[Node, tp.Iterable[tuple[Key, Leaf]]], None]
149+
class GraphNodeImpl(NodeImplBase):
150+
set_key: tp.Callable[[tp.Any, Key, tp.Any], None]
151+
pop_key: tp.Callable[[tp.Any, Key], tp.Any]
152+
create_empty: tp.Callable[[tp.Any], tp.Any]
153+
clear: tp.Callable[[tp.Any], None]
154+
init: tp.Callable[[tp.Any, tp.Iterable[tuple[Key, tp.Any]]], None]
155+
156+
157+
maybe_use_flaxlib('GraphNodeImpl')
147158

148159

149160
@dataclasses.dataclass(frozen=True, slots=True)
150-
class PytreeNodeImpl(NodeImplBase[Node, Leaf, AuxData]):
151-
unflatten: tp.Callable[[tp.Sequence[tuple[Key, Leaf]], AuxData], Node]
161+
class PytreeNodeImpl(NodeImplBase):
162+
unflatten: tp.Callable[[tp.Sequence[tuple[Key, tp.Any]], tp.Any], tp.Any]
163+
164+
165+
maybe_use_flaxlib('PytreeNodeImpl')
152166

153167

154-
NodeImpl = tp.Union[
155-
GraphNodeImpl[Node, Leaf, AuxData], PytreeNodeImpl[Node, Leaf, AuxData]
156-
]
168+
NodeImpl = tp.Union[GraphNodeImpl, PytreeNodeImpl]
157169

158170

159-
GRAPH_REGISTRY: dict[type, NodeImpl[tp.Any, tp.Any, tp.Any]] = {}
160-
PYTREE_REGISTRY: dict[type, PytreeNodeImpl[tp.Any, tp.Any, tp.Any]] = {}
171+
GRAPH_REGISTRY: dict[type, NodeImpl] = {}
172+
PYTREE_REGISTRY: dict[type, PytreeNodeImpl] = {}
161173

162174

163175
def register_graph_node_type(
@@ -173,13 +185,13 @@ def register_graph_node_type(
173185
raise ValueError(f'Node type {type} is already registered.')
174186

175187
GRAPH_REGISTRY[type] = GraphNodeImpl(
176-
type=type,
177-
flatten=flatten,
178-
set_key=set_key,
179-
pop_key=pop_key,
180-
create_empty=create_empty,
181-
clear=clear,
182-
init=init,
188+
type,
189+
flatten,
190+
set_key,
191+
pop_key,
192+
create_empty,
193+
clear,
194+
init,
183195
)
184196

185197

@@ -191,9 +203,7 @@ def register_pytree_node_type(
191203
if type in PYTREE_REGISTRY:
192204
raise ValueError(f'Node type {type} is already registered.')
193205

194-
PYTREE_REGISTRY[type] = PytreeNodeImpl(
195-
type=type, flatten=flatten, unflatten=unflatten
196-
)
206+
PYTREE_REGISTRY[type] = PytreeNodeImpl(type, flatten, unflatten)
197207

198208

199209
def is_node(x: tp.Any) -> bool:
@@ -210,16 +220,13 @@ def is_node_type(x: type[tp.Any]) -> bool:
210220
return x in GRAPH_REGISTRY or x in PYTREE_REGISTRY or x is GenericPytree
211221

212222

213-
def get_node_impl(x: Node) -> NodeImpl[Node, tp.Any, tp.Any] | None:
214-
if isinstance(x, Variable):
215-
return None
216-
223+
def get_node_impl(x) -> NodeImpl | None:
217224
node_type = type(x)
218225

219-
if node_type in GRAPH_REGISTRY:
220-
return GRAPH_REGISTRY[node_type]
221-
elif node_type in PYTREE_REGISTRY:
222-
return PYTREE_REGISTRY[node_type]
226+
if node_impl := GRAPH_REGISTRY.get(node_type):
227+
return node_impl
228+
elif node_impl := PYTREE_REGISTRY.get(node_type):
229+
return node_impl
223230
elif node_type in JAX_PYTREE_REGISTRY or issubclass(node_type, tuple):
224231
return PYTREE_NODE_IMPL # type: ignore
225232
else:
@@ -285,11 +292,7 @@ def __treescope_repr__(self, path, subtree_renderer):
285292
subtree_renderer=subtree_renderer,
286293
)
287294

288-
if config.flax_use_flaxlib:
289-
import flaxlib
290-
291-
jax.tree_util.register_static(flaxlib.NodeRef)
292-
globals()['NodeRef'] = flaxlib.NodeRef
295+
maybe_use_flaxlib('NodeRef', register_static=True)
293296

294297
@jax.tree_util.register_static
295298
@dataclasses.dataclass(frozen=True, repr=False)
@@ -331,11 +334,7 @@ def __treescope_repr__(self, path, subtree_renderer):
331334
subtree_renderer=subtree_renderer,
332335
)
333336

334-
if config.flax_use_flaxlib:
335-
import flaxlib
336-
337-
jax.tree_util.register_static(flaxlib.VariableDef)
338-
globals()['VariableDef'] = flaxlib.VariableDef
337+
maybe_use_flaxlib('VariableDef', register_static=True)
339338

340339
@jax.tree_util.register_static
341340
@dataclasses.dataclass(frozen=True, repr=False, slots=True)
@@ -391,12 +390,7 @@ def __treescope_repr__(self, path, subtree_renderer):
391390
subtree_renderer=subtree_renderer,
392391
)
393392

394-
395-
if config.flax_use_flaxlib:
396-
import flaxlib
397-
398-
jax.tree_util.register_static(flaxlib.NodeDef)
399-
globals()['NodeDef'] = flaxlib.NodeDef
393+
maybe_use_flaxlib('NodeDef', register_static=True)
400394

401395

402396
@jax.tree_util.register_static
@@ -628,10 +622,10 @@ def _graph_flatten(
628622
assert paths is not None
629623
paths.append(tuple(path))
630624
variabledef = VariableDef(
631-
type=type(node),
632-
index=index,
633-
outer_index=ref_outer_index.get(node, None) if ref_outer_index else None,
634-
metadata=HashableMapping(node._var_metadata),
625+
type(node),
626+
index,
627+
ref_outer_index.get(node, None) if ref_outer_index else None,
628+
HashableMapping(node._var_metadata),
635629
)
636630
nodes.append(variabledef)
637631
return
@@ -683,6 +677,111 @@ def _graph_flatten(
683677

684678
return
685679

680+
def _flatten_fast(
681+
node: Node,
682+
/,
683+
*,
684+
ref_index: RefMap,
685+
ref_outer_index: RefMap | None,
686+
) -> tuple[GraphDef[Node], list[tp.Any]]:
687+
leaves: list[jax.Array | np.ndarray] = []
688+
nodes: list[NodeDef[tp.Any] | VariableDef[tp.Any] | NodeRef[tp.Any]] = []
689+
attributes: list[tuple[Key, NodeAttr | ArrayAttr | Static[tp.Any]]] = []
690+
node_impl = get_node_impl(node)
691+
if node_impl is None and not isinstance(node, Variable):
692+
raise RuntimeError(f'Unsupported type: {type(node)}, this is a bug.')
693+
_graph_flatten_fast(
694+
node,
695+
node_impl,
696+
ref_index,
697+
ref_outer_index,
698+
nodes,
699+
attributes,
700+
leaves,
701+
)
702+
graphdef = GraphDef(
703+
nodes=nodes, attributes=attributes, num_leaves=len(leaves)
704+
)
705+
706+
return graphdef, leaves
707+
708+
709+
def _graph_flatten_fast(
710+
node: Node,
711+
node_impl: NodeImpl[Node, Leaf, AuxData] | None,
712+
ref_index: RefMap,
713+
ref_outer_index: RefMap | None,
714+
nodes: list[NodeDef[tp.Any] | VariableDef[tp.Any] | NodeRef[tp.Any]],
715+
attributes: list[tuple[Key, NodeAttr | ArrayAttr | Static[tp.Any]]],
716+
leaves: list[jax.Array | np.ndarray],
717+
) -> None:
718+
is_pytree_node_ = type(node_impl) is PytreeNodeImpl
719+
720+
if not is_pytree_node_ and node in ref_index:
721+
nodes.append(NodeRef(index := ref_index[node]))
722+
return
723+
724+
is_graph_node_ = type(node_impl) is GraphNodeImpl
725+
is_variable = isinstance(node, Variable)
726+
727+
# only cache graph nodes
728+
if is_pytree_node_:
729+
index = None
730+
else:
731+
index = len(ref_index)
732+
ref_index[node] = index
733+
734+
if is_variable:
735+
assert isinstance(node, Variable)
736+
assert index is not None
737+
leaf = node.raw_value
738+
leaves.append(leaf)
739+
variabledef = VariableDef(
740+
type(node),
741+
index,
742+
ref_outer_index.get(node, None) if ref_outer_index else None,
743+
HashableMapping(node._var_metadata),
744+
)
745+
nodes.append(variabledef)
746+
return
747+
748+
if node_impl is None:
749+
raise RuntimeError(f'Unsupported type: {type(node)}, this is a bug.')
750+
751+
values, metadata = node_impl.flatten(node)
752+
num_attributes = len(values)
753+
nodedef = NodeDef(
754+
node_impl.type,
755+
index,
756+
ref_outer_index[node]
757+
if is_graph_node_ and ref_outer_index and node in ref_outer_index
758+
else None,
759+
num_attributes,
760+
metadata,
761+
)
762+
nodes.append(nodedef)
763+
764+
for key, value in values:
765+
value_node_impl = get_node_impl(value)
766+
if value_node_impl is not None or isinstance(value, Variable):
767+
attributes.append((key, NODE_ATTR))
768+
_graph_flatten_fast(
769+
value,
770+
value_node_impl,
771+
ref_index,
772+
ref_outer_index,
773+
nodes,
774+
attributes,
775+
leaves,
776+
)
777+
elif isinstance(value, (jax.Array, np.ndarray)):
778+
attributes.append((key, ARRAY_ATTR))
779+
leaves.append(value)
780+
else:
781+
attributes.append((key, Static(value)))
782+
783+
return
784+
686785

687786
@dataclasses.dataclass(slots=True)
688787
class FingerprintContext:
@@ -1286,15 +1385,15 @@ class GraphContext(threading.local):
12861385
)
12871386
ref_index_stack: list[SplitContext] = dataclasses.field(default_factory=list)
12881387
index_ref_stack: list[MergeContext] = dataclasses.field(default_factory=list)
1289-
tmp_static_cache: RefMap | None = None
1388+
tmp_static_cache: PyRefMap | None = None
12901389
caching: bool = False
12911390

12921391

12931392
GRAPH_CONTEXT = GraphContext()
12941393

12951394

12961395
@contextlib.contextmanager
1297-
def static_cache(static_cache: RefMap):
1396+
def static_cache(static_cache: PyRefMap):
12981397
if GRAPH_CONTEXT.caching:
12991398
yield
13001399
return
@@ -1357,7 +1456,7 @@ def _cached_partial(f: tp.Callable[..., tp.Any], *cached_args):
13571456
Returns:
13581457
A partial function expecting the remaining arguments to the original function.
13591458
"""
1360-
cache: RefMap = RefMap()
1459+
cache = PyRefMap()
13611460
original_ref_index: RefMap = RefMap()
13621461
index_ref: IndexMap = IndexMap()
13631462
cached_ref_index: RefMap = RefMap()
@@ -1379,7 +1478,7 @@ def create_static_cache(x):
13791478
new_ref_index=cached_new_ref_index,
13801479
)
13811480
cached_ref_index.update(cached_new_ref_index)
1382-
cache[node_cache] = StaticCache.create(
1481+
cache[node_cache] = StaticCache.create( # type: ignore
13831482
graphdef, paths, variables, cached_new_ref_index
13841483
)
13851484
return node_cache
@@ -1564,6 +1663,20 @@ def flatten(
15641663
else:
15651664
return graphdef, leaves
15661665

1666+
def flatten_fast(self, node: A) -> tuple[GraphDef[A], list[tp.Any]]:
1667+
ctx = (
1668+
current_update_context(self.ctxtag) if self.ctxtag is not None else None
1669+
)
1670+
ref_outer_index = (
1671+
ctx.inner_ref_outer_index if ctx and ctx.inner_ref_outer_index else None
1672+
)
1673+
graphdef, leaves = _flatten_fast(
1674+
node,
1675+
ref_index=self.ref_index,
1676+
ref_outer_index=ref_outer_index,
1677+
)
1678+
return graphdef, leaves
1679+
15671680

15681681
@contextlib.contextmanager
15691682
def split_context(ctxtag: tp.Hashable | None = None):
@@ -2621,9 +2734,9 @@ def _unflatten_pytree(
26212734

26222735

26232736
PYTREE_NODE_IMPL = PytreeNodeImpl(
2624-
type=GenericPytree,
2625-
flatten=_flatten_pytree,
2626-
unflatten=_unflatten_pytree, # type: ignore
2737+
GenericPytree,
2738+
_flatten_pytree,
2739+
_unflatten_pytree, # type: ignore
26272740
)
26282741

26292742
# common pytrees

flax/nnx/transforms/compilation.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,10 +95,13 @@ def __hash__(self):
9595

9696

9797
def _jit_split_fn(ctx: graph.SplitContext, path, prefix, x):
98-
if isinstance(prefix, StateSharding):
98+
if type(prefix) is StateSharding:
9999
graphdef, *states = ctx.flatten(x, *prefix.filters)
100100
return extract.NodeStates.from_split(graphdef, *states, metadata=prefix)
101-
return extract.NodeStates.from_split(*ctx.flatten(x, with_paths=False))
101+
elif graph.GRAPH_CONTEXT.caching:
102+
return extract.NodeStates.from_split(*ctx.flatten(x, with_paths=False))
103+
else:
104+
return extract.NodeStates.from_split(*ctx.flatten_fast(x))
102105

103106

104107
def _jit_merge_fn(ctx: graph.MergeContext, path, prefix, leaf) -> tp.Any:

0 commit comments

Comments
 (0)