|
| 1 | +# Copyright 2024 The Flax Authors. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +# %% |
| 16 | +import os |
| 17 | + |
| 18 | +os.environ['FLAX_MUTABLE_ARRAY'] = 'true' |
| 19 | + |
| 20 | +import jax |
| 21 | +import jax.numpy as jnp |
| 22 | +import matplotlib.pyplot as plt |
| 23 | +import numpy as np |
| 24 | + |
| 25 | +from flax import nnx |
| 26 | +from flax.nnx.variablelib import is_mutable_array |
| 27 | + |
| 28 | + |
| 29 | +def mutable_like(path, x): |
| 30 | + return ( |
| 31 | + isinstance(x, nnx.Variable) and x.mutable |
| 32 | + ) or nnx.variablelib.is_mutable_array(x) |
| 33 | + |
| 34 | + |
| 35 | +def freeze(x, only: nnx.filterlib.Filter = mutable_like): |
| 36 | + freeze_filter = nnx.filterlib.to_predicate(only) |
| 37 | + mutable_arrays: set[int] = set() |
| 38 | + |
| 39 | + def check_mutable_array(path, x): |
| 40 | + m_array_id = id(x) |
| 41 | + if m_array_id in mutable_arrays: |
| 42 | + path_str = jax.tree_util.keystr(path) |
| 43 | + raise ValueError( |
| 44 | + f'Found duplicate MutableArray found at path {path_str}: {x}' |
| 45 | + ) |
| 46 | + mutable_arrays.add(m_array_id) |
| 47 | + |
| 48 | + def _freeze_fn(jax_path, x): |
| 49 | + path = tuple(nnx.graph._key_path_to_key(part) for part in jax_path) |
| 50 | + if freeze_filter(path, x): |
| 51 | + if isinstance(x, nnx.Variable): |
| 52 | + check_mutable_array(jax_path, x.raw_value) |
| 53 | + return x.from_metadata(x[...], x.get_metadata().copy()) |
| 54 | + elif nnx.variablelib.is_mutable_array(x): |
| 55 | + check_mutable_array(jax_path, x) |
| 56 | + return x[...] |
| 57 | + return x |
| 58 | + |
| 59 | + return jax.tree.map_with_path( |
| 60 | + _freeze_fn, x, is_leaf=lambda x: isinstance(x, nnx.Variable) |
| 61 | + ) |
| 62 | + |
| 63 | +X = np.linspace(-jnp.pi, jnp.pi, 100)[:, None] |
| 64 | +Y = 0.8 * jnp.sin(X) + 0.1 + np.random.normal(0, 0.1, size=X.shape) |
| 65 | + |
| 66 | +def dataset(batch_size): |
| 67 | + while True: |
| 68 | + idx = np.random.choice(len(X), size=batch_size) |
| 69 | + yield X[idx], Y[idx] |
| 70 | + |
| 71 | + |
| 72 | +class Linear(nnx.Module): |
| 73 | + __data__ = ('w', 'b') |
| 74 | + |
| 75 | + def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs): |
| 76 | + self.w = nnx.Param(jax.random.uniform(rngs.params(), (din, dout))) |
| 77 | + self.b = nnx.Param(jnp.zeros((dout,))) |
| 78 | + |
| 79 | + def __call__(self, x): |
| 80 | + return x @ self.w[...] + self.b[None] |
| 81 | + |
| 82 | + |
| 83 | +class Count(nnx.Variable[nnx.A]): |
| 84 | + pass |
| 85 | + |
| 86 | + |
| 87 | +class MLP(nnx.Module): |
| 88 | + __data__ = ('count', 'linear1', 'linear2') |
| 89 | + |
| 90 | + def __init__(self, din, dhidden, dout, *, rngs: nnx.Rngs): |
| 91 | + self.count = Count(jnp.array(0)) |
| 92 | + self.linear1 = Linear(din, dhidden, rngs=rngs) |
| 93 | + self.linear2 = Linear(dhidden, dout, rngs=rngs) |
| 94 | + |
| 95 | + def __call__(self, x): |
| 96 | + self.count[...] += 1 |
| 97 | + return self.linear2(jax.nn.gelu(self.linear1(x)) * 0.5) |
| 98 | + |
| 99 | + |
| 100 | +model = MLP(din=1, dhidden=64, dout=1, rngs=nnx.Rngs(0)) |
| 101 | + |
| 102 | + |
| 103 | +@jax.jit |
| 104 | +def train_step(model, x, y): |
| 105 | + graphdef, params, counts = nnx.split(model, nnx.Param, Count) |
| 106 | + |
| 107 | + def loss_fn(params): |
| 108 | + model = nnx.merge(graphdef, params, counts) |
| 109 | + return jnp.mean((y - model(x)) ** 2) |
| 110 | + |
| 111 | + grads = jax.grad(loss_fn)(freeze(params)) |
| 112 | + |
| 113 | + def sgd(w, g): |
| 114 | + w[...] -= 0.1 * g[...] |
| 115 | + |
| 116 | + jax.tree.map(sgd, params, grads) |
| 117 | + |
| 118 | + |
| 119 | +@jax.jit |
| 120 | +def test_step(model: MLP, x, y): |
| 121 | + return {'loss': jnp.mean((y - model(x)) ** 2)} |
| 122 | + |
| 123 | + |
| 124 | +total_steps = 10_000 |
| 125 | +for step, (x, y) in enumerate(dataset(32)): |
| 126 | + train_step(model, x, y) |
| 127 | + |
| 128 | + if step % 1000 == 0: |
| 129 | + logs = test_step(model, X, Y) |
| 130 | + print(f"step: {step}, loss: {logs['loss']}") |
| 131 | + |
| 132 | + if step >= total_steps - 1: |
| 133 | + break |
| 134 | + |
| 135 | +print('times called:', model.count.value) |
| 136 | + |
| 137 | +y_pred = model(X) |
| 138 | + |
| 139 | +plt.scatter(X, Y, color='blue') |
| 140 | +plt.plot(X, y_pred, color='black') |
| 141 | +plt.show() |
0 commit comments