-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathjr-psd-fit-evo.py
133 lines (112 loc) · 4.24 KB
/
jr-psd-fit-evo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
"""
The goal of this example is to create a model of power spectral densities
generated by the Jansen-Rit model, and to estimate parameters corresponding
to different PSDs.
- fsamp 500 Hz, win 4s, overlap 1s, Welsh PSD, log dB, 60 s data
"""
import time
import numpy as np
import pylab as pl
import vbjax as vb
import jax
import jax.numpy as jp
import jax.example_libraries.optimizers as jopt
import tqdm
# first model will just be a single Jansen-Rit
def model(state, parameters):
return vb.jr_dfun(state, 0, parameters)
# define a function to run the simulation & compute Welch PSD
def run_sim_psd(parameters, rng_key):
# here we are choosing 4 parameters to optimize, but
# it's best to look at the paper to select the best ones for your study
A, B, a, b, v0, r, J, lsig = parameters
# do a short simulation
dt = 2.0 # ms == 500 Hz sampling frequency
ntime = int(60e3 / dt)
initial_state = jp.ones((6, 1))
_, loop = vb.make_sde(dt=dt, dfun=model, gfun=jp.exp(lsig))
noise = vb.randn(ntime, *initial_state.shape, key=rng_key)
parameters = vb.jr_default_theta._replace(A=A, B=B, a=a, b=b)
states = loop(initial_state, noise, parameters)
lfp = states[:, 1] - states[:, 0]
lfp = jp.diff(lfp, axis=0) # remove 1/f
# compute spectra
win_size = int(4e3 / dt)
overlap = win_size // 4 # 25 %
ftfreq = jp.fft.fftfreq(win_size, dt) * 1e3 # kHz -> Hz
windows = jp.array([lfp[i*overlap:i*overlap+win_size,0]
for i in range((len(lfp) - win_size)//overlap)])
windows = windows * jp.hanning(win_size)
windows_fft = jp.fft.fft(windows, axis=1) # apply fft over time
windows_psd = jp.mean(jp.abs(windows_fft), axis=0) # average power
return ftfreq, windows_psd
# load data
Pz = np.load('Sebastien_Spectrum_Pz.npy')
# run sim for two example parameters
parameters = jp.array([
(3.25, 22.0, 0.1, 0.05, 5.52, 0.56, 135.0, -8.0),
(3.25, 22.0, 0.1, 0.05, 5.52, 0.56, 135.0, -7.0),
(3.31, 22.5, 0.11, 0.049, 5.59, 0.55, 130.0, -8.0),
])
rng_keys = jax.random.split(jax.random.PRNGKey(1106), len(parameters))
psds = [run_sim_psd(p, k) for p, k in zip(parameters, rng_keys)]
# show those psds
pl.figure()
ftfreq = psds[0][0]
ftmask = (ftfreq > 0)*(ftfreq < 80)
for _, psd in psds:
pl.plot(ftfreq[:175], jp.log(psd[ftmask])[:175])
pl.plot(ftfreq[:175], Pz.T, 'r')
pl.xlim([0, 50])
pl.grid(1)
pl.legend([str(_) for _ in parameters])
pl.ylabel('PSD')
pl.xlabel('Hz')
pl.title('Example Simulated Welch PSD on 60s Jansen-Rit')
# as a first fit, we'll use
# - sum square error as a loss function
# - the first simulated PSD above as the "data" to fit
# - PSD values in the band of 0 to 80 Hz
target_psd = jp.exp(jp.array( Pz[0] ))
def lognorm(vector):
return jp.log(vector / jp.linalg.norm(vector))
def loss(opt_params, rng_key):
_, sim_psd = run_sim_psd(opt_params, rng_key)
err = lognorm(sim_psd[:target_psd.size]) - lognorm(target_psd)
return jp.sum(jp.square(err))
# compute loss & gradients on many parameters at once
vloss = jax.jit(jax.vmap(loss, in_axes=(1,0)))
pvloss = jax.pmap(jax.vmap(loss, in_axes=(1,0)))
# do a search for best params
rounds = 5000
round_size = 64
zs = vb.randn(rounds, parameters.shape[1], round_size)
rng_keys = jax.random.split(jax.random.PRNGKey(1106), rounds*round_size).reshape(rounds, round_size, 2)
bsf_params = jp.array(parameters[0])*1.3
bsf = vloss(bsf_params[:,None], rng_keys[0,:1]).min()
sdseq = 0.5/2**np.r_[:20]
tik = time.time()
for i in (pbar := tqdm.trange(rounds)):
sd = sdseq[i//2000]
p = bsf_params[:,None]*(1 + sd*zs[i])
pp = p.reshape(parameters.shape[1], vb.cores, -1).transpose(1, 0, 2)
v = pvloss(pp, rng_keys[i].reshape(vb.cores, -1, 2)).ravel()
imin = jp.argmin(v)
if v[imin] < bsf:
bsf = v[imin]
bsf_params = p[:, imin]
sps = round_size * i / (time.time() - tik)
pbar.set_description(f'loss {bsf:0.3f}, {sps:0.1g} s/s, sd {sd}')
print('final params', bsf_params)
# now show the opt sim fc found
pl.figure()
_, bsf_psd = run_sim_psd(bsf_params, rng_keys[0,0])
pl.plot(ftfreq[:175], lognorm(target_psd))
pl.plot(ftfreq[:175], lognorm(bsf_psd[:175]))
pl.xlim([0, 50])
pl.grid(1)
pl.legend(('Target', 'Fit'))
pl.ylabel('PSD')
pl.xlabel('Hz')
pl.title('Simulated PSDs & Fit')
pl.show()