Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Resting state integration #1174

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
238 changes: 236 additions & 2 deletions brian2/groups/neurongroup.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,16 @@
from brian2.core.spikesource import SpikeSource
from brian2.core.variables import (Variables, LinkedVariable,
DynamicArrayVariable, Subexpression)
from brian2.core.namespace import get_local_namespace
from brian2.equations.equations import (Equations, DIFFERENTIAL_EQUATION,
SUBEXPRESSION, PARAMETER,
check_subexpressions,
extract_constant_subexpressions)
extract_constant_subexpressions,
SingleEquation, Expression)
from brian2.equations.refractory import add_refractoriness
from brian2.parsing.expressions import (parse_expression_dimensions,
is_boolean_expression)
from brian2.parsing.sympytools import str_to_sympy, sympy_to_str
from brian2.stateupdaters.base import StateUpdateMethod
from brian2.units.allunits import second
from brian2.units.fundamentalunits import (Quantity, Unit, DIMENSIONLESS,
Expand All @@ -31,10 +34,16 @@
fail_for_dimension_mismatch)
from brian2.utils.logger import get_logger
from brian2.utils.stringtools import get_identifiers

from brian2.codegen.runtime.numpy_rt.numpy_rt import NumpyCodeObject
from .group import Group, CodeRunner, get_dtype
from .subgroup import Subgroup

try:
from scipy.optimize import root
scipy_available = True
except ImportError:
scipy_available = False

__all__ = ['NeuronGroup']

logger = get_logger(__name__)
Expand Down Expand Up @@ -920,3 +929,228 @@ def add_event_to_text(event):
add_event_to_text(event)

return '\n'.join(text)

def resting_state(self, x0 = {}):
'''
Calculate resting state of the system.

Parameters
----------
x0 : dict
Initial guess for the state variables. If any of the system's state variables are not
added, default value of 0 is mapped as the initial guess to the missing state variables.
Note: Time elapsed to locate the resting state would be lesser for better initial guesses.

Returns
-------
rest_state : dict
Dictioary with pair of state variables and resting state values. Returned values
are represented in SI units.
'''
# check scipy availability
if not scipy_available:
raise NotImplementedError("Scipy is not available for using `scipy.optimize.root()`")
# check state variables defined in initial guess are valid
if(x0.keys() - self.equations.diff_eq_names):
raise KeyError("Unknown State Variable: {}".format(next(iter(x0.keys() -
self.equations.diff_eq_names))))

# Add 0 as the intial value for non-mentioned state variables in x0
x0.update({name : 0 for name in self.equations.diff_eq_names - x0.keys()})

# sort dictionary items
state_dict = dict(sorted(x0.items()))

# helper functions to create NeuronGroup object of corresponding equation
# For example: _rhs_equation() returns NeuronGroup object with equations representing
# Right-Hand-Side of self.equations and _jacobian_equation() returns NeuronGroup object
# with equations of jacobian matrix
rhs_states, rhs_group = _rhs_equation(self.equations, get_local_namespace(1))
jac_variables, jac_group = _jacobian_equation(self.equations, self.variables, get_local_namespace(1))

# solver function with _wrapper() as the callable function to be optimized
result = root(_wrapper, list(state_dict.values()), args = (rhs_states, rhs_group, jac_variables,
jac_group, state_dict.keys()), jac = True)

# check the result message for the status of convergence
if result.success == False:
raise Exception("Root calculation failed to converge. Poor initial guess may be the cause of the failure")

# evaluate the solution states to get state variables of jacobian
jac_state = _evaluate_states(jac_group, dict(zip(state_dict.keys(), result.x)), list(jac_variables.reshape(-1)))

# with the state values, prepare jacobian matrix
jac_matrix = np.zeros(jac_variables.shape)

for row in range(jac_variables.shape[0]):
for col in range(jac_variables.shape[1]):
jac_matrix[row, col] = float(jac_state[jac_variables[row, col]])

# check whether the solution is stable by using sign of eigenvalues
jac_eig = np.linalg.eigvals(jac_matrix)
if not np.all(np.real(jac_eig) < 0):
raise Exception('Equilibrium is not stable. Failed to converge to stable equilibrium')

# return the soultion in dictionary form
return dict(zip(state_dict.keys(), result.x))

def _rhs_equation(eqs, namespace = None, level = 0):

"""
Extract the RHS of a system of differential equations. External constants
can be provided via the namespace or will be taken from the local namespace.
Make a new set of equations, where differential equations are replaced by parameters,
and a new subexpression defines their RHS.

E.g. for 'dv/dt = -v / tau : volt' use:
'''v : volt
RHS_v = -v / tau : volt'''

This function could be used to find a resting state of the
system, i.e. a fixed point where the RHS of all equations are approximately 0.

Parameters
----------
eqs : `Equations`
The equations

Returns
-------
rhs_states : list
A list with the names of all variables defined as RHS of the equations
rhs_group : `NeuronGroup`
The NeuronGroup object
"""

if namespace is None:
namespace = get_local_namespace(level+1)

rhs_equations = []
for eq in eqs.values():
if eq.type == DIFFERENTIAL_EQUATION:
rhs_equations.append(SingleEquation(PARAMETER, eq.varname,
dimensions=eq.dim,
var_type=eq.var_type))
rhs_equations.append(SingleEquation(SUBEXPRESSION, 'RHS_'+eq.varname,
dimensions=eq.dim/second.dim,
var_type=eq.var_type,
expr=eq.expr))
else:
rhs_equations.append(eq)

# NeuronGroup with the obtained rhs_equations
rhs_group = NeuronGroup(1, model = Equations(rhs_equations),
codeobj_class = NumpyCodeObject,
namespace = namespace)
# states corresponding to RHS of the system of differential equations
rhs_states = ['RHS_' + name for name in eqs.diff_eq_names]

return (rhs_states, rhs_group)

def _jacobian_equation(eqs, group_variables, namespace = None, level = 0):

"""
Create jacobain expressions of a system of differential equations. External constants
can be provided via the namespace or will be taken from the local namespace.
Make a new set of equations, where differential equations are replaced by parameters,
and a new subexpression defines their jacobain expression.

This function could be used to find a resting state of the
system and check its stability

Parameters
----------
eqs : `Equations`
Equations of the parent NeuronGroup
group_variables : `Variables`
Variables of the parent NeuronGroup

Returns
-------
jac_matrix_variables : `2D-NumPy array`
2D- matrix of jacobian variables.
For example: jac_matrix_variables of model with two variables: u and v would be,
np.array([[J_u_u J_u_v],
[J_v_u J_v_v]])
jac_group : `NeuronGroup`
The NeuronGroup object
"""

if namespace is None:
namespace = get_local_namespace(level+1)
# prepare jac_eqs
diff_eqs = eqs.get_substituted_expressions(group_variables)
diff_eq_names = [name for name, _ in diff_eqs]
system = sympy.Matrix([str_to_sympy(diff_eq[1].code)
for diff_eq in diff_eqs])
J = system.jacobian([str_to_sympy(d) for d in diff_eq_names])
jac_eqs = []
for diff_eq_name, diff_eq in diff_eqs:
jac_eqs.append(SingleEquation(PARAMETER, diff_eq_name,
dimensions=eqs[diff_eq_name].dim,
var_type=eqs[diff_eq_name].var_type))
for var_idx, diff_eq_var in enumerate(diff_eq_names):
for diff_idx, diff_eq_diff in enumerate(diff_eq_names):
dimensions = eqs[diff_eq_var].dim/second.dim/eqs[diff_eq_diff].dim
expr = f'{sympy_to_str(J[var_idx, diff_idx])}'
if expr == '0':
expr = f'0*{dimensions!r}'
jac_eqs.append(SingleEquation(SUBEXPRESSION, f'J_{diff_eq_var}_{diff_eq_diff}',
dimensions=dimensions,
expr=Expression(expr)))
# NeuronGroup with the obtained jac_eqs
jac_group = NeuronGroup(1, model = Equations(jac_eqs),
codeobj_class = NumpyCodeObject,
namespace = namespace)
# prepare 2D matrix of jacobian variables
jac_matrix_variables = np.array(
[[f'J_{var}_{diff_var}' for diff_var in diff_eq_names]
for var in diff_eq_names])

return (jac_matrix_variables, jac_group)

def _evaluate_states(group, values, states):

"""
Evaluate the set of states when given values are set.
The function gets NeuronGroup object and set the given values to it;
and returns the values of states given

Parameters
----------
group : `NeuronGroup`
The NeuronGroup
values : dict-like
Values of states to be set to group
states: list
State variables for which values have to be get

Returns
-------
state_values : dict-like
Dictionary of state variables and their values
"""

group.set_states(values, units = False)
state_values = group.get_states(states)
return state_values

def _wrapper(args, rhs_states, rhs_group, jac_variables, jac_group, diff_eq_names):
"""
Vector function for which root needs to be calculated. Callable function of `scipy.optimize.root()`
"""
# match the argument values with correct variables
sorted_variable_dict = {name : arg for name, arg in zip(sorted(diff_eq_names), args)}
# get the values of `rhs_states` when given values(sorted_variable_dict) are set to rhs_group
rhs = _evaluate_states(rhs_group, sorted_variable_dict, rhs_states)
# get the values of `jac_varaibles` when given values(sorted_variable_dict) are set to jac_group
jac = _evaluate_states(jac_group, sorted_variable_dict, list(jac_variables.reshape(-1)))

# with the values prepare jacobian matrix
jac_matrix = np.zeros(jac_variables.shape)
for row in range(jac_variables.shape[0]):
for col in range(jac_variables.shape[1]):
jac_matrix[row, col] = float(jac[jac_variables[row, col]])

return [float(rhs['RHS_{}'.format(name)]) for name in sorted(diff_eq_names)], jac_matrix

92 changes: 91 additions & 1 deletion brian2/tests/test_neurongroup.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,16 @@
from brian2.units.allunits import second, volt
from brian2.units.fundamentalunits import (DimensionMismatchError,
have_same_dimensions)
from brian2.units.stdunits import ms, mV, Hz
from brian2.units.stdunits import ms, mV, Hz, cm, msiemens, nA
from brian2.units.unitsafefunctions import linspace
from brian2.units.allunits import second, volt, umetre, siemens, ufarad
from brian2.utils.logger import catch_logs

try:
import scipy
scipy_available = True
except ImportError:
scipy_available = False

@pytest.mark.codegen_independent
def test_creation():
Expand Down Expand Up @@ -1716,6 +1722,85 @@ def test_semantics_mod():
assert_allclose(G.x[:], float_values % 3)
assert_allclose(G.y[:], float_values % 3)

def test_simple_resting_value():
"""
Test the resting state values of the system
"""
# simple model with single dependent variable, here it is not necessary
# to run the model as the resting value is certain
El = - 100
tau = 1 * ms
eqs = '''
dv/dt = (El - v)/tau : 1
'''
grp = NeuronGroup(1, eqs, method = 'exact')
resting_state = grp.resting_state()
assert_allclose(resting_state['v'], El)

# one more example
area = 100 * umetre ** 2
g_L = 1e-2 * siemens * cm ** -2 * area
E_L = 1000
Cm = 1 * ufarad * cm ** -2 * area
grp = NeuronGroup(10, '''dv/dt = I_leak / Cm : volt
I_leak = g_L*(E_L - v) : amp''')
resting_state = grp.resting_state({'v': float(10000)})
assert_allclose(resting_state['v'], E_L)

def test_failed_resting_state():
# check the failed to converge system is correctly notified to the user
area = 20000 * umetre ** 2
Cm = 1 * ufarad * cm ** -2 * area
gl = 5e-5 * siemens * cm ** -2 * area
El = -65 * mV
EK = -90 * mV
ENa = 50 * mV
g_na = 100 * msiemens * cm ** -2 * area
g_kd = 30 * msiemens * cm ** -2 * area
VT = -63 * mV
I = 0.01*nA
eqs = Equations('''
dv/dt = (gl*(El-v) - g_na*(m*m*m)*h*(v-ENa) - g_kd*(n*n*n*n)*(v-EK) + I)/Cm : volt
dm/dt = 0.32*(mV**-1)*(13.*mV-v+VT)/
(exp((13.*mV-v+VT)/(4.*mV))-1.)/ms*(1-m)-0.28*(mV**-1)*(v-VT-40.*mV)/
(exp((v-VT-40.*mV)/(5.*mV))-1.)/ms*m : 1
dn/dt = 0.032*(mV**-1)*(15.*mV-v+VT)/
(exp((15.*mV-v+VT)/(5.*mV))-1.)/ms*(1.-n)-.5*exp((10.*mV-v+VT)/(40.*mV))/ms*n : 1
dh/dt = 0.128*exp((17.*mV-v+VT)/(18.*mV))/ms*(1.-h)-4./(1+exp((40.*mV-v+VT)/(5.*mV)))/ms*h : 1
''')
group = NeuronGroup(1, eqs, method='exponential_euler')
group.v = -70*mV
# very poor choice of initial values causing the convergence to fail
with pytest.raises(Exception):
group.resting_state({'v': 0, 'm': 100000000, 'n': 1000000, 'h': 100000000})

def test_unstable_resting_state():

# check the unstability of the converged solution
area = 20000 * umetre ** 2
Cm = 1 * ufarad * cm ** -2 * area
gl = 5e-5 * siemens * cm ** -2 * area
El = -65 * mV
EK = -90 * mV
ENa = 50 * mV
g_na = 100 * msiemens * cm ** -2 * area
g_kd = 30 * msiemens * cm ** -2 * area
VT = -63 * mV
I = 0.01*nA
eqs = Equations('''
dv/dt = (gl*(El-v) - g_na*(m*m*m)*h*(v-ENa) - g_kd*(n*n*n*n)*(v-EK) + I)/Cm : volt
dm/dt = 0.32*(mV**-1)*(13.*mV-v+VT)/
(exp((13.*mV-v+VT)/(4.*mV))-1.)/ms*(1-m)-0.28*(mV**-1)*(v-VT-40.*mV)/
(exp((v-VT-40.*mV)/(5.*mV))-1.)/ms*m : 1
dn/dt = 0.032*(mV**-1)*(15.*mV-v+VT)/
(exp((15.*mV-v+VT)/(5.*mV))-1.)/ms*(1.-n)-.5*exp((10.*mV-v+VT)/(40.*mV))/ms*n : 1
dh/dt = 0.128*exp((17.*mV-v+VT)/(18.*mV))/ms*(1.-h)-4./(1+exp((40.*mV-v+VT)/(5.*mV)))/ms*h : 1
''')
group = NeuronGroup(1, eqs, method='exponential_euler')
group.v = -70*mV
# converging to unstable solution
with pytest.raises(Exception):
group.resting_state()

if __name__ == '__main__':
test_set_states()
Expand Down Expand Up @@ -1792,3 +1877,8 @@ def test_semantics_mod():
test_semantics_floor_division()
test_semantics_floating_point_division()
test_semantics_mod()
if scipy_available:
test_simple_resting_value()
test_failed_resting_state()
test_unstable_resting_state()