Skip to content

Commit c1b9fc3

Browse files
committed
add fast_flatten
1 parent 1dddcfe commit c1b9fc3

File tree

5 files changed

+403
-116
lines changed

5 files changed

+403
-116
lines changed

flax/nnx/graph.py

Lines changed: 179 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,15 @@
5656
GraphState = State[Key, StateLeaf]
5757
GraphFlatState = FlatState[StateLeaf]
5858

59+
def maybe_use_flaxlib(name: str, register_static: bool = False):
60+
if config.flax_use_flaxlib:
61+
import flaxlib
62+
63+
flaxlib_version: type = getattr(flaxlib, name)
64+
globals()[name] = flaxlib_version
65+
66+
if register_static:
67+
jax.tree_util.register_static(flaxlib_version)
5968

6069
def is_state_leaf(x: tp.Any) -> tpe.TypeGuard[StateLeaf]:
6170
return isinstance(x, VariableState)
@@ -71,10 +80,7 @@ def from_refmap(refmap: RefMap) -> IndexMap:
7180
indexmap.update((index, value) for value, index in refmap.items())
7281
return indexmap
7382

74-
if config.flax_use_flaxlib:
75-
import flaxlib
76-
77-
globals()['IndexMap'] = flaxlib.IndexMap
83+
maybe_use_flaxlib('IndexMap')
7884

7985

8086
# RefMap = dict
@@ -120,11 +126,12 @@ def __iter__(self) -> tp.Iterator[tp.Any]:
120126
def items(self) -> tp.ItemsView[tp.Any, int]:
121127
return self._mapping.values() # type: ignore
122128

129+
# TODO(cgarciae): PyRefMap is currently being used for
130+
# cached_partial, but it should be removed when cached_partial
131+
# is no longer needed.
132+
PyRefMap = RefMap
123133

124-
if config.flax_use_flaxlib:
125-
import flaxlib
126-
127-
globals()['RefMap'] = flaxlib.RefMap
134+
maybe_use_flaxlib('RefMap')
128135

129136

130137
@dataclasses.dataclass(frozen=True, slots=True)
@@ -136,28 +143,33 @@ def node_dict(self, node: Node) -> dict[Key, Leaf]:
136143
nodes, _ = self.flatten(node)
137144
return dict(nodes)
138145

146+
maybe_use_flaxlib('NodeImplBase')
139147

140148
@dataclasses.dataclass(frozen=True, slots=True)
141-
class GraphNodeImpl(NodeImplBase[Node, Leaf, AuxData]):
142-
set_key: tp.Callable[[Node, Key, Leaf], None]
143-
pop_key: tp.Callable[[Node, Key], Leaf]
144-
create_empty: tp.Callable[[AuxData], Node]
145-
clear: tp.Callable[[Node], None]
146-
init: tp.Callable[[Node, tp.Iterable[tuple[Key, Leaf]]], None]
149+
class GraphNodeImpl(NodeImplBase):
150+
set_key: tp.Callable[[tp.Any, Key, tp.Any], None]
151+
pop_key: tp.Callable[[tp.Any, Key], tp.Any]
152+
create_empty: tp.Callable[[tp.Any], tp.Any]
153+
clear: tp.Callable[[tp.Any], None]
154+
init: tp.Callable[[tp.Any, tp.Iterable[tuple[Key, tp.Any]]], None]
155+
156+
157+
maybe_use_flaxlib('GraphNodeImpl')
147158

148159

149160
@dataclasses.dataclass(frozen=True, slots=True)
150-
class PytreeNodeImpl(NodeImplBase[Node, Leaf, AuxData]):
151-
unflatten: tp.Callable[[tp.Sequence[tuple[Key, Leaf]], AuxData], Node]
161+
class PytreeNodeImpl(NodeImplBase):
162+
unflatten: tp.Callable[[tp.Sequence[tuple[Key, tp.Any]], tp.Any], tp.Any]
163+
164+
165+
maybe_use_flaxlib('PytreeNodeImpl')
152166

153167

154-
NodeImpl = tp.Union[
155-
GraphNodeImpl[Node, Leaf, AuxData], PytreeNodeImpl[Node, Leaf, AuxData]
156-
]
168+
NodeImpl = tp.Union[GraphNodeImpl, PytreeNodeImpl]
157169

158170

