Skip to content

Commit 6af8fcb

Browse files
IvyZXFlax Authors
authored and
Flax Authors
committed
Refactor bridge.Module tests from wrappers_test.py to another file.
PiperOrigin-RevId: 731495913
1 parent 7974b10 commit 6af8fcb

File tree

2 files changed

+308
-277
lines changed

2 files changed

+308
-277
lines changed

tests/nnx/bridge/module_test.py

Lines changed: 308 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,308 @@
1+
# Copyright 2024 The Flax Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import os
16+
from typing import Any
17+
18+
from flax.linen.dtypes import promote_dtype
19+
20+
os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=4'
21+
22+
import jax
23+
import jax.numpy as jnp
24+
from absl.testing import absltest
25+
26+
from flax import linen as nn
27+
from flax import nnx
28+
from flax.nnx import bridge
29+
from flax.nnx.bridge.module import MODULE_CONTEXT
30+
31+
32+
class TestBridgeModule(absltest.TestCase):
33+
def test_update(self):
34+
class Foo(bridge.Module):
35+
a: int
36+
37+
foo = Foo(1)
38+
state = {'b': {'c': nnx.Param(jnp.array(2))}}
39+
nnx.update(foo, state)
40+
41+
def test_module_stack(self):
42+
"""Test that apply set the module stack correctly."""
43+
test = self
44+
45+
class Foo(bridge.Module):
46+
def setup(self):
47+
current_ctx = MODULE_CONTEXT.module_stack[-1]
48+
test.assertIs(current_ctx.module, self)
49+
test.assertFalse(current_ctx.in_compact)
50+
51+
def __call__(self):
52+
current_ctx = MODULE_CONTEXT.module_stack[-1]
53+
test.assertIs(current_ctx.module, self)
54+
test.assertFalse(current_ctx.in_compact)
55+
56+
foo = Foo()
57+
foo.apply({})
58+
59+
def test_compact_basic(self):
60+
test = self
61+
class Linear(bridge.Module):
62+
dout: int
63+
64+
@bridge.compact
65+
def __call__(self, x):
66+
w = self.param(
67+
'w', nnx.initializers.uniform(), (x.shape[-1], self.dout)
68+
)
69+
b = self.param('b', nn.initializers.zeros_init(), (self.dout,))
70+
return x @ w + b[None]
71+
72+
class Foo(bridge.Module):
73+
dout: int
74+
75+
@bridge.compact
76+
def __call__(self, x):
77+
din = x.shape[-1]
78+
self.linear = Linear(self.dout)
79+
x = self.linear(x)
80+
81+
# NNX
82+
graphdef, state = nnx.split(self)
83+
test.assertIn('Linear_0', state)
84+
test.assertIn('w', state['Linear_0'])
85+
test.assertIn('b', state['Linear_0'])
86+
87+
return x
88+
89+
foo = Foo(5)
90+
x = jnp.ones((3, 2))
91+
92+
self.assertIsInstance(foo, nnx.Module)
93+
94+
variables = foo.init(0, x)
95+
params = variables['params']
96+
97+
self.assertIn('Linear_0', params)
98+
self.assertIn('w', params['Linear_0'])
99+
self.assertIn('b', params['Linear_0'])
100+
self.assertEqual(params['Linear_0']['w'].shape, (2, 5))
101+
self.assertEqual(params['Linear_0']['b'].shape, (5,))
102+
103+
y: jax.Array = foo.apply(variables, x)
104+
105+
self.assertEqual(y.shape, (3, 5))
106+
107+
def test_mutable_state(self):
108+
class FooLinen(nn.Module):
109+
@nn.compact
110+
def __call__(self):
111+
count = self.variable(
112+
'counts', 'count', lambda: jnp.zeros((), jnp.int32)
113+
)
114+
count.value += 1
115+
116+
model_linen = FooLinen()
117+
initial_vars_linen = model_linen.init({})
118+
_, vars_linen = model_linen.apply(initial_vars_linen, mutable='counts')
119+
120+
class FooNNX(bridge.Module):
121+
@bridge.compact
122+
def __call__(self):
123+
count = self.variable(
124+
'counts', 'count', lambda: jnp.zeros((), jnp.int32)
125+
)
126+
count.value += 1
127+
128+
model_nnx = FooNNX()
129+
130+
initial_vars_nnx = model_nnx.init({})
131+
_, vars_nnx = model_nnx.apply(initial_vars_nnx, mutable='counts')
132+
133+
self.assertEqual(
134+
initial_vars_linen['counts']['count'], initial_vars_nnx['counts']['count']
135+
)
136+
self.assertEqual(vars_linen['counts']['count'], vars_nnx['counts']['count'])
137+
138+
def test_compact_parent_none(self):
139+
class Foo(bridge.Module):
140+
pass
141+
142+
class Bar(bridge.Module):
143+
@bridge.compact
144+
def __call__(self):
145+
return Foo().scope
146+
147+
bar = Bar()
148+
scope = bar.apply({}, rngs=1)
149+
self.assertIsNone(bar.scope)
150+
151+
self.assertEqual(scope.rngs.default.key.value, jax.random.key(1))
152+
self.assertEqual(scope.rngs.default.count.value, 0)
153+
154+
class Baz(bridge.Module):
155+
@bridge.compact
156+
def __call__(self):
157+
return Foo(parent=None).scope
158+
159+
baz = Baz()
160+
scope = baz.apply({}, rngs=1)
161+
self.assertIsNone(scope)
162+
163+
def test_name(self):
164+
class Foo(bridge.Module):
165+
dout: int
166+
167+
def __call__(self, x):
168+
w = self.param(
169+
'w', nnx.initializers.uniform(), (x.shape[-1], self.dout)
170+
)
171+
return x @ w
172+
173+
class Bar(bridge.Module):
174+
@bridge.compact
175+
def __call__(self, x):
176+
return Foo(5, name='xyz')(x)
177+
178+
bar = Bar()
179+
x = jnp.ones((1, 2))
180+
y, variables = bar.init_with_output(0, x)
181+
182+
self.assertIn('xyz', variables['params'])
183+
self.assertEqual(variables['params']['xyz']['w'].shape, (2, 5))
184+
self.assertEqual(y.shape, (1, 5))
185+
186+
y = bar.apply(variables, x)
187+
self.assertEqual(y.shape, (1, 5))
188+
189+
with self.assertRaises(ValueError):
190+
class SetupBar(bridge.Module):
191+
def setup(self):
192+
self.xyz = Foo(5, name='xyz')
193+
def __call__(self, x):
194+
return self.xyz(x)
195+
SetupBar().init(0, x)
196+
197+
def test_dense_port(self):
198+
class Dense(bridge.Module):
199+
features: int
200+
use_bias: bool = True
201+
dtype: Any = None
202+
param_dtype: Any = jnp.float32
203+
precision: Any = None
204+
kernel_init: Any = nnx.initializers.lecun_normal()
205+
bias_init: Any = nnx.initializers.zeros_init()
206+
# Deprecated. Will be removed.
207+
dot_general: Any | None = None
208+
dot_general_cls: Any = None
209+
210+
@bridge.compact
211+
def __call__(self, inputs: jax.Array) -> jax.Array:
212+
kernel = self.param(
213+
'kernel',
214+
self.kernel_init,
215+
(jnp.shape(inputs)[-1], self.features),
216+
self.param_dtype,
217+
)
218+
if self.use_bias:
219+
bias = self.param(
220+
'bias', self.bias_init, (self.features,), self.param_dtype
221+
)
222+
else:
223+
bias = None
224+
inputs, kernel, bias = promote_dtype(
225+
inputs, kernel, bias, dtype=self.dtype
226+
)
227+
228+
if self.dot_general_cls is not None:
229+
dot_general = self.dot_general_cls()
230+
elif self.dot_general is not None:
231+
dot_general = self.dot_general
232+
else:
233+
dot_general = jax.lax.dot_general
234+
y = dot_general(
235+
inputs,
236+
kernel,
237+
(((inputs.ndim - 1,), (0,)), ((), ())),
238+
precision=self.precision,
239+
)
240+
if bias is not None:
241+
y += jnp.reshape(bias, (1,) * (y.ndim - 1) + (-1,))
242+
return y
243+
244+
m = Dense(3)
245+
x = jnp.ones((1, 10, 2))
246+
y, variables = m.init_with_output(0, x)
247+
248+
self.assertEqual(y.shape, (1, 10, 3))
249+
self.assertEqual(variables['params']['kernel'].shape, (2, 3))
250+
self.assertEqual(variables['params']['bias'].shape, (3,))
251+
252+
y = m.apply(variables, x)
253+
254+
self.assertEqual(y.shape, (1, 10, 3))
255+
self.assertEqual(variables['params']['kernel'].shape, (2, 3))
256+
self.assertEqual(variables['params']['bias'].shape, (3,))
257+
258+
@jax.jit
259+
def train_step(params, x, y):
260+
def loss_fn(params):
261+
y_pred = m.apply({'params': params}, x)
262+
return jnp.mean((y - y_pred) ** 2)
263+
264+
grads = jax.grad(loss_fn)(params)
265+
266+
params = jax.tree.map(lambda p, g: p - 0.01 * g, params, grads)
267+
268+
return params
269+
270+
params = variables['params']
271+
x = jnp.ones((1, 10, 2))
272+
y = jnp.ones((1, 10, 3))
273+
274+
params = train_step(params, x, y)
275+
276+
def test_metadata(self):
277+
class Linear(bridge.Module):
278+
dout: int
279+
280+
@bridge.compact
281+
def __call__(self, x):
282+
w = self.param(
283+
'w', bridge.with_partitioning(nnx.initializers.uniform(), ('in', 'out')),
284+
(x.shape[-1], self.dout)
285+
)
286+
b = self.param('b', nnx.initializers.zeros_init(), (self.dout,))
287+
return x @ w + b[None]
288+
289+
foo = Linear(5)
290+
x = jnp.ones((3, 2))
291+
292+
variables = foo.init(0, x)
293+
params = variables['params']
294+
self.assertIsInstance(params['w'], nn.Partitioned)
295+
self.assertEqual(params['w'].value.shape, (2, 5))
296+
self.assertEqual(params['w'].names, ('in', 'out'))
297+
self.assertEqual(nn.get_partition_spec(variables)['params']['w'],
298+
jax.sharding.PartitionSpec('in', 'out'))
299+
self.assertIsInstance(params['b'], jax.Array)
300+
self.assertEqual(params['b'].shape, (5,))
301+
302+
y: jax.Array = foo.apply(variables, x)
303+
self.assertEqual(y.shape, (3, 5))
304+
305+
306+
if __name__ == '__main__':
307+
absltest.main()
308+

0 commit comments

Comments
 (0)