Skip to content

Commit f309de1

Browse files
committed
[nnx] mutable array p1
1 parent dbce607 commit f309de1

24 files changed

+1136
-155
lines changed
Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
# Copyright 2024 The Flax Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# %%
16+
import os
17+
18+
os.environ['FLAX_MUTABLE_ARRAY'] = 'true'
19+
20+
import jax
21+
import jax.numpy as jnp
22+
import matplotlib.pyplot as plt
23+
import numpy as np
24+
25+
from flax import nnx
26+
from flax.nnx.variablelib import is_mutable_array
27+
28+
29+
def mutable_like(path, x):
30+
return (
31+
isinstance(x, nnx.Variable) and x.mutable
32+
) or nnx.variablelib.is_mutable_array(x)
33+
34+
35+
def freeze(x, only: nnx.filterlib.Filter = mutable_like):
36+
freeze_filter = nnx.filterlib.to_predicate(only)
37+
mutable_arrays: set[int] = set()
38+
39+
def check_mutable_array(path, x):
40+
m_array_id = id(x)
41+
if m_array_id in mutable_arrays:
42+
path_str = jax.tree_util.keystr(path)
43+
raise ValueError(
44+
f'Found duplicate MutableArray found at path {path_str}: {x}'
45+
)
46+
mutable_arrays.add(m_array_id)
47+
48+
def _freeze_fn(jax_path, x):
49+
path = tuple(nnx.graph._key_path_to_key(part) for part in jax_path)
50+
if freeze_filter(path, x):
51+
if isinstance(x, nnx.Variable):
52+
check_mutable_array(jax_path, x.raw_value)
53+
return x.from_metadata(x[...], x.get_metadata().copy())
54+
elif nnx.variablelib.is_mutable_array(x):
55+
check_mutable_array(jax_path, x)
56+
return x[...]
57+
return x
58+
59+
return jax.tree.map_with_path(
60+
_freeze_fn, x, is_leaf=lambda x: isinstance(x, nnx.Variable)
61+
)
62+
63+
X = np.linspace(-jnp.pi, jnp.pi, 100)[:, None]
64+
Y = 0.8 * jnp.sin(X) + 0.1 + np.random.normal(0, 0.1, size=X.shape)
65+
66+
def dataset(batch_size):
67+
while True:
68+
idx = np.random.choice(len(X), size=batch_size)
69+
yield X[idx], Y[idx]
70+
71+
72+
class Linear(nnx.Module):
73+
__data__ = ('w', 'b')
74+
75+
def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):
76+
self.w = nnx.Param(jax.random.uniform(rngs.params(), (din, dout)))
77+
self.b = nnx.Param(jnp.zeros((dout,)))
78+
79+
def __call__(self, x):
80+
return x @ self.w[...] + self.b[None]
81+
82+
83+
class Count(nnx.Variable[nnx.A]):
84+
pass
85+
86+
87+
class MLP(nnx.Module):
88+
__data__ = ('count', 'linear1', 'linear2')
89+
90+
def __init__(self, din, dhidden, dout, *, rngs: nnx.Rngs):
91+
self.count = Count(jnp.array(0))
92+
self.linear1 = Linear(din, dhidden, rngs=rngs)
93+
self.linear2 = Linear(dhidden, dout, rngs=rngs)
94+
95+
def __call__(self, x):
96+
self.count[...] += 1
97+
return self.linear2(jax.nn.gelu(self.linear1(x)) * 0.5)
98+
99+
100+
model = MLP(din=1, dhidden=64, dout=1, rngs=nnx.Rngs(0))
101+
102+
103+
@jax.jit
104+
def train_step(model, x, y):
105+
graphdef, params, counts = nnx.split(model, nnx.Param, Count)
106+
107+
def loss_fn(params):
108+
model = nnx.merge(graphdef, params, counts)
109+
return jnp.mean((y - model(x)) ** 2)
110+
111+
grads = jax.grad(loss_fn)(freeze(params))
112+
113+
def sgd(w, g):
114+
w[...] -= 0.1 * g[...]
115+
116+
jax.tree.map(sgd, params, grads)
117+
118+
119+
@jax.jit
120+
def test_step(model: MLP, x, y):
121+
return {'loss': jnp.mean((y - model(x)) ** 2)}
122+
123+
124+
total_steps = 10_000
125+
for step, (x, y) in enumerate(dataset(32)):
126+
train_step(model, x, y)
127+
128+
if step % 1000 == 0:
129+
logs = test_step(model, X, Y)
130+
print(f"step: {step}, loss: {logs['loss']}")
131+
132+
if step >= total_steps - 1:
133+
break
134+
135+
print('times called:', model.count.value)
136+
137+
y_pred = model(X)
138+
139+
plt.scatter(X, Y, color='blue')
140+
plt.plot(X, y_pred, color='black')
141+
plt.show()

0 commit comments

Comments
 (0)