Skip to content

Commit

Permalink
⊂(◉‿◉)つ underp test
Browse files Browse the repository at this point in the history
  • Loading branch information
marmaduke woodman committed Nov 10, 2023
1 parent 0aef8a7 commit c7bc8ba
Showing 1 changed file with 51 additions and 5 deletions.
56 changes: 51 additions & 5 deletions vbjax/tests/test_jaxisms.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
import jax
import jax.dlpack
import jax.test_util

import vbjax
import jax.numpy as np
import vbjax as vb


def test_dlpack_numpy():
Expand Down Expand Up @@ -46,10 +46,10 @@ def test_batched_jax_norm():
"test how to batch norms of derivatives and Jacobians"

# this is for neural ode, so make some layers
nn_p, nn_f = vbjax.make_dense_layers(3, [13])
nn_p, nn_f = vb.make_dense_layers(3, [13])
assert callable(nn_f)

x = vbjax.randn(100, 3, 50)
x = vb.randn(100, 3, 50)
x00 = x[0,:,0]

f = lambda x: nn_f(nn_p, x)
Expand Down Expand Up @@ -135,6 +135,7 @@ def loop(sim, ts):
'dt': 0.01,
'coupling': 0.01
}
ts = np.r_[:100]

def loss(params, sim, ts):
sim = sim.copy()
Expand All @@ -147,4 +148,49 @@ def loss(params, sim, ts):
v, grads = jloss(params, sim, ts)
for key, g_key in grads.items():
assert np.abs(g_key) > 0



def test_loop_dict_vmap():

sim = {
'noise': 0.1,
'coupling': 0.01,
'dt': 0.1,
'x': vb.randn(164),
'weights': vb.randn(164,164),
'eeg_gain': vb.randn(64, 164),
'rng': jax.random.PRNGKey(42)
}


def step(sim, t):
x = sim['x']
z = sim['noise'] * jax.random.normal(sim['rng'], shape=x.shape)
x = x + sim['dt']*(x - x**3/3 + sim['coupling'] * sim['weights']@x) + z
sim['x'] = x
eeg = sim['eeg_gain'] @ x
return sim, eeg

def loop(sim, ts):
sim, eegs = jax.lax.scan(step, sim, ts)
return eegs

# to vmap we need specific args to vmap over
# in simple case this is an array like initial conditions,
def loop_x0(x0, sim, ts):
sim = sim.copy()
sim['x'] = x0
return loop(sim, ts)

ts = np.r_[:100]

l1 = jax.jit(jax.vmap(lambda x: loop_x0(x, sim, ts)))
l2 = jax.jit(lambda x: loop_x0(x, sim, ts))
l3 = jax.jit(jax.vmap(lambda x: loop_x0(x, sim, ts), in_axes=1, out_axes=-1))

a1 = vb.randn(32, 164)
a2 = vb.randn(164, 32)

assert l1(a1).shape == (32, ts.size, 64)
assert l2(a2).shape == (ts.size, 64, 32)
assert l3(a2).shape == (ts.size, 64, 32)

0 comments on commit c7bc8ba

Please sign in to comment.