From c7bc8ba47fc9d28b9bf4091fb3c1444dc4bf9538 Mon Sep 17 00:00:00 2001 From: marmaduke woodman Date: Fri, 10 Nov 2023 10:46:50 +0100 Subject: [PATCH] =?UTF-8?q?=E2=8A=82(=E2=97=89=E2=80=BF=E2=97=89)=E3=81=A4?= =?UTF-8?q?=20underp=20test?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- vbjax/tests/test_jaxisms.py | 56 +++++++++++++++++++++++++++++++++---- 1 file changed, 51 insertions(+), 5 deletions(-) diff --git a/vbjax/tests/test_jaxisms.py b/vbjax/tests/test_jaxisms.py index c140037..0190721 100644 --- a/vbjax/tests/test_jaxisms.py +++ b/vbjax/tests/test_jaxisms.py @@ -8,8 +8,8 @@ import jax import jax.dlpack import jax.test_util - -import vbjax +import jax.numpy as np +import vbjax as vb def test_dlpack_numpy(): @@ -46,10 +46,10 @@ def test_batched_jax_norm(): "test how to batch norms of derivatives and Jacobians" # this is for neural ode, so make some layers - nn_p, nn_f = vbjax.make_dense_layers(3, [13]) + nn_p, nn_f = vb.make_dense_layers(3, [13]) assert callable(nn_f) - x = vbjax.randn(100, 3, 50) + x = vb.randn(100, 3, 50) x00 = x[0,:,0] f = lambda x: nn_f(nn_p, x) @@ -135,6 +135,7 @@ def loop(sim, ts): 'dt': 0.01, 'coupling': 0.01 } + ts = np.r_[:100] def loss(params, sim, ts): sim = sim.copy() @@ -147,4 +148,49 @@ def loss(params, sim, ts): v, grads = jloss(params, sim, ts) for key, g_key in grads.items(): assert np.abs(g_key) > 0 - \ No newline at end of file + + +def test_loop_dict_vmap(): + + sim = { + 'noise': 0.1, + 'coupling': 0.01, + 'dt': 0.1, + 'x': vb.randn(164), + 'weights': vb.randn(164,164), + 'eeg_gain': vb.randn(64, 164), + 'rng': jax.random.PRNGKey(42) + } + + + def step(sim, t): + x = sim['x'] + z = sim['noise'] * jax.random.normal(sim['rng'], shape=x.shape) + x = x + sim['dt']*(x - x**3/3 + sim['coupling'] * sim['weights']@x) + z + sim['x'] = x + eeg = sim['eeg_gain'] @ x + return sim, eeg + + def loop(sim, ts): + sim, eegs = jax.lax.scan(step, sim, ts) + return eegs + + # to vmap we need specific args to vmap over + # in simple case this is an array like initial conditions, + def loop_x0(x0, sim, ts): + sim = sim.copy() + sim['x'] = x0 + return loop(sim, ts) + + ts = np.r_[:100] + + l1 = jax.jit(jax.vmap(lambda x: loop_x0(x, sim, ts))) + l2 = jax.jit(lambda x: loop_x0(x, sim, ts)) + l3 = jax.jit(jax.vmap(lambda x: loop_x0(x, sim, ts), in_axes=1, out_axes=-1)) + + a1 = vb.randn(32, 164) + a2 = vb.randn(164, 32) + + assert l1(a1).shape == (32, ts.size, 64) + assert l2(a2).shape == (ts.size, 64, 32) + assert l3(a2).shape == (ts.size, 64, 32) \ No newline at end of file