Skip to content

Commit 2a5df33

Browse files
committed
[nnx] refactor GraphDef
1 parent aa9d572 commit 2a5df33

File tree

12 files changed

+388
-323
lines changed

12 files changed

+388
-323
lines changed

benchmarks/nnx_graph_overhead.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -97,11 +97,9 @@ def main(argv):
9797
def step_nnx(model: MLP, optimizer: nnx.Optimizer):
9898
pass
9999

100-
cached_step_nnx = nnx.cached_partial(step_nnx, model, optimizer)
101-
102100
t0 = time()
103101
for _ in range(total_steps):
104-
cached_step_nnx()
102+
step_nnx(model, optimizer)
105103

106104
total_time = time() - t0
107105
time_per_step = total_time / total_steps

flax/nnx/bridge/module.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,7 @@ def _graph_node_flatten(self):
255255
PriorityStr(self.attr_priorities.get(k, AttrPriority.DEFAULT), k)
256256
for k in nodes.keys()
257257
)
258-
sorted_nodes = ((k, nodes[k]) for k in sorted(keys))
258+
sorted_nodes = list((k, nodes[k]) for k in sorted(keys))
259259
return sorted_nodes, type(self)
260260

261261
def set_attr_priority(self, name: str, value: AttrPriority):

flax/nnx/bridge/variables.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ def nnx_attrs_to_linen_vars(nnx_attrs: dict) -> dict:
154154
elif isinstance(v, variablelib.VariableState):
155155
col_name = variablelib.variable_name_from_type(v.type)
156156
v = to_linen_var(v)
157-
elif isinstance(v, graph.NodeDef) or isinstance(v, graph.NodeRef):
157+
elif isinstance(v, graph.GraphDef):
158158
col_name = 'nnx' # an nnx.GraphDef for some ToLinen submodule
159159
else:
160160
raise ValueError(f'Cannot infer collection name from value: {v}')

flax/nnx/bridge/wrappers.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
@dataclasses.dataclass
3939
class Functional(tp.Generic[M]):
4040
module_type: tp.Type[M]
41-
graphdef: tp.Optional[graph.NodeDef[M]]
41+
graphdef: tp.Optional[graph.GraphDef[M]]
4242
args: tuple[tp.Any, ...]
4343
kwargs: dict[str, tp.Any]
4444

@@ -48,7 +48,6 @@ def init(self, *, rngs: tp.Optional[Rngs] = None) -> State:
4848
kwargs['rngs'] = rngs
4949
module = self.module_type(*self.args, **self.kwargs, **kwargs)
5050
graphdef, state = nnx.split(module)
51-
assert type(graphdef) is graph.NodeDef
5251
self.graphdef = graphdef
5352
return state # type: ignore
5453

@@ -217,7 +216,7 @@ class ToLinen(linen.Module):
217216
>>> variables.keys()
218217
dict_keys(['nnx', 'params'])
219218
>>> type(variables['nnx']['graphdef'])
220-
<class 'flax.nnx.graph.NodeDef'>
219+
<class 'flax.nnx.graph.GraphDef'>
221220
222221
Args:
223222
nnx_class: The NNX Module class (not instance!).

0 commit comments

Comments
 (0)