159-
GRAPH_REGISTRY: dict[type, NodeImpl[tp.Any, tp.Any, tp.Any]] = {}
160-
PYTREE_REGISTRY: dict[type, PytreeNodeImpl[tp.Any, tp.Any, tp.Any]] = {}
171+
GRAPH_REGISTRY: dict[type, NodeImpl] = {}
172+
PYTREE_REGISTRY: dict[type, PytreeNodeImpl] = {}
161173

162174

163175
def register_graph_node_type(
@@ -173,13 +185,13 @@ def register_graph_node_type(
173185
raise ValueError(f'Node type {type} is already registered.')
174186

175187
GRAPH_REGISTRY[type] = GraphNodeImpl(
176-
type=type,
177-
flatten=flatten,
178-
set_key=set_key,
179-
pop_key=pop_key,
180-
create_empty=create_empty,
181-
clear=clear,
182-
init=init,
188+
type,
189+
flatten,
190+
set_key,
191+
pop_key,
192+
create_empty,
193+
clear,
194+
init,
183195
)
184196

185197

@@ -191,9 +203,7 @@ def register_pytree_node_type(
191203
if type in PYTREE_REGISTRY:
192204
raise ValueError(f'Node type {type} is already registered.')
193205

194-
PYTREE_REGISTRY[type] = PytreeNodeImpl(
195-
type=type, flatten=flatten, unflatten=unflatten
196-
)
206+
PYTREE_REGISTRY[type] = PytreeNodeImpl(type, flatten, unflatten)
197207

198208

199209
def is_node(x: tp.Any) -> bool:
@@ -210,16 +220,13 @@ def is_node_type(x: type[tp.Any]) -> bool:
210220
return x in GRAPH_REGISTRY or x in PYTREE_REGISTRY or x is GenericPytree
211221

212222

213-
def get_node_impl(x: Node) -> NodeImpl[Node, tp.Any, tp.Any] | None:
214-
if isinstance(x, Variable):
215-
return None
216-
223+
def get_node_impl(x) -> NodeImpl | None:
217224
node_type = type(x)
218225

219-
if node_type in GRAPH_REGISTRY:
220-
return GRAPH_REGISTRY[node_type]
221-
elif node_type in PYTREE_REGISTRY:
222-
return PYTREE_REGISTRY[node_type]
226+
if node_impl := GRAPH_REGISTRY.get(node_type):
227+
return node_impl
228+
elif node_impl := PYTREE_REGISTRY.get(node_type):
229+
return node_impl
223230
elif node_type in JAX_PYTREE_REGISTRY or issubclass(node_type, tuple):
224231
return PYTREE_NODE_IMPL # type: ignore
225232
else:
@@ -285,11 +292,7 @@ def __treescope_repr__(self, path, subtree_renderer):
285292
subtree_renderer=subtree_renderer,
286293
)
287294

288-
if config.flax_use_flaxlib:
289-
import flaxlib
290-
291-
jax.tree_util.register_static(flaxlib.NodeRef)
292-
globals()['NodeRef'] = flaxlib.NodeRef
295+
maybe_use_flaxlib('NodeRef', register_static=True)
293296

294297
@jax.tree_util.register_static
295298
@dataclasses.dataclass(frozen=True, repr=False)
@@ -331,11 +334,7 @@ def __treescope_repr__(self, path, subtree_renderer):
331334
subtree_renderer=subtree_renderer,
332335
)
333336

334-
if config.flax_use_flaxlib:
335-
import flaxlib
336-
337-
jax.tree_util.register_static(flaxlib.VariableDef)
338-
globals()['VariableDef'] = flaxlib.VariableDef
337+
maybe_use_flaxlib('VariableDef', register_static=True)
339338

340339
@jax.tree_util.register_static
341340
@dataclasses.dataclass(frozen=True, repr=False, slots=True)
@@ -391,12 +390,7 @@ def __treescope_repr__(self, path, subtree_renderer):
391390
subtree_renderer=subtree_renderer,
392391
)
393392

394-
395-
if config.flax_use_flaxlib:
396-
import flaxlib
397-
398-
jax.tree_util.register_static(flaxlib.NodeDef)
399-
globals()['NodeDef'] = flaxlib.NodeDef
393+
maybe_use_flaxlib('NodeDef', register_static=True)
400394

401395

