Skip to content

Commit

Permalink
fix inhomogeneous ODE solver
Browse files Browse the repository at this point in the history
  • Loading branch information
C.A.P. Linssen committed Jul 19, 2023
1 parent 89ab59d commit 75a20eb
Show file tree
Hide file tree
Showing 3 changed files with 148 additions and 18 deletions.
2 changes: 1 addition & 1 deletion odetoolbox/analytic_integrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def __init__(self, solver_dict, spike_times: Optional[Dict[str, List[float]]] =


#
# perform substtitution in update expressions ahead of time to save time later
# perform substitution in update expressions ahead of time to save time later
#

for k, v in self.update_expressions.items():
Expand Down
32 changes: 18 additions & 14 deletions odetoolbox/system_of_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,20 +223,24 @@ def generate_propagator_solver(self):
sym_str = "__P__{}__{}".format(str(self.x_[row]), str(self.x_[col]))
P_sym[row, col] = sympy.parsing.sympy_parser.parse_expr(sym_str, global_dict=Shape._sympy_globals)
P_expr[sym_str] = P[row, col]
if _is_zero(self.b_[col]):
# homogeneous ODE
update_expr_terms.append(sym_str + " * " + str(self.x_[col]))
else:
# inhomogeneous ODE
if _is_zero(self.A_[col, col]):
# of the form x' = const
update_expr_terms.append(sym_str + " * " + str(self.x_[col]) + " + " + Config().output_timestep_symbol + " * " + str(self.b_[col]))
else:
particular_solution = -self.b_[col] / self.A_[col, col]
update_expr_terms.append(sym_str + " * (" + str(self.x_[col]) + " - (" + str(particular_solution) + "))" + " + (" + str(particular_solution) + ")")

update_expr[str(self.x_[row])] = " + ".join(update_expr_terms)
update_expr[str(self.x_[row])] = sympy.parsing.sympy_parser.parse_expr(update_expr[str(self.x_[row])], global_dict=Shape._sympy_globals)
update_expr_terms.append(sym_str + " * " + str(self.x_[col]))

_update_expr_homogeneous = " + ".join(update_expr_terms)

if _is_zero(self.b_[row]):
# homogeneous ODE
_update_expr = _update_expr_homogeneous
else:
# inhomogeneous ODE
if _is_zero(self.A_[row, row]):
# of the form x' = const
_update_expr = _update_expr_homogeneous + " + " + Config().output_timestep_symbol + " * " + str(self.b_[row])
else:
sym_str = "__P__{}__{}".format(str(self.x_[row]), str(self.x_[row]))
particular_solution = -self.b_[row] / self.A_[row, row]
_update_expr = sym_str + " * (" + _update_expr_homogeneous + " - (" + str(particular_solution) + "))" + " + (" + str(particular_solution) + ")"

update_expr[str(self.x_[row])] = sympy.parsing.sympy_parser.parse_expr(_update_expr, global_dict=Shape._sympy_globals)
if not _is_zero(self.b_[row]):
# only simplify in case an inhomogeneous term is present
update_expr[str(self.x_[row])] = _custom_simplify_expr(update_expr[str(self.x_[row])])
Expand Down
132 changes: 129 additions & 3 deletions tests/test_inhomogeneous.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,22 @@
#

import numpy as np
import sympy
import pytest
import sympy
from scipy.integrate import odeint

import odetoolbox

from odetoolbox.analytic_integrator import AnalyticIntegrator
from odetoolbox.shapes import Shape
from odetoolbox.system_of_shapes import SystemOfShapes, PropagatorGenerationException
from odetoolbox.spike_generator import SpikeGenerator

try:
import matplotlib as mpl
mpl.use('Agg')
import matplotlib.pyplot as plt
INTEGRATION_TEST_DEBUG_PLOTS = True
except ImportError:
INTEGRATION_TEST_DEBUG_PLOTS = False


class TestInhomogeneous:
Expand Down Expand Up @@ -212,3 +220,121 @@ def test_inhomogeneous_solver_second_order_combined_system_api(self):
result = odetoolbox.analysis(indict)
assert len(result) == 1 \
and result[0]["solver"] == "analytical"

def test_double_exponential(self):
r"""Test propagators generation for double exponential"""

def time_to_max(tau_1, tau_2):
r"""
Time of maximum.
"""
tmax = (np.log(tau_1) - np.log(tau_2)) / (1. / tau_2 - 1. / tau_1)
return tmax

def unit_amplitude(tau_1, tau_2):
r"""
Scaling factor ensuring that amplitude of solution is one.
"""
tmax = time_to_max(tau_1, tau_2)
alpha = 1. / (np.exp(-tmax / tau_1) - np.exp(-tmax / tau_2))
return alpha

def flow(y, t, tau_1, tau_2, alpha, dt):
r"""
Rhs of ODE system to be solved.
"""
dy1dt = -y[0] / tau_1 + alpha * (1 / tau_2 - 1 / tau_1)
dy2dt = y[0] - y[1] / tau_2

return np.array([dy1dt, dy2dt])

indict = {"dynamics": [{"expression": "I_aux' = -I_aux / tau_1 + alpha * (1.0 / tau_2 - 1.0 / tau_1) * weighted_input_spikes",
"initial_values": {"I_aux": "0."}},
{"expression": "I' = I_aux - I / tau_2",
"initial_values": {"I": "0"}}],
"options": {"output_timestep_symbol": "__h"},
"parameters": {"tau_1": "10",
"tau_2": "2",
"w": "3.14",
"alpha": str(unit_amplitude(tau_1=10., tau_2=2.)),
"weighted_input_spikes": "0."}}

result = odetoolbox.analysis(indict)

assert len(result) == 1 \
and result[0]["solver"] == "analytical"

w = 3.14 # weight (amplitude; pA)
tau_1 = 10. # decay time constant (ms)
tau_2 = 2. # rise time constant (ms)
dt = .125 # time resolution (ms)
T = 500. # simulation time (ms)
input_spike_times = np.array([100., 300.]) # array of input spike times (ms)

alpha = unit_amplitude(tau_1, tau_2)

stimuli = [{"type": "list",
"list": " ".join([str(el) for el in input_spike_times]),
"variables": ["I_aux"]}]

spike_times = SpikeGenerator.spike_times_from_json(stimuli, T)

ODE_INITIAL_VALUES = {"I": 0., "I_aux": 0.}

# simulate with ode-toolbox
solver_dict = odetoolbox.analysis(indict, disable_stiffness_check=True)
assert len(solver_dict) == 1
solver_dict = solver_dict[0]
assert solver_dict["solver"] == "analytical"

N = int(np.ceil(T / dt) + 1)
timevec = np.linspace(0., T, N)
analytic_integrator = AnalyticIntegrator(solver_dict, spike_times)
analytic_integrator.shape_starting_values["I_aux"] = w * alpha * (1./tau_2 - 1./tau_1)
analytic_integrator.set_initial_values(ODE_INITIAL_VALUES)
analytic_integrator.reset()
state = {"timevec": [], "I": [], "I_aux": []}
for step, t in enumerate(timevec):
state_ = analytic_integrator.get_value(t)
state["timevec"].append(t)
for sym, val in state_.items():
state[sym].append(val)

# solve with odeint
ts0 = np.arange(0., input_spike_times[0] - dt / 2, dt)
ts1 = np.arange(input_spike_times[0], input_spike_times[1] - dt / 2, dt)
ts2 = np.arange(input_spike_times[1], T + dt, dt)

y_ = odeint(flow, [0., 0.], ts0, args=(tau_1, tau_2, alpha, dt))
y_ = np.vstack([y_, odeint(flow, [y_[-1, 0] + w * alpha * (1. / tau_2 - 1. / tau_1), y_[-1, 1]], ts1, args=(tau_1, tau_2, alpha, dt))])
y_ = np.vstack([y_, odeint(flow, [y_[-1, 0] + w * alpha * (1. / tau_2 - 1. / tau_1), y_[-1, 1]], ts2, args=(tau_1, tau_2, alpha, dt))])

rec_I_interp = np.interp(np.hstack([ts0, ts1, ts2]), timevec, state['I'])
rec_I_aux_interp = np.interp(np.hstack([ts0, ts1, ts2]), timevec, state['I_aux'])

if INTEGRATION_TEST_DEBUG_PLOTS:
tmax = time_to_max(tau_1, tau_2)
mpl.rcParams['text.usetex'] = True

fig, ax = plt.subplots(nrows=2, figsize=(5, 4), dpi=300)
ax[0].plot(timevec, state['I_aux'], '--', lw=3, color='k', label=r'$I_\mathsf{aux}(t)$ (NEST)')
ax[0].plot(timevec, state['I'], '-', lw=3, color='k', label=r'$I(t)$ (NEST)')
ax[0].plot(np.hstack([ts0, ts1, ts2]), y_[:, 0], '--', lw=2, color='r', label=r'$I_\mathsf{aux}(t)$ (odeint)')
ax[0].plot(np.hstack([ts0, ts1, ts2]), y_[:, 1], '-', lw=2, color='r', label=r'$I(t)$ (odeint)')

for tin in input_spike_times:
ax[0].vlines(tin + tmax, ax[0].get_ylim()[0], ax[0].get_ylim()[1], colors='k', linestyles=':')

ax[1].semilogy(np.hstack([ts0, ts1, ts2]), np.abs(y_[:, 1] - rec_I_interp), label="I")
ax[1].semilogy(np.hstack([ts0, ts1, ts2]), np.abs(y_[:, 0] - rec_I_aux_interp), linestyle="--", label="I_aux")
ax[1].set_ylabel("Error")

for _ax in ax:
_ax.set_xlim(0., T + dt)
_ax.legend()

ax[-1].set_xlabel(r'time (ms)')

fig.savefig('double_exp_test.png')

np.testing.assert_allclose(y_[:, 1], rec_I_interp, atol=1E-7)

0 comments on commit 75a20eb

Please sign in to comment.