Skip to content

Commit a7157cf

Browse files
author
Flax Authors
committed
Merge pull request #4715 from google:nnx-mutable-array-p1
PiperOrigin-RevId: 756050354
2 parents 12a29ec + 6ab5f83 commit a7157cf

36 files changed

+1429
-2187
lines changed

docs_nnx/conf.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@
139139

140140
# -- Options for myst ----------------------------------------------
141141
# uncomment line below to avoid running notebooks during development
142-
nb_execution_mode = 'off'
142+
# nb_execution_mode = 'off'
143143
# Notebook cell execution timeout; defaults to 30.
144144
nb_execution_timeout = 100
145145
# List of patterns, relative to source directory, that match notebook
@@ -151,6 +151,7 @@
151151
'flax/nnx', # exclude nnx
152152
'guides/demo.ipynb', # TODO(cgarciae): broken, remove or update
153153
'guides/gemma.ipynb',
154+
'guides/bridge_guide.ipynb', # TODO(cgarciae): broken, bridge doesn't support Linen sow yet
154155
]
155156
# raise exceptions on execution so CI can catch errors
156157
nb_execution_allow_errors = False

docs_nnx/guides/bridge_guide.ipynb

Lines changed: 126 additions & 141 deletions
Large diffs are not rendered by default.

docs_nnx/guides/bridge_guide.md

Lines changed: 48 additions & 142 deletions
Large diffs are not rendered by default.

docs_nnx/guides/filters_guide.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -284,7 +284,7 @@
284284
" predicates = [nnx.filterlib.to_predicate(f) for f in filters]\n",
285285
" flat_states: list[dict[KeyPath, Any]] = [{} for p in predicates]\n",
286286
"\n",
287-
" for path, value in state.flat_state():\n",
287+
" for path, value in state:\n",
288288
" for i, predicate in enumerate(predicates):\n",
289289
" if predicate(path, value):\n",
290290
" flat_states[i][path] = value\n",
@@ -415,7 +415,7 @@
415415
"name": "python",
416416
"nbconvert_exporter": "python",
417417
"pygments_lexer": "ipython3",
418-
"version": "3.10.13"
418+
"version": "3.11.9"
419419
}
420420
},
421421
"nbformat": 4,

docs_nnx/guides/filters_guide.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ def split(node, *filters):
149149
predicates = [nnx.filterlib.to_predicate(f) for f in filters]
150150
flat_states: list[dict[KeyPath, Any]] = [{} for p in predicates]
151151
152-
for path, value in state.flat_state():
152+
for path, value in state:
153153
for i, predicate in enumerate(predicates):
154154
if predicate(path, value):
155155
flat_states[i][path] = value

docs_nnx/guides/performance.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@
147147
"\n",
148148
"for _ in range(10):\n",
149149
" x, y = jnp.ones((32, 2)), jnp.zeros((32, 3))\n",
150-
" state, loss = jax_train_step(graphdef, state, x, y)\n",
150+
" loss, state = jax_train_step(graphdef, state, x, y)\n",
151151
"\n",
152152
"# update objects after training\n",
153153
"nnx.update((model, optimizer, metrics), state)"

docs_nnx/guides/performance.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ def jax_train_step(graphdef, state, x, y):
119119
120120
for _ in range(10):
121121
x, y = jnp.ones((32, 2)), jnp.zeros((32, 3))
122-
state, loss = jax_train_step(graphdef, state, x, y)
122+
loss, state = jax_train_step(graphdef, state, x, y)
123123
124124
# update objects after training
125125
nnx.update((model, optimizer, metrics), state)

0 commit comments

Comments
 (0)