Skip to content

Commit

Permalink
Merge pull request #41 from ins-amu/monitor
Browse files Browse the repository at this point in the history
Add monitors
  • Loading branch information
marmaduke woodman authored Nov 13, 2023
2 parents dc9e385 + c14ee55 commit cdf991f
Show file tree
Hide file tree
Showing 11 changed files with 336 additions and 22 deletions.
48 changes: 41 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,15 @@ for Jax to discover the GPU(s).

## Examples

Here's an all-to-all connected network with Montbrio-Pazo-Roxin
Here are some examples of simulations which show what you can
do with the library. Because they are implemented atop Jax, it
is easy to take gradients for optimization or MCMC, or do efficient
GPU parallel batching. Or both ಠ_ರೃ

### Simple network

Here's the smallest simulation you might want to do:
an all-to-all connected network with Montbrio-Pazo-Roxin
mass model dynamics,

```python
Expand All @@ -51,11 +59,9 @@ vb.plot_states(xs, 'rV', jpg='example1', show=True)

While integrators and mass models tend to be the same across publications, but
the network model itself varies (regions vs surface, stimulus etc), vbjax allows
user to focus on defining the `network` and then getting time series. Because
the work is done by Jax, this is all auto-differentiable, GPU-able so friendly to
use with common machine learning algorithms.
user to focus on defining the `network` and then getting time series.

### Neural field
### Simplest neural field

Here's a neural field,
```python
Expand All @@ -82,7 +88,35 @@ vb.make_field_gif(xt[::10], 'example2.gif')
![](example2.gif)

This example shows how the field forms patterns gradually despite the
noise in the simulation.
noise in the simulation, due to the effect of local connectivity


### MCMC estimation of neural field activity

For MCMC estimates with NumPyro we define a function to compute
posterior log probability `p(theta | x)`,
```python
def logp(xt=None):
x0h = numpyro.sample('x0h', dist.Normal(jnp.zeros((nlat, nlon)), 1))
xth_mu = loop(x0h, ts, k)
numpyro.sample('xth', dist.Normal(xth_mu, 1), obs=xt)
```
run MCMC w/ NUTS,
```python
mcmc = MCMC(NUTS(logp), num_warmup=500, num_samples=500)
mcmc.run(jax.random.PRNGKey(0), xt=xt)
x0h = mcmc.get_samples()['x0h']
```
check diagnostics like estimated sample size, shrinkage and z-score,
```python
ess = numpyro.diagnostics.effective_sample_size(x0h.reshape((1, 500, -1)))
assert ess.min() > 100
shrinkage, zscore = vbjax.shrinkage_zscore(x0, x0h, 1)
assert shrinkage.min() > 0.7
assert zscore.max() < 1.5
```
Full code is in the [test suite](vbjax/tests/test_field_inference.py), can
be run `pytest -m slow`, since it takes about 5 minutes to run on GPU, and 12 minutes on single CPU core.

### Fitting an autoregressive process

Expand Down Expand Up @@ -133,7 +167,7 @@ MCMC algorithms.

## HPC usage

We use this on HPC systems, most easily with container images.
We use this on HPC systems, most easily with container images. Open an issue if it doesn't work.

<details><summary>CSCS Piz Daint</summary>

Expand Down
4 changes: 1 addition & 3 deletions pytest.ini
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,4 @@ markers =
testpaths =
vbjax/tests
addopts =
-m 'not slow'
log_cli=true
log_level=INFO
-m 'not slow'
9 changes: 7 additions & 2 deletions vbjax/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
print('vbjax ███▒▒▒▒▒▒▒ loading')

# default to setting up many cores
def _use_many_cores():
Expand All @@ -7,9 +8,9 @@ def _use_many_cores():
n = mp.cpu_count()
value = '--xla_force_host_platform_device_count=%d' % n
os.environ['XLA_FLAGS'] = value
print(f'vbjax ( ˘ ³˘)ノ°゚º❍。 using {n} cores')
print(f'vbjax (ノ☉ヮ⚆)ノ ⌒*:・゚✧ can haz {n} cores')
else:
sys.stderr.write('XLA_FLAGS already set\n')
print('vbjax XLA_FLAGS already set\n')

_use_many_cores()

Expand All @@ -19,16 +20,19 @@ def _use_many_cores():
from .neural_mass import (
JRState, JRTheta, jr_dfun, jr_default_theta,
MPRState, MPRTheta, mpr_dfun, mpr_default_theta,
BOLDTheta, compute_bold_theta, bold_default_theta, bold_dfun
)
from .regmap import make_region_mapping
from .coupling import (
make_diff_cfun, make_linear_cfun, make_delayed_coupling
)
from .connectome import make_conn_latent_mvnorm
from .sparse import make_spmv, csr_to_jax_bcoo, make_sg_spmv
from .monitor import make_timeavg, make_bold, make_gain, make_offline
from .layers import make_dense_layers
from .diagnostics import shrinkage_zscore
from .embed import embed_neural_flow, embed_polynomial, embed_gradient, embed_autoregress
from .util import to_jax, to_np
from ._version import __version__

# some random setup for convenience
Expand Down Expand Up @@ -79,3 +83,4 @@ def animate(i):
writer = animation.PillowWriter(fps=fps, bitrate=400)
ani.save(gifname, writer=writer)

print('vbjax ᕕ(ᐛ)ᕗ ready')
66 changes: 66 additions & 0 deletions vbjax/monitor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import jax
import jax.numpy as np
from .loops import heun_step
from .neural_mass import BOLDTheta, bold_dfun


