|
| 1 | +# Copyright 2024 The Flax Authors. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +import os |
| 16 | +from typing import Any |
| 17 | + |
| 18 | +from flax.linen.dtypes import promote_dtype |
| 19 | + |
| 20 | +os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=4' |
| 21 | + |
| 22 | +import jax |
| 23 | +import jax.numpy as jnp |
| 24 | +from absl.testing import absltest |
| 25 | + |
| 26 | +from flax import linen as nn |
| 27 | +from flax import nnx |
| 28 | +from flax.nnx import bridge |
| 29 | +from flax.nnx.bridge.module import MODULE_CONTEXT |
| 30 | + |
| 31 | + |
| 32 | +class TestBridgeModule(absltest.TestCase): |
| 33 | + def test_update(self): |
| 34 | + class Foo(bridge.Module): |
| 35 | + a: int |
| 36 | + |
| 37 | + foo = Foo(1) |
| 38 | + state = {'b': {'c': nnx.Param(jnp.array(2))}} |
| 39 | + nnx.update(foo, state) |
| 40 | + |
| 41 | + def test_module_stack(self): |
| 42 | + """Test that apply set the module stack correctly.""" |
| 43 | + test = self |
| 44 | + |
| 45 | + class Foo(bridge.Module): |
| 46 | + def setup(self): |
| 47 | + current_ctx = MODULE_CONTEXT.module_stack[-1] |
| 48 | + test.assertIs(current_ctx.module, self) |
| 49 | + test.assertFalse(current_ctx.in_compact) |
| 50 | + |
| 51 | + def __call__(self): |
| 52 | + current_ctx = MODULE_CONTEXT.module_stack[-1] |
| 53 | + test.assertIs(current_ctx.module, self) |
| 54 | + test.assertFalse(current_ctx.in_compact) |
| 55 | + |
| 56 | + foo = Foo() |
| 57 | + foo.apply({}) |
| 58 | + |
| 59 | + def test_compact_basic(self): |
| 60 | + test = self |
| 61 | + class Linear(bridge.Module): |
| 62 | + dout: int |
| 63 | + |
| 64 | + @bridge.compact |
| 65 | + def __call__(self, x): |
| 66 | + w = self.param( |
| 67 | + 'w', nnx.initializers.uniform(), (x.shape[-1], self.dout) |
| 68 | + ) |
| 69 | + b = self.param('b', nn.initializers.zeros_init(), (self.dout,)) |
| 70 | + return x @ w + b[None] |
| 71 | + |
| 72 | + class Foo(bridge.Module): |
| 73 | + dout: int |
| 74 | + |
| 75 | + @bridge.compact |
| 76 | + def __call__(self, x): |
| 77 | + din = x.shape[-1] |
| 78 | + self.linear = Linear(self.dout) |
| 79 | + x = self.linear(x) |
| 80 | + |
| 81 | + # NNX |
| 82 | + graphdef, state = nnx.split(self) |
| 83 | + test.assertIn('Linear_0', state) |
| 84 | + test.assertIn('w', state['Linear_0']) |
| 85 | + test.assertIn('b', state['Linear_0']) |
| 86 | + |
| 87 | + return x |
| 88 | + |
| 89 | + foo = Foo(5) |
| 90 | + x = jnp.ones((3, 2)) |
| 91 | + |
| 92 | + self.assertIsInstance(foo, nnx.Module) |
| 93 | + |
| 94 | + variables = foo.init(0, x) |
| 95 | + params = variables['params'] |
| 96 | + |
| 97 | + self.assertIn('Linear_0', params) |
| 98 | + self.assertIn('w', params['Linear_0']) |
| 99 | + self.assertIn('b', params['Linear_0']) |
| 100 | + self.assertEqual(params['Linear_0']['w'].shape, (2, 5)) |
| 101 | + self.assertEqual(params['Linear_0']['b'].shape, (5,)) |
| 102 | + |
| 103 | + y: jax.Array = foo.apply(variables, x) |
| 104 | + |
| 105 | + self.assertEqual(y.shape, (3, 5)) |
| 106 | + |
| 107 | + def test_mutable_state(self): |
| 108 | + class FooLinen(nn.Module): |
| 109 | + @nn.compact |
| 110 | + def __call__(self): |
| 111 | + count = self.variable( |
| 112 | + 'counts', 'count', lambda: jnp.zeros((), jnp.int32) |
| 113 | + ) |
| 114 | + count.value += 1 |
| 115 | + |
| 116 | + model_linen = FooLinen() |
| 117 | + initial_vars_linen = model_linen.init({}) |
| 118 | + _, vars_linen = model_linen.apply(initial_vars_linen, mutable='counts') |
| 119 | + |
| 120 | + class FooNNX(bridge.Module): |
| 121 | + @bridge.compact |
| 122 | + def __call__(self): |
| 123 | + count = self.variable( |
| 124 | + 'counts', 'count', lambda: jnp.zeros((), jnp.int32) |
| 125 | + ) |
| 126 | + count.value += 1 |
| 127 | + |
| 128 | + model_nnx = FooNNX() |
| 129 | + |
| 130 | + initial_vars_nnx = model_nnx.init({}) |
| 131 | + _, vars_nnx = model_nnx.apply(initial_vars_nnx, mutable='counts') |
| 132 | + |
| 133 | + self.assertEqual( |
| 134 | + initial_vars_linen['counts']['count'], initial_vars_nnx['counts']['count'] |
| 135 | + ) |
| 136 | + self.assertEqual(vars_linen['counts']['count'], vars_nnx['counts']['count']) |
| 137 | + |
| 138 | + def test_compact_parent_none(self): |
| 139 | + class Foo(bridge.Module): |
| 140 | + pass |
| 141 | + |
| 142 | + class Bar(bridge.Module): |
| 143 | + @bridge.compact |
| 144 | + def __call__(self): |
| 145 | + return Foo().scope |
| 146 | + |
| 147 | + bar = Bar() |
| 148 | + scope = bar.apply({}, rngs=1) |
| 149 | + self.assertIsNone(bar.scope) |
| 150 | + |
| 151 | + self.assertEqual(scope.rngs.default.key.value, jax.random.key(1)) |
| 152 | + self.assertEqual(scope.rngs.default.count.value, 0) |
| 153 | + |
| 154 | + class Baz(bridge.Module): |
| 155 | + @bridge.compact |
| 156 | + def __call__(self): |
| 157 | + return Foo(parent=None).scope |
| 158 | + |
| 159 | + baz = Baz() |
| 160 | + scope = baz.apply({}, rngs=1) |
| 161 | + self.assertIsNone(scope) |
| 162 | + |
| 163 | + def test_name(self): |
| 164 | + class Foo(bridge.Module): |
| 165 | + dout: int |
| 166 | + |
| 167 | + def __call__(self, x): |
| 168 | + w = self.param( |
| 169 | + 'w', nnx.initializers.uniform(), (x.shape[-1], self.dout) |
| 170 | + ) |
| 171 | + return x @ w |
| 172 | + |
| 173 | + class Bar(bridge.Module): |
| 174 | + @bridge.compact |
| 175 | + def __call__(self, x): |
| 176 | + return Foo(5, name='xyz')(x) |
| 177 | + |
| 178 | + bar = Bar() |
| 179 | + x = jnp.ones((1, 2)) |
| 180 | + y, variables = bar.init_with_output(0, x) |
| 181 | + |
| 182 | + self.assertIn('xyz', variables['params']) |
| 183 | + self.assertEqual(variables['params']['xyz']['w'].shape, (2, 5)) |
| 184 | + self.assertEqual(y.shape, (1, 5)) |
| 185 | + |
| 186 | + y = bar.apply(variables, x) |
| 187 | + self.assertEqual(y.shape, (1, 5)) |
| 188 | + |
| 189 | + with self.assertRaises(ValueError): |
| 190 | + class SetupBar(bridge.Module): |
| 191 | + def setup(self): |
| 192 | + self.xyz = Foo(5, name='xyz') |
| 193 | + def __call__(self, x): |
| 194 | + return self.xyz(x) |
| 195 | + SetupBar().init(0, x) |
| 196 | + |
| 197 | + def test_dense_port(self): |
| 198 | + class Dense(bridge.Module): |
| 199 | + features: int |
| 200 | + use_bias: bool = True |
| 201 | + dtype: Any = None |
| 202 | + param_dtype: Any = jnp.float32 |
| 203 | + precision: Any = None |
| 204 | + kernel_init: Any = nnx.initializers.lecun_normal() |
| 205 | + bias_init: Any = nnx.initializers.zeros_init() |
| 206 | + # Deprecated. Will be removed. |
| 207 | + dot_general: Any | None = None |
| 208 | + dot_general_cls: Any = None |
| 209 | + |
| 210 | + @bridge.compact |
| 211 | + def __call__(self, inputs: jax.Array) -> jax.Array: |
| 212 | + kernel = self.param( |
| 213 | + 'kernel', |
| 214 | + self.kernel_init, |
| 215 | + (jnp.shape(inputs)[-1], self.features), |
| 216 | + self.param_dtype, |
| 217 | + ) |
| 218 | + if self.use_bias: |
| 219 | + bias = self.param( |
| 220 | + 'bias', self.bias_init, (self.features,), self.param_dtype |
| 221 | + ) |
| 222 | + else: |
| 223 | + bias = None |
| 224 | + inputs, kernel, bias = promote_dtype( |
| 225 | + inputs, kernel, bias, dtype=self.dtype |
| 226 | + ) |
| 227 | + |
| 228 | + if self.dot_general_cls is not None: |
| 229 | + dot_general = self.dot_general_cls() |
| 230 | + elif self.dot_general is not None: |
| 231 | + dot_general = self.dot_general |
| 232 | + else: |
| 233 | + dot_general = jax.lax.dot_general |
| 234 | + y = dot_general( |
| 235 | + inputs, |
| 236 | + kernel, |
| 237 | + (((inputs.ndim - 1,), (0,)), ((), ())), |
| 238 | + precision=self.precision, |
| 239 | + ) |
| 240 | + if bias is not None: |
| 241 | + y += jnp.reshape(bias, (1,) * (y.ndim - 1) + (-1,)) |
| 242 | + return y |
| 243 | + |
| 244 | + m = Dense(3) |
| 245 | + x = jnp.ones((1, 10, 2)) |
| 246 | + y, variables = m.init_with_output(0, x) |
| 247 | + |
| 248 | + self.assertEqual(y.shape, (1, 10, 3)) |
| 249 | + self.assertEqual(variables['params']['kernel'].shape, (2, 3)) |
| 250 | + self.assertEqual(variables['params']['bias'].shape, (3,)) |
| 251 | + |
| 252 | + y = m.apply(variables, x) |
| 253 | + |
| 254 | + self.assertEqual(y.shape, (1, 10, 3)) |
| 255 | + self.assertEqual(variables['params']['kernel'].shape, (2, 3)) |
| 256 | + self.assertEqual(variables['params']['bias'].shape, (3,)) |
| 257 | + |
| 258 | + @jax.jit |
| 259 | + def train_step(params, x, y): |
| 260 | + def loss_fn(params): |
| 261 | + y_pred = m.apply({'params': params}, x) |
| 262 | + return jnp.mean((y - y_pred) ** 2) |
| 263 | + |
| 264 | + grads = jax.grad(loss_fn)(params) |
| 265 | + |
| 266 | + params = jax.tree.map(lambda p, g: p - 0.01 * g, params, grads) |
| 267 | + |
| 268 | + return params |
| 269 | + |
| 270 | + params = variables['params'] |
| 271 | + x = jnp.ones((1, 10, 2)) |
| 272 | + y = jnp.ones((1, 10, 3)) |
| 273 | + |
| 274 | + params = train_step(params, x, y) |
| 275 | + |
| 276 | + def test_metadata(self): |
| 277 | + class Linear(bridge.Module): |
| 278 | + dout: int |
| 279 | + |
| 280 | + @bridge.compact |
| 281 | + def __call__(self, x): |
| 282 | + w = self.param( |
| 283 | + 'w', bridge.with_partitioning(nnx.initializers.uniform(), ('in', 'out')), |
| 284 | + (x.shape[-1], self.dout) |
| 285 | + ) |
| 286 | + b = self.param('b', nnx.initializers.zeros_init(), (self.dout,)) |
| 287 | + return x @ w + b[None] |
| 288 | + |
| 289 | + foo = Linear(5) |
| 290 | + x = jnp.ones((3, 2)) |
| 291 | + |
| 292 | + variables = foo.init(0, x) |
| 293 | + params = variables['params'] |
| 294 | + self.assertIsInstance(params['w'], nn.Partitioned) |
| 295 | + self.assertEqual(params['w'].value.shape, (2, 5)) |
| 296 | + self.assertEqual(params['w'].names, ('in', 'out')) |
| 297 | + self.assertEqual(nn.get_partition_spec(variables)['params']['w'], |
| 298 | + jax.sharding.PartitionSpec('in', 'out')) |
| 299 | + self.assertIsInstance(params['b'], jax.Array) |
| 300 | + self.assertEqual(params['b'].shape, (5,)) |
| 301 | + |
| 302 | + y: jax.Array = foo.apply(variables, x) |
| 303 | + self.assertEqual(y.shape, (3, 5)) |
| 304 | + |
| 305 | + |
| 306 | +if __name__ == '__main__': |
| 307 | + absltest.main() |
| 308 | + |
0 commit comments