From 75a20eb3e72598f9a42a7a27b674e1611c6335c0 Mon Sep 17 00:00:00 2001 From: "C.A.P. Linssen" Date: Wed, 19 Jul 2023 11:21:32 -0700 Subject: [PATCH] fix inhomogeneous ODE solver --- odetoolbox/analytic_integrator.py | 2 +- odetoolbox/system_of_shapes.py | 32 ++++---- tests/test_inhomogeneous.py | 132 +++++++++++++++++++++++++++++- 3 files changed, 148 insertions(+), 18 deletions(-) diff --git a/odetoolbox/analytic_integrator.py b/odetoolbox/analytic_integrator.py index 1ae10123..aa113957 100644 --- a/odetoolbox/analytic_integrator.py +++ b/odetoolbox/analytic_integrator.py @@ -98,7 +98,7 @@ def __init__(self, solver_dict, spike_times: Optional[Dict[str, List[float]]] = # - # perform substtitution in update expressions ahead of time to save time later + # perform substitution in update expressions ahead of time to save time later # for k, v in self.update_expressions.items(): diff --git a/odetoolbox/system_of_shapes.py b/odetoolbox/system_of_shapes.py index 3d3f7327..801e3ef7 100644 --- a/odetoolbox/system_of_shapes.py +++ b/odetoolbox/system_of_shapes.py @@ -223,20 +223,24 @@ def generate_propagator_solver(self): sym_str = "__P__{}__{}".format(str(self.x_[row]), str(self.x_[col])) P_sym[row, col] = sympy.parsing.sympy_parser.parse_expr(sym_str, global_dict=Shape._sympy_globals) P_expr[sym_str] = P[row, col] - if _is_zero(self.b_[col]): - # homogeneous ODE - update_expr_terms.append(sym_str + " * " + str(self.x_[col])) - else: - # inhomogeneous ODE - if _is_zero(self.A_[col, col]): - # of the form x' = const - update_expr_terms.append(sym_str + " * " + str(self.x_[col]) + " + " + Config().output_timestep_symbol + " * " + str(self.b_[col])) - else: - particular_solution = -self.b_[col] / self.A_[col, col] - update_expr_terms.append(sym_str + " * (" + str(self.x_[col]) + " - (" + str(particular_solution) + "))" + " + (" + str(particular_solution) + ")") - - update_expr[str(self.x_[row])] = " + ".join(update_expr_terms) - update_expr[str(self.x_[row])] = sympy.parsing.sympy_parser.parse_expr(update_expr[str(self.x_[row])], global_dict=Shape._sympy_globals) + update_expr_terms.append(sym_str + " * " + str(self.x_[col])) + + _update_expr_homogeneous = " + ".join(update_expr_terms) + + if _is_zero(self.b_[row]): + # homogeneous ODE + _update_expr = _update_expr_homogeneous + else: + # inhomogeneous ODE + if _is_zero(self.A_[row, row]): + # of the form x' = const + _update_expr = _update_expr_homogeneous + " + " + Config().output_timestep_symbol + " * " + str(self.b_[row]) + else: + sym_str = "__P__{}__{}".format(str(self.x_[row]), str(self.x_[row])) + particular_solution = -self.b_[row] / self.A_[row, row] + _update_expr = sym_str + " * (" + _update_expr_homogeneous + " - (" + str(particular_solution) + "))" + " + (" + str(particular_solution) + ")" + + update_expr[str(self.x_[row])] = sympy.parsing.sympy_parser.parse_expr(_update_expr, global_dict=Shape._sympy_globals) if not _is_zero(self.b_[row]): # only simplify in case an inhomogeneous term is present update_expr[str(self.x_[row])] = _custom_simplify_expr(update_expr[str(self.x_[row])]) diff --git a/tests/test_inhomogeneous.py b/tests/test_inhomogeneous.py index 70ff4b85..5012f900 100644 --- a/tests/test_inhomogeneous.py +++ b/tests/test_inhomogeneous.py @@ -20,14 +20,22 @@ # import numpy as np -import sympy import pytest +import sympy +from scipy.integrate import odeint import odetoolbox from odetoolbox.analytic_integrator import AnalyticIntegrator -from odetoolbox.shapes import Shape -from odetoolbox.system_of_shapes import SystemOfShapes, PropagatorGenerationException +from odetoolbox.spike_generator import SpikeGenerator + +try: + import matplotlib as mpl + mpl.use('Agg') + import matplotlib.pyplot as plt + INTEGRATION_TEST_DEBUG_PLOTS = True +except ImportError: + INTEGRATION_TEST_DEBUG_PLOTS = False class TestInhomogeneous: @@ -212,3 +220,121 @@ def test_inhomogeneous_solver_second_order_combined_system_api(self): result = odetoolbox.analysis(indict) assert len(result) == 1 \ and result[0]["solver"] == "analytical" + + def test_double_exponential(self): + r"""Test propagators generation for double exponential""" + + def time_to_max(tau_1, tau_2): + r""" + Time of maximum. + """ + tmax = (np.log(tau_1) - np.log(tau_2)) / (1. / tau_2 - 1. / tau_1) + return tmax + + def unit_amplitude(tau_1, tau_2): + r""" + Scaling factor ensuring that amplitude of solution is one. + """ + tmax = time_to_max(tau_1, tau_2) + alpha = 1. / (np.exp(-tmax / tau_1) - np.exp(-tmax / tau_2)) + return alpha + + def flow(y, t, tau_1, tau_2, alpha, dt): + r""" + Rhs of ODE system to be solved. + """ + dy1dt = -y[0] / tau_1 + alpha * (1 / tau_2 - 1 / tau_1) + dy2dt = y[0] - y[1] / tau_2 + + return np.array([dy1dt, dy2dt]) + + indict = {"dynamics": [{"expression": "I_aux' = -I_aux / tau_1 + alpha * (1.0 / tau_2 - 1.0 / tau_1) * weighted_input_spikes", + "initial_values": {"I_aux": "0."}}, + {"expression": "I' = I_aux - I / tau_2", + "initial_values": {"I": "0"}}], + "options": {"output_timestep_symbol": "__h"}, + "parameters": {"tau_1": "10", + "tau_2": "2", + "w": "3.14", + "alpha": str(unit_amplitude(tau_1=10., tau_2=2.)), + "weighted_input_spikes": "0."}} + + result = odetoolbox.analysis(indict) + + assert len(result) == 1 \ + and result[0]["solver"] == "analytical" + + w = 3.14 # weight (amplitude; pA) + tau_1 = 10. # decay time constant (ms) + tau_2 = 2. # rise time constant (ms) + dt = .125 # time resolution (ms) + T = 500. # simulation time (ms) + input_spike_times = np.array([100., 300.]) # array of input spike times (ms) + + alpha = unit_amplitude(tau_1, tau_2) + + stimuli = [{"type": "list", + "list": " ".join([str(el) for el in input_spike_times]), + "variables": ["I_aux"]}] + + spike_times = SpikeGenerator.spike_times_from_json(stimuli, T) + + ODE_INITIAL_VALUES = {"I": 0., "I_aux": 0.} + + # simulate with ode-toolbox + solver_dict = odetoolbox.analysis(indict, disable_stiffness_check=True) + assert len(solver_dict) == 1 + solver_dict = solver_dict[0] + assert solver_dict["solver"] == "analytical" + + N = int(np.ceil(T / dt) + 1) + timevec = np.linspace(0., T, N) + analytic_integrator = AnalyticIntegrator(solver_dict, spike_times) + analytic_integrator.shape_starting_values["I_aux"] = w * alpha * (1./tau_2 - 1./tau_1) + analytic_integrator.set_initial_values(ODE_INITIAL_VALUES) + analytic_integrator.reset() + state = {"timevec": [], "I": [], "I_aux": []} + for step, t in enumerate(timevec): + state_ = analytic_integrator.get_value(t) + state["timevec"].append(t) + for sym, val in state_.items(): + state[sym].append(val) + + # solve with odeint + ts0 = np.arange(0., input_spike_times[0] - dt / 2, dt) + ts1 = np.arange(input_spike_times[0], input_spike_times[1] - dt / 2, dt) + ts2 = np.arange(input_spike_times[1], T + dt, dt) + + y_ = odeint(flow, [0., 0.], ts0, args=(tau_1, tau_2, alpha, dt)) + y_ = np.vstack([y_, odeint(flow, [y_[-1, 0] + w * alpha * (1. / tau_2 - 1. / tau_1), y_[-1, 1]], ts1, args=(tau_1, tau_2, alpha, dt))]) + y_ = np.vstack([y_, odeint(flow, [y_[-1, 0] + w * alpha * (1. / tau_2 - 1. / tau_1), y_[-1, 1]], ts2, args=(tau_1, tau_2, alpha, dt))]) + + rec_I_interp = np.interp(np.hstack([ts0, ts1, ts2]), timevec, state['I']) + rec_I_aux_interp = np.interp(np.hstack([ts0, ts1, ts2]), timevec, state['I_aux']) + + if INTEGRATION_TEST_DEBUG_PLOTS: + tmax = time_to_max(tau_1, tau_2) + mpl.rcParams['text.usetex'] = True + + fig, ax = plt.subplots(nrows=2, figsize=(5, 4), dpi=300) + ax[0].plot(timevec, state['I_aux'], '--', lw=3, color='k', label=r'$I_\mathsf{aux}(t)$ (NEST)') + ax[0].plot(timevec, state['I'], '-', lw=3, color='k', label=r'$I(t)$ (NEST)') + ax[0].plot(np.hstack([ts0, ts1, ts2]), y_[:, 0], '--', lw=2, color='r', label=r'$I_\mathsf{aux}(t)$ (odeint)') + ax[0].plot(np.hstack([ts0, ts1, ts2]), y_[:, 1], '-', lw=2, color='r', label=r'$I(t)$ (odeint)') + + for tin in input_spike_times: + ax[0].vlines(tin + tmax, ax[0].get_ylim()[0], ax[0].get_ylim()[1], colors='k', linestyles=':') + + ax[1].semilogy(np.hstack([ts0, ts1, ts2]), np.abs(y_[:, 1] - rec_I_interp), label="I") + ax[1].semilogy(np.hstack([ts0, ts1, ts2]), np.abs(y_[:, 0] - rec_I_aux_interp), linestyle="--", label="I_aux") + ax[1].set_ylabel("Error") + + for _ax in ax: + _ax.set_xlim(0., T + dt) + _ax.legend() + + ax[-1].set_xlabel(r'time (ms)') + + fig.savefig('double_exp_test.png') + + np.testing.assert_allclose(y_[:, 1], rec_I_interp, atol=1E-7)