def make_offline(step_fn, sample_fn, *args):
"Compute monitor samples in an offline or batch fashion."
def op(mon, x):
mon = step_fn(mon, x)
return mon, None
def offline_sample(mon, xs):
mon, _ = jax.lax.scan(op, mon, xs)
mon, samp = sample_fn(mon)
return mon, samp
return offline_sample

# NB shape here is the input shape of neural activity

def make_timeavg(shape):
"Make a time average monitor."
new = lambda : {'y': np.zeros(shape), 'n': 0}
def step(buf, x):
return {'y': buf['y'] + x,
'n': buf['n'] + 1}
def sample(buf):
return new(), buf['y'] / buf['n']
return new(), step, sample


def compute_sarvas_gain(q, r, o, att, Ds=0, Dc=0) -> np.ndarray:
# https://gist.github.com/maedoc/add7c3206f81d59105753a04f7c1fcf4
pass


def make_gain(gain, shape=None):
"Make a gain-matrix monitor suitable for sEEG, EEG & MEG."
tavg_shape = gain.shape[:1] + (shape[1:] if shape else ())
buf, tavg_step, tavg_sample = make_timeavg(tavg_shape)
step = lambda b, x: tavg_step(b, gain @ x)
return buf, step, tavg_sample


def make_bold(shape, dt, p: BOLDTheta):
"Make a BOLD fMRI monitor."
sfvq = np.ones((4,) + shape)
sfvq = sfvq.at[0].set(0)
def step(sfvq, x):
return heun_step(sfvq, bold_dfun, dt, x, p)
def sample(buf):
s, f, v, q = buf
return buf, p.v0 * (p.k1*(1 - q) + p.k2*(1 - q / v) + p.k3*(1 - v))
return sfvq, step, sample


def make_fc(shape):
# welford online cov estimate yields o(1) backprop memory usage
# https://github.com/maedoc/tvb-fut/blob/master/lib/github.com/maedoc/tvb-fut/stats.fut#L9
pass

def make_fft(shape, period):
pass

# TODO sliding window versions of those

# @jax.checkpoint to lower memory usage?
45 changes: 45 additions & 0 deletions vbjax/neural_mass.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,3 +79,48 @@ def mpr_dfun(ys, c, p):
(1 / p.tau) * (V ** 2 + p.eta + p.J * p.tau *
r + p.I + I_c - (np.pi ** 2) * (r ** 2) * (p.tau ** 2))
])


BOLDTheta = collections.namedtuple(
typename='BOLDTheta',
field_names='tau_s,tau_f,tau_o,alpha,te,v0,e0,epsilon,nu_0,'
'r_0,recip_tau_s,recip_tau_f,recip_tau_o,recip_alpha,'
'recip_e0,k1,k2,k3'
)

def compute_bold_theta(
tau_s=0.65,
tau_f=0.41,
tau_o=0.98,
alpha=0.32,
te=0.04,
v0=4.0,
e0=0.4,
epsilon=0.5,
nu_0=40.3,
r_0=25.0,
):
recip_tau_s = 1.0 / tau_s
recip_tau_f = 1.0 / tau_f
recip_tau_o = 1.0 / tau_o
recip_alpha = 1.0 / alpha
recip_e0 = 1.0 / e0
k1 = 4.3 * nu_0 * e0 * te
k2 = epsilon * r_0 * e0 * te
k3 = 1.0 - epsilon
return BOLDTheta(**locals())

bold_default_theta = compute_bold_theta()

def bold_dfun(sfvq, x, p: BOLDTheta):
s, f, v, q = sfvq
ds = x - p.recip_tau_s * s - p.recip_tau_f * (f - 1)
df = s
dv = p.recip_tau_o * (f - v ** p.recip_alpha)
dq = p.recip_tau_o * (f * (1 - (1 - p.e0) ** (1 / f)) * p.recip_e0
- v ** p.recip_alpha * (q / v))
return np.array([ds, df, dv, dq])


# TODO other models
# TODO codim3 https://gist.github.com/maedoc/01cea5cad9c833c56349392ee7d9b627
3 changes: 1 addition & 2 deletions vbjax/shtlc.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@
try:
import shtns
except ImportError:
import sys
sys.stderr.write('** shtns is not available\n')
print('vbjax ò_ô shtns is not available')


# Grid functions
Expand Down
11 changes: 4 additions & 7 deletions vbjax/sparse.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
import numpy as np
import jax.dlpack
import scipy.sparse
import jax
import jax.experimental.sparse as jsp


_to_np = lambda x: np.from_dlpack(x)
_to_jax = lambda x: jax.dlpack.from_dlpack(x.__dlpack__())
import vbjax as vb


def make_spmv(A, is_symmetric=False, use_scipy=False):
Expand All @@ -30,8 +27,8 @@ def make_spmv(A, is_symmetric=False, use_scipy=False):
"""
AT = A.T.copy()
@jax.custom_vjp
def matvec(x): return _to_jax(A @ _to_np(x))
def matvec_tr(x): return _to_jax(AT @ _to_np(x))
def matvec(x): return vb.to_jax(A @ vb.to_np(x))
def matvec_tr(x): return vb.to_jax(AT @ vb.to_np(x))
def matvec_fwd(x): return matvec(x), None
def matvec_bwd(res, g): return matvec(g) if is_symmetric else matvec_tr(g),
matvec.defvjp(matvec_fwd, matvec_bwd)
Expand Down
3 changes: 3 additions & 0 deletions vbjax/tests/test_loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,6 @@ def dfun(xt, x, t, p):
return -xt[t - 5]
_, sdde = vb.make_sdde(1.0, 5, dfun, 0.01)
sdde(vb.randn(20)+10, None)


# TODO theta method? https://gist.github.com/maedoc/c47acb9d346e31017e05324ffc4582c1
Loading

0 comments on commit cdf991f

Please sign in to comment.