From 8d539f48ecdd5930eac55f1871a91a92df189f0e Mon Sep 17 00:00:00 2001 From: marmaduke woodman Date: Fri, 10 Nov 2023 18:15:27 +0100 Subject: [PATCH] =?UTF-8?q?(=E2=98=9E=EF=BE=9F=E2=88=80=EF=BE=9F)=E2=98=9E?= =?UTF-8?q?=20grad=20is=20not=20slow!?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- vbjax/tests/test_monitor.py | 43 ++++++++++++++++++++++++++++--------- 1 file changed, 33 insertions(+), 10 deletions(-) diff --git a/vbjax/tests/test_monitor.py b/vbjax/tests/test_monitor.py index 4fa7a87..16eafd0 100644 --- a/vbjax/tests/test_monitor.py +++ b/vbjax/tests/test_monitor.py @@ -2,7 +2,7 @@ import jax import jax.numpy as np import vbjax as vb - +import pytest def test_timeavg(): buf, ta_step, ta_sample = vb.make_timeavg((4,)) @@ -45,7 +45,7 @@ def test_bold(): assert fmri.shape == (n,) -def test_multiple_periods(): +def setup_multiple_periods(unroll): # setup monitors eeg_gain = vb.randn(64, 32) @@ -57,7 +57,8 @@ def test_multiple_periods(): # TODO may be easier with jax_dataclasses sim = { 'eeg_buf': eeg_buf, - 'bold_buf': bold_buf + 'bold_buf': bold_buf, + 'freq': 0.1, } # inner scan steps neural dynamics & monitor states @@ -66,32 +67,54 @@ def op1(sim, t): # insert neural dynamics here x = jax.random.normal(key_t, shape=(eeg_gain.shape[1],)) # update monitors - sim['eeg_buf'] = eeg_step(sim['eeg_buf'], x) + sim['eeg_buf'] = eeg_step(sim['eeg_buf'], np.sin(x * sim['freq'])) sim['bold_buf'] = bold_step(sim['bold_buf'], np.sin(x) * 0.25 + 1.0) return sim, x # next scan samples eeg monitors def op2(sim, t_): # sample eeg w/ period of 10*dt - sim, raw = jax.lax.scan(op1, sim, t_ * 10 + np.r_[:10]) + sim, raw = jax.lax.scan(op1, sim, t_ * 10 + np.r_[:10], + unroll=10 if unroll else 1) sim['eeg_buf'], eeg_t = eeg_sample(sim['eeg_buf']) return sim, (raw, eeg_t) # outer scan steps from one bold sample to the next def op3(sim, T): # run for 5 samples of eeg - sim, (raw, eeg) = jax.lax.scan(op2, sim, T*50 + np.r_[:5]) + sim, (raw, eeg) = jax.lax.scan(op2, sim, T*50 + np.r_[:5], + unroll=5 if unroll else 1) # sample fmri w/ period of 5*10*dt _, fmri = bold_sample(sim['bold_buf']) return sim, (raw, eeg, fmri) + return sim, op3 + +def test_multiple_periods(): + sim, op3 = setup_multiple_periods() ts = np.r_[:10] sim, (raw, eeg, fmri) = jax.lax.scan(op3, sim, ts) - assert raw.shape == (ts.size, 5, 10, 32) assert eeg.shape == (ts.size, 5, 64) assert fmri.shape == (ts.size, 32) - # that covers most use cases - # metrics like fcd can be at done at fmri time scale - +@pytest.mark.parametrize('dojit,unroll,grad', [ + (True,True,True), (False,False,True), + (True, True, False), (False, False, False), +]) +def test_multiple_periods_perf(benchmark, dojit, unroll, grad): + sim, op3 = setup_multiple_periods(unroll) + ts = np.r_[:100] + def run(freq, sim): + sim = sim.copy() + sim['freq'] = freq + sim, (raw, eeg, fmri) = jax.lax.scan(op3, sim, ts, + unroll=10 if unroll else 1) + return np.sum(np.square(eeg)) + if grad: + run = jax.grad(run) + assert np.abs(run(0.2,sim)) > 0 + if dojit: + run = jax.jit(run) + run(0.2, sim) + benchmark(lambda : run(0.2, sim))