402396
@jax.tree_util.register_static
@@ -627,10 +621,10 @@ def _graph_flatten(
627621
assert paths is not None
628622
paths.append(tuple(path))
629623
variabledef = VariableDef(
630-
type=type(node),
631-
index=index,
632-
outer_index=ref_outer_index.get(node, None) if ref_outer_index else None,
633-
metadata=HashableMapping(node._var_metadata),
624+
type(node),
625+
index,
626+
ref_outer_index.get(node, None) if ref_outer_index else None,
627+
HashableMapping(node._var_metadata),
634628
)
635629
nodes.append(variabledef)
636630
return
@@ -682,6 +676,111 @@ def _graph_flatten(
682676

683677
return
684678

679+
def _flatten_fast(
680+
node: Node,
681+
/,
682+
*,
683+
ref_index: RefMap,
684+
ref_outer_index: RefMap | None,
685+
) -> tuple[GraphDef[Node], list[tp.Any]]:
686+
leaves: list[jax.Array | np.ndarray] = []
687+
nodes: list[NodeDef[tp.Any] | VariableDef[tp.Any] | NodeRef[tp.Any]] = []
688+
attributes: list[tuple[Key, NodeAttr | ArrayAttr | Static[tp.Any]]] = []
689+
node_impl = get_node_impl(node)
690+
if node_impl is None and not isinstance(node, Variable):
691+
raise RuntimeError(f'Unsupported type: {type(node)}, this is a bug.')
692+
_graph_flatten_fast(
693+
node,
694+
node_impl,
695+
ref_index,
696+
ref_outer_index,
697+
nodes,
698+
attributes,
699+
leaves,
700+
)
701+
graphdef = GraphDef(
702+
nodes=nodes, attributes=attributes, num_leaves=len(leaves)
703+
)
704+
705+
return graphdef, leaves
706+
707+
708+
def _graph_flatten_fast(
709+
node: Node,
710+
node_impl: NodeImpl[Node, Leaf, AuxData] | None,
711+
ref_index: RefMap,
712+
ref_outer_index: RefMap | None,
713+
nodes: list[NodeDef[tp.Any] | VariableDef[tp.Any] | NodeRef[tp.Any]],
714+
attributes: list[tuple[Key, NodeAttr | ArrayAttr | Static[tp.Any]]],
715+
leaves: list[jax.Array | np.ndarray],
716+
) -> None:
717+
is_pytree_node_ = type(node_impl) is PytreeNodeImpl
718+
719+
if not is_pytree_node_ and node in ref_index:
720+
nodes.append(NodeRef(index := ref_index[node]))
721+
return
722+
723+
is_graph_node_ = type(node_impl) is GraphNodeImpl
724+
is_variable = isinstance(node, Variable)
725+
726+
# only cache graph nodes
727+
if is_pytree_node_:
728+
index = None
729+
else:
730+
index = len(ref_index)
731+
ref_index[node] = index
732+
733+
if is_variable:
734+
assert isinstance(node, Variable)
735+
assert index is not None
736+
leaf = node.raw_value
737+
leaves.append(leaf)
738+
variabledef = VariableDef(
739+
type(node),
740+
index,
741+
ref_outer_index.get(node, None) if ref_outer_index else None,
742+
HashableMapping(node._var_metadata),
743+
)
744+
nodes.append(variabledef)
745+
return
746+
747+
if node_impl is None:
748+
raise RuntimeError(f'Unsupported type: {type(node)}, this is a bug.')
749+
750+
values, metadata = node_impl.flatten(node)
751+
num_attributes = len(values)
752+
nodedef = NodeDef(
753+
node_impl.type,
754+
index,
755+
ref_outer_index[node]
756+
if is_graph_node_ and ref_outer_index and node in ref_outer_index
757+
else None,
758+
num_attributes,
759+
metadata,
760+
)
761+
nodes.append(nodedef)
762+
763+
for key, value in values:
764+
value_node_impl = get_node_impl(value)
765+
if value_node_impl is not None or isinstance(value, Variable):
766+
attributes.append((key, NODE_ATTR))
767+
_graph_flatten_fast(
768+
value,
769+
value_node_impl,
770+
ref_index,
771+
ref_outer_index,
772+
nodes,
773+
attributes,
774+
leaves,
775+
)
776+
elif isinstance(value, (jax.Array, np.ndarray)):
777+
attributes.append((key, ARRAY_ATTR))
778+
leaves.append(value)
779+
else:
780+
attributes.append((key, Static(value)))
781+
782+
return
783+
685784

686785
@dataclasses.dataclass(slots=True)
687786
class FingerprintContext:
@@ -1285,15 +1384,15 @@ class GraphContext(threading.local):
12851384
)
12861385
ref_index_stack: list[SplitContext] = dataclasses.field(default_factory=list)
12871386
index_ref_stack: list[MergeContext] = dataclasses.field(default_factory=list)
1288-
tmp_static_cache: RefMap | None = None
1387+
tmp_static_cache: PyRefMap | None = None
12891388
caching: bool = False
12901389

12911390

12921391
GRAPH_CONTEXT = GraphContext()
12931392

12941393

12951394
@contextlib.contextmanager
1296-
def static_cache(static_cache: RefMap):
1395+
def static_cache(static_cache: PyRefMap):
12971396
if GRAPH_CONTEXT.caching:
12981397
yield
12991398
return
@@ -1356,7 +1455,7 @@ def _cached_partial(f: tp.Callable[..., tp.Any], *cached_args):
13561455
Returns:
13571456
A partial function expecting the remaining arguments to the original function.
13581457
"""
1359-
cache: RefMap = RefMap()
1458+
cache = PyRefMap()
13601459
original_ref_index: RefMap = RefMap()
13611460
index_ref: IndexMap = IndexMap()
13621461
cached_ref_index: RefMap = RefMap()
@@ -1378,7 +1477,7 @@ def create_static_cache(x):
13781477
new_ref_index=cached_new_ref_index,
13791478
)
13801479
cached_ref_index.update(cached_new_ref_index)
1381-
cache[node_cache] = StaticCache.create(
1480+
cache[node_cache] = StaticCache.create( # type: ignore
13821481
graphdef, paths, variables, cached_new_ref_index
13831482
)
13841483
return node_cache
@@ -1563,6 +1662,20 @@ def flatten(
15631662
else:
15641663
return graphdef, leaves
15651664

1665+
def flatten_fast(self, node: A) -> tuple[GraphDef[A], list[tp.Any]]:
1666+
ctx = (
1667+
current_update_context(self.ctxtag) if self.ctxtag is not None else None
1668+
)
1669+
ref_outer_index = (
1670+
ctx.inner_ref_outer_index if ctx and ctx.inner_ref_outer_index else None
1671+
)
1672+
graphdef, leaves = _flatten_fast(
1673+
node,
1674+
ref_index=self.ref_index,
1675+
ref_outer_index=ref_outer_index,
1676+
)
1677+
return graphdef, leaves
1678+
15661679

15671680
@contextlib.contextmanager
15681681
def split_context(ctxtag: tp.Hashable | None = None):
@@ -2620,9 +2733,9 @@ def _unflatten_pytree(
26202733

26212734

26222735
PYTREE_NODE_IMPL = PytreeNodeImpl(
2623-
type=GenericPytree,
2624-
flatten=_flatten_pytree,
2625-
unflatten=_unflatten_pytree, # type: ignore
2736+
GenericPytree,
2737+
_flatten_pytree,
2738+
_unflatten_pytree, # type: ignore
26262739
)
26272740

26282741
# common pytrees

flax/nnx/transforms/compilation.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,10 +95,13 @@ def __hash__(self):
9595

9696

9797
def _jit_split_fn(ctx: graph.SplitContext, path, prefix, x):
98-
if isinstance(prefix, StateSharding):
98+
if type(prefix) is StateSharding:
9999
graphdef, *states = ctx.flatten(x, *prefix.filters)
100100
return extract.NodeStates.from_split(graphdef, *states, metadata=prefix)
101-
return extract.NodeStates.from_split(*ctx.flatten(x, with_paths=False))
101+
elif graph.GRAPH_CONTEXT.caching:
102+
return extract.NodeStates.from_split(*ctx.flatten(x, with_paths=False))
103+
else:
104+
return extract.NodeStates.from_split(*ctx.flatten_fast(x))
102105

103106

104107
def _jit_merge_fn(ctx: graph.MergeContext, path, prefix, leaf) -> tp.Any:

0 commit comments

Comments
 (0)