Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add an example to visualize the evolutuion of fitted parameters in bounded parameter space #43

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
166 changes: 166 additions & 0 deletions examples/hh_nevergrad_errorevolution.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
from brian2 import *
from brian2modelfitting import *
import pandas as pd
from matplotlib.animation import FuncAnimation


# load data
df_inp_traces = pd.read_csv('input_traces_hh.csv')
df_out_traces = pd.read_csv('output_traces_hh.csv')
inp_traces = df_inp_traces.to_numpy()
inp_traces = inp_traces[:-1, 1:]
out_traces = df_out_traces.to_numpy()
out_traces = out_traces[:-1, 1:]

# model parameters
area = 20000. * umetre ** 2
Cm = 1. * ufarad * cm ** -2 * area
E_l = -65. * mV
E_k = -90. * mV
E_na = 50. * mV
Vt = -63. * mV
init_v = {'v': -65. * mV}
dt = 0.01 * ms
defaultclock.dt = dt

# model definition
hodgkin_huxley = Equations(
'''dv/dt = (
(g_l * (E_l - v)
- g_na * (m ** 3) * h * (v - E_na)
- g_k * (n ** 4) * (v - E_k) + I) / Cm) : volt
dm/dt = (
0.32 * (mV ** -1 ) * (13.0 * mV - v + Vt)
/ (exp((13.0 * mV - v + Vt) / (4.0 * mV)) - 1.0) / ms * (1 - m)
- 0.28 * (mV ** -1) * (v - Vt - 40.0 * mV)
/ (exp((v - Vt - 40.0 * mV) / (5.0 * mV)) - 1.0) / ms * m) : 1
dn/dt = (
0.032 * (mV ** -1) * (15.0 * mV - v + Vt)
/ (exp((15.0 * mV - v + Vt) / (5.0 * mV)) - 1.0) / ms * (1.0 - n)
- 0.5 * exp((10.0 * mV - v + Vt) / (40.0 * mV)) / ms * n) : 1
dh/dt = (
0.128 * exp((17.0 * mV - v + Vt) / (18.0 * mV)) / ms * (1.0 - h)
- 4.0 / (1 + exp((40.0 * mV - v + Vt) / (5.0 * mV))) / ms * h) : 1
g_na : siemens (constant)
g_k : siemens (constant)
g_l : siemens (constant)''')

# optimizer instantiation
optimizer = NevergradOptimizer()

# metric instantiation
metric = MSEMetric()

# fitter definition and fitting procedure
n_samples = 40
fitter = TraceFitter(
model=hodgkin_huxley,
input_var='I', input=inp_traces * amp,
output_var='v', output=out_traces * mV,
dt=dt,
n_samples=n_samples,
method='exponential_euler',
param_init=init_v)


def callback(params, errors, best_params, best_error, index):
"""Custom callback.

Print the best error for each optimization round."""
print(f'[round {index + 1}]\tbest error: {np.min(errors)}')


# fitting procedure
n_rounds = 25
res, error = fitter.fit(
optimizer=optimizer,
metric=metric,
n_rounds=n_rounds,
callback=callback,
g_l=[1.e-09 * siemens, 1.e-07 * siemens],
g_na=[2.e-06 * siemens, 2.e-04 * siemens],
g_k=[6.e-07 * siemens, 6.e-05 * siemens])

# visualization of best fitted traces
fit_traces = fitter.generate_traces(params=res, param_init=init_v)

nrows = 2
ncols = fit_traces.shape[0]
fig, axs = plt.subplots(
nrows=nrows, ncols=ncols, sharex=True,
gridspec_kw={'height_ratios': [3, 1]}, figsize=(15, 4))
for idx in range(ncols):
axs[0, idx].plot(out_traces[idx, :].T, 'k-', label='$V_m^{measured}(t)$')
axs[0, idx].plot(fit_traces[idx, :].T / mV, 'r--', label='$V_m^{fit}(t)$')
axs[1, idx].plot(inp_traces[idx, :].T / amp, 'k-', label='$I(t)$')
axs[0, idx].grid()
axs[1, idx].grid()
axs[1, idx].set_xlabel('t [ms]')
if idx == 0:
axs[0, idx].set_ylabel('$V_m$ [mV]')
axs[1, idx].set_ylabel('$I$ [A/cm$^2$]')
handles, labels = [
(h + l) for h, l
in zip(axs[0, idx].get_legend_handles_labels(),
axs[1, idx].get_legend_handles_labels())]
fig.legend(handles, labels, loc='upper right')
plt.tight_layout()
plt.show()

# visualization of errors and parameters evolving over time
full_output = fitter.results(format='dict', use_units=False)
g_k = full_output['g_k']
g_na = full_output['g_na']
g_l = full_output['g_l']
error = full_output['error']

fig = plt.figure(figsize=(15, 5))
ax1 = fig.add_subplot(1, 2, 1, projection='3d')
ax1.set(xlabel='$g_K$ [S]', ylabel='$g_{Na}$ [S]', zlabel='$g_l$ [S]',
xlim=(0, g_k.max() * 1.01), ylim=(0, g_na.max() * 1.01),
zlim=(0, g_l.max() * 1.01))
ax1.ticklabel_format(useOffset=True, style='scientific', scilimits=(0, 0))
ax2 = fig.add_subplot(1, 2, 2)
ax2.set(xlabel='round', ylabel='error')
ax2.grid()


def init():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the init and animate function could use a short comment for those not familiar with FuncAnimation.

"""Scatter plot of the initial population of parameters and the
best associated error."""
ax1.plot3D(g_k[:n_samples], g_na[:n_samples], g_l[:n_samples],
'b*', markersize=4, label='init population')
ax2.plot([0], np.min(error[:n_samples]),
'b*', markersize=8, label='init best error')
ax1.legend()
ax2.legend()


def animate(frame):
"""Scatter plot current population of parameters for each frame,
starting from the second round of optimization.

Number of frames should correspond to the number of optimization
rounds."""
istart = frame * n_samples
iend = istart + n_samples
if (res['g_k'] / siemens in g_k[istart:iend]
and res['g_na'] / siemens in g_na[istart:iend]
and res['g_l'] / siemens in g_l[istart:iend]):
ax1.plot3D([res['g_k']], [res['g_na']], [res['g_l']],
'r*', markersize=8, label='best params')
ax2.plot(frame, np.min(error[istart:iend]),
'r*', markersize=8, label='best error')
ax1.legend()
ax2.legend()
else:
ax1.plot3D(g_k[istart:iend], g_na[istart:iend], g_l[istart:iend],
'ko', markersize=4, zorder=-1, alpha=0.3)
ax2.plot(frame, np.min(error[istart:iend]), 'ko',
markersize=4, zorder=-1)


anim = FuncAnimation(
fig, animate, init_func=init, frames=np.arange(1, n_rounds), repeat=False)
plt.tight_layout()
plt.show()