Skip to content

Commit

Permalink
(☞゚∀゚)☞ grad is not slow!
Browse files Browse the repository at this point in the history
  • Loading branch information
marmaduke woodman committed Nov 10, 2023
1 parent 24b59f3 commit 8d539f4
Showing 1 changed file with 33 additions and 10 deletions.
43 changes: 33 additions & 10 deletions vbjax/tests/test_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import jax
import jax.numpy as np
import vbjax as vb

import pytest

def test_timeavg():
buf, ta_step, ta_sample = vb.make_timeavg((4,))
Expand Down Expand Up @@ -45,7 +45,7 @@ def test_bold():
assert fmri.shape == (n,)


def test_multiple_periods():
def setup_multiple_periods(unroll):

# setup monitors
eeg_gain = vb.randn(64, 32)
Expand All @@ -57,7 +57,8 @@ def test_multiple_periods():
# TODO may be easier with jax_dataclasses
sim = {
'eeg_buf': eeg_buf,
'bold_buf': bold_buf
'bold_buf': bold_buf,
'freq': 0.1,
}

# inner scan steps neural dynamics & monitor states
Expand All @@ -66,32 +67,54 @@ def op1(sim, t):
# insert neural dynamics here
x = jax.random.normal(key_t, shape=(eeg_gain.shape[1],))
# update monitors
sim['eeg_buf'] = eeg_step(sim['eeg_buf'], x)
sim['eeg_buf'] = eeg_step(sim['eeg_buf'], np.sin(x * sim['freq']))
sim['bold_buf'] = bold_step(sim['bold_buf'], np.sin(x) * 0.25 + 1.0)
return sim, x

# next scan samples eeg monitors
def op2(sim, t_):
# sample eeg w/ period of 10*dt
sim, raw = jax.lax.scan(op1, sim, t_ * 10 + np.r_[:10])
sim, raw = jax.lax.scan(op1, sim, t_ * 10 + np.r_[:10],
unroll=10 if unroll else 1)
sim['eeg_buf'], eeg_t = eeg_sample(sim['eeg_buf'])
return sim, (raw, eeg_t)

# outer scan steps from one bold sample to the next
def op3(sim, T):
# run for 5 samples of eeg
sim, (raw, eeg) = jax.lax.scan(op2, sim, T*50 + np.r_[:5])
sim, (raw, eeg) = jax.lax.scan(op2, sim, T*50 + np.r_[:5],
unroll=5 if unroll else 1)
# sample fmri w/ period of 5*10*dt
_, fmri = bold_sample(sim['bold_buf'])
return sim, (raw, eeg, fmri)

return sim, op3

def test_multiple_periods():
sim, op3 = setup_multiple_periods()
ts = np.r_[:10]
sim, (raw, eeg, fmri) = jax.lax.scan(op3, sim, ts)

assert raw.shape == (ts.size, 5, 10, 32)
assert eeg.shape == (ts.size, 5, 64)
assert fmri.shape == (ts.size, 32)

# that covers most use cases
# metrics like fcd can be at done at fmri time scale

@pytest.mark.parametrize('dojit,unroll,grad', [
(True,True,True), (False,False,True),
(True, True, False), (False, False, False),
])
def test_multiple_periods_perf(benchmark, dojit, unroll, grad):
sim, op3 = setup_multiple_periods(unroll)
ts = np.r_[:100]
def run(freq, sim):
sim = sim.copy()
sim['freq'] = freq
sim, (raw, eeg, fmri) = jax.lax.scan(op3, sim, ts,
unroll=10 if unroll else 1)
return np.sum(np.square(eeg))
if grad:
run = jax.grad(run)
assert np.abs(run(0.2,sim)) > 0
if dojit:
run = jax.jit(run)
run(0.2, sim)
benchmark(lambda : run(0.2, sim))

0 comments on commit 8d539f4

Please sign in to comment.