diff --git a/.gitignore b/.gitignore index fe292b6..6b47837 100644 --- a/.gitignore +++ b/.gitignore @@ -80,3 +80,6 @@ target/ # Installation and auto-generated files brian2modelfitting.egg-info + +# sbi logs +examples/sbi-logs/ diff --git a/brian2modelfitting/__init__.py b/brian2modelfitting/__init__.py index c3c8bbc..e1994f6 100644 --- a/brian2modelfitting/__init__.py +++ b/brian2modelfitting/__init__.py @@ -6,6 +6,7 @@ from .tests import run as run_test from .fitter import * +from .inferencer import Inferencer from .optimizer import * from .metric import * from .simulator import * @@ -16,4 +17,5 @@ 'Simulator', 'RuntimeSimulator', 'CPPStandaloneSimulator', 'MSEMetric', 'Metric', 'GammaFactor', 'FeatureMetric', 'SpikeMetric', 'TraceMetric', 'get_gamma_factor', 'firing_rate', - 'Fitter', 'SpikeFitter', 'TraceFitter', 'OnlineTraceFitter'] + 'Fitter', 'SpikeFitter', 'TraceFitter', 'OnlineTraceFitter', + 'Inferencer'] diff --git a/brian2modelfitting/inferencer.py b/brian2modelfitting/inferencer.py new file mode 100644 index 0000000..02601ff --- /dev/null +++ b/brian2modelfitting/inferencer.py @@ -0,0 +1,790 @@ +from numbers import Number +from typing import Mapping + +from brian2.core.functions import Function +from brian2.core.namespace import get_local_namespace +from brian2.core.network import Network +from brian2.devices.cpp_standalone.device import CPPStandaloneDevice +from brian2.devices.device import get_device, device +from brian2.equations.equations import Equations +from brian2.groups.neurongroup import NeuronGroup +from brian2.input.timedarray import TimedArray +from brian2.monitors.statemonitor import StateMonitor +from brian2.units.allunits import * # all physical units +from brian2.units.fundamentalunits import (DIMENSIONLESS, + fail_for_dimension_mismatch, + get_dimensions, + Quantity) +import matplotlib.pyplot as plt +import numpy as np +from sbi.utils.get_nn_models import (posterior_nn, + likelihood_nn, + classifier_nn) +from sbi.utils.torchutils import BoxUniform +import sbi.analysis +import sbi.inference +import torch + +from .simulator import RuntimeSimulator, CPPStandaloneSimulator + + +def configure_simulator(): + """Return the configured simulator, which can be either + `.RuntimeSimulator`, object for use with `.RuntimeDevice`, or + `.CPPStandaloneSimulator`, object for use with + `.CPPStandaloneDevice`. + + Parameters + ---------- + None + + Returns + ------- + brian2modelfitting.simulator.Simulator + Either `.RuntimeSimulator` or `.CPPStandaloneSimulator` + depending on the currently active ``Device`` object describing + the available computational engine. + """ + simulators = {'CPPStandaloneDevice': CPPStandaloneSimulator(), + 'RuntimeDevice': RuntimeSimulator()} + if isinstance(get_device(), CPPStandaloneDevice): + if device.has_been_run is True: + build_options = dict(device.build_options) + get_device().reinit() + get_device().activate(**build_options) + return simulators[get_device().__class__.__name__] + + +def get_full_namespace(additional_namespace, level=0): + """Return the namespace with added ``additional_namespace``, in + which references to external parameters or functions are stored. + + Parameters + ---------- + additional_namespace : dict + References to external parameters or functions, where key is + the name and value is the value of the external param/func. + level : int, optional + How far to go back to get the locals/globals. + + Returns + ------- + dict + Namespace with additional references to external parameters or + functions. + """ + namespace = {key: value + for key, value in get_local_namespace(level=level + 1).items() + if isinstance(value, (Number, Quantity, Function))} + namespace.update(additional_namespace) + return namespace + + +def get_param_dict(param_values, param_names, n_values): + """Return a dictionary compiled of parameter names and values. + + Parameters + ---------- + param_values : iterable + Iterable of size (``n_samples``, ``len(param_names)``) + containing parameter values. + param_names : iterable + Iterable containing parameter names + n_values : int + Total number of given values for a single parameter. + + Returns + ------- + dict + Dictionary containing key-value pairs thet correspond to a + parameter name and value(s) + """ + param_values = np.array(param_values) + param_dict = dict() + for name, value in zip(param_names, param_values.T): + param_dict[name] = (np.ones((n_values, )) * value) + return param_dict + + +def calc_prior(param_names, **params): + """Return prior distributparion over given parameters. Note that the + only available prior distribution currently supported is + multidimensional uniform distribution defined on a box. + + Parameters + ---------- + param_names : iterable + Iterable containing parameter names. + params : dict + Dictionary with keys that correspond to parameter names, and + values should be a single dimensional lists or arrays + + Return + ------ + sbi.utils.torchutils.BoxUniform + ``sbi`` compatible object that contains a uniform prior + distribution over a given set of parameter + """ + for param_name in param_names: + if param_name not in params: + raise TypeError(f'"Bounds must be set for parameter {param_name}') + prior_min = [] + prior_max = [] + for param_name in param_names: + prior_min.append(min(params[param_name]).item()) + prior_max.append(max(params[param_name]).item()) + prior = BoxUniform(low=torch.as_tensor(prior_min), + high=torch.as_tensor(prior_max)) + return prior + + +class Inferencer(object): + """Class for simulation-based inference. + + It offers an interface similar to that of `.Fitter` class but + instead of fitting, neural density estimator is trained using a + generative model. This class serves as a wrapper for ``sbi`` + library for inferencing posterior over unknown parameters of a + given model. + + Parameters + ---------- + dt : brian2.units.fundamentalunits.Quantity + Integration time step. + model : str or brian2.equations.equations.Equations + Single cell model equations. + input : dict + Input traces in dictionary format, where key corresponds to the + name of the input variable as defined in ``model`` and value + corresponds to a single dimensional array of data traces. + output : dict + Dictionary of recorded (or simulated) output data traces, where + key corresponds to the name of the output variable as defined + in ``model`` and value corresponds to a single dimensional + array of recorded data traces. + method : str, optional + Integration method. + threshold : str, optional + The condition which produces spikes. Should be a single line + boolean expression. + reset : str, optional + The (possibly multi-line) string with the code to execute on + reset. + refractory : str, optional + Either the length of the refractory period (e.g., ``2*ms``), a + string expression that evaluates to the length of the + refractory period after each spike, e.g., ``'(1 + rand())*ms'``, + or a string expression evaluating to a boolean value, given the + condition under which the neuron stays refractory after a spike, + e.g., ``'v > -20*mV'``. + param_init : dict, optional + Dictionary of state variables to be initialized with respective + values. + """ + def __init__(self, dt, model, input, output, method=None, threshold=None, + reset=None, refractory=None, param_init=None): + # time scale + self.dt = dt + + # model equations + if isinstance(model, str): + model = Equations(model) + else: + raise TypeError('Equations must be appropriately formatted.') + + # input data traces + if not isinstance(input, Mapping): + raise TypeError('``input`` argument must be a dictionary mapping' + ' the name of the input variable and ``input``.') + if len(input) > 1: + raise NotImplementedError('Only a single input is supported.') + input_var = list(input.keys())[0] + input = input[input_var] + if input_var not in model.identifiers: + raise NameError(f'{input_var} is not an identifier in the model.') + + # output data traces + if not isinstance(output, Mapping): + raise TypeError('``output`` argument must be a dictionary mapping' + ' the name of the output variable and ``output``') + output_var = list(output.keys()) + output = list(output.values()) + for o_var in output_var: + if o_var not in model.names: + raise NameError(f'{o_var} is not a model variable') + self.output_var = output_var + self.output = output + + # create variable for parameter names + self.param_names = model.parameter_names + + # set the simulation time for a given time scale + self.n_traces, n_steps = input.shape + self.sim_time = dt * n_steps + + # handle multiple output variables + self.output_dim = [] + for o_var, out in zip(self.output_var, self.output): + self.output_dim.append(model[o_var].dim) + fail_for_dimension_mismatch(out, self.output_dim[-1], + 'The provided target values must have' + ' the same units as the variable' + f' {o_var}') + + # add input to equations + self.model = model + input_dim = get_dimensions(input) + input_dim = '1' if input_dim is DIMENSIONLESS else repr(input_dim) + input_eqs = f'{input_var} = input_var(t, i % n_traces) : {input_dim}' + self.model += input_eqs + + # add output to equations + counter = 0 + for o_var, o_dim in zip(self.output_var, self.output_dim): + counter += 1 + output_expr = f'output_var_{counter}(t, i % n_traces)' + output_dim = ('1' if o_dim is DIMENSIONLESS else repr(o_dim)) + output_eqs = f'{o_var}_target = {output_expr} : {output_dim}' + self.model += output_eqs + + # create ``TimedArray`` object for input w.r.t. a given time scale + self.input_traces = TimedArray(input.transpose(), dt=self.dt) + + # handle initial values for the ODE system + if not param_init: + param_init = {} + for param in param_init.keys(): + if not (param in self.model.diff_eq_names or + param in self.model.parameter_names): + raise ValueError(f'{param} is not a model variable or a' + ' parameter in the model') + self.param_init = param_init + + # handle the rest of optional parameters for the ``NeuronGroup`` class + self.method = method + self.threshold = threshold + self.reset = reset + self.refractory = refractory + + # placeholder for samples + self.n_samples = None + self.samples = None + # placeholder for the posterior + self.posterior = None + + @property + def n_neurons(self): + """Return the number of neurons that are used in `.NeuronGroup` + class while generating data for training the neural density + estimator. + + Unlike the `.Fitter` class, `.Inferencer` does not take the + total number of samples in the constructor. Thus, this property + becomes available only after the simulation is performed. + + Parameters + ---------- + None + + Returns + ------- + int + Total number of neurons. + """ + if self.n_samples is None: + raise ValueError('Number of samples is not yet defined.' + 'Call ``generate_training_data`` method first.') + return self.n_traces * self.n_samples + + def setup_simulator(self, network_name, n_neurons, output_var, param_init, + level=1): + """Return configured simulator. + + Parameters + ---------- + network_name : str + Network name. + n_neurons : int + Number of neurons which equals to the number of samples + times the number of input/output traces. + output_var : str + Name of the output variable. + param_init : dict + Dictionary of state variables to be initialized with + respective values. + level : int, optional + How far to go back to get the locals/globals. + + Returns + ------- + brian2modelfitting.simulator.Simulator + Configured simulator w.r.t. to the available device. + """ + # configure the simulator + simulator = configure_simulator() + + # update the local namespace + namespace = get_full_namespace({'input_var': self.input_traces, + 'n_traces': self.n_traces}, + level=level+1) + counter = 0 + for out in self.output: + counter += 1 + namespace[f'output_var_{counter}'] = TimedArray(out.transpose(), + dt=self.dt) + + # setup neuron group + kwds = {} + if self.method is not None: + kwds['method'] = self.method + model = (self.model + + Equations('iteration : integer (constant, shared)')) + neurons = NeuronGroup(N=n_neurons, + model=model, + threshold=self.threshold, + reset=self.reset, + refractory=self.refractory, + dt=self.dt, + namespace=namespace, + name='neurons', + **kwds) + network = Network(neurons) + network.add(StateMonitor(source=neurons, variables=output_var, + record=True, dt=self.dt, name='statemonitor')) + + # initialize the simulator + simulator.initialize(network, param_init, name=network_name) + return simulator + + def initialize_prior(self, **params): + """Return the prior uniform distribution over parameters. + + Parameters + ---------- + params : dict + Bounds for each parameter. + + Returns + ------- + sbi.utils.BoxUniform + Uniformly distributed prior over given parameters. + """ + for param in params: + if param not in self.param_names: + raise ValueError(f'Parameter {param} must be defined as a' + ' model\'s parameter') + prior = calc_prior(self.param_names, **params) + return prior + + def generate_training_data(self, n_samples, prior): + """Return sampled prior and executed simulator containing + recorded variables to be used for training the neural density + estimator. + + Parameter + --------- + n_samples : int + The number of samples. + prior : sbi.utils.BoxUniform + Uniformly distributed prior over given parameters. + + Returns + ------- + numpy.ndarray + Sampled prior of shape (``n_samples``, -1). + """ + # set n_samples to class variable to be able to call self.n_neurons + self.n_samples = n_samples + + # sample from prior + theta = prior.sample((n_samples, )) + theta = np.atleast_2d(theta.numpy()) + return theta + + def extract_summary_statistics(self, theta, features, level=1): + """Return summary statistics to be used for training the neural + density estimator. + + Parameters + ---------- + theta : numpy.ndarray + Sampled prior of shape (``n_samples``, -1). + features : list + List of callables that take the voltage trace and output + summary statistics stored in `numpy.array`. + level : int, optional + How far to go back to get the locals/globals. + + Returns + ------- + numpy.ndarray + Summary statistics. + """ + # repeat each row for how many input/output different trace are there + _theta = np.repeat(theta, repeats=self.n_traces, axis=0) + + # create a dictionary with repeated sampled prior + d_param = get_param_dict(_theta, self.param_names, self.n_neurons) + + # set up and run the simulator + network_name = 'infere' + simulator = self.setup_simulator(network_name=network_name, + n_neurons=self.n_neurons, + output_var=self.output_var, + param_init=self.param_init, + level=level+1) + simulator.run(self.sim_time, d_param, self.param_names, iteration=0, + name=network_name) + + # extract features + obs = simulator.statemonitor.recorded_variables + x = [] + for ov in self.output_var: + x_val = obs[ov].get_value() + summary_statistics = [] + for feature in features: + summary_statistics.append(feature(x_val)) + x.append(summary_statistics) + x = np.array(x, dtype=np.float32) + x = x.reshape((self.n_samples, -1)) + return x + + def init_inference(self, inference_method, density_estimator_model, prior, + **inference_kwargs): + """Return instantiated inference object. + + Parameters + ---------- + inference_method : str + Inference method. Either of SNPE, SNLE or SNRE. + density_estimator_model : str + The type of density estimator to be created. Either + ``mdn``, ``made``, ``maf``, ``nsf`` for SNPE and SNLE, or + ``linear``, ``mlp``, ``resnet`` for SNRE. + prior : sbi.utils.BoxUniform + Uniformly distributed prior over given parameters. + inference_kwargs : dict, optional + Additional keyword arguments for + ``sbi.utils.get_nn_models.posterior_nn`` method. + + Returns + ------- + sbi.inference.NeuralInference + Instantiated inference object. + """ + try: + inference_method = str.upper(inference_method) + inference_method_fun = getattr(sbi.inference, inference_method) + except AttributeError: + raise NameError(f'Inference method {inference_method} is not ' + 'supported. Choose between SNPE, SNLE or SNRE.') + finally: + if inference_method == 'SNPE': + density_estimator_builder = posterior_nn( + model=density_estimator_model, **inference_kwargs) + elif inference_method == 'SNLE': + density_estimator_builder = likelihood_nn( + model=density_estimator_model, **inference_kwargs) + else: + density_estimator_builder = classifier_nn( + model=density_estimator_model, **inference_kwargs) + inference = inference_method_fun(prior, density_estimator_builder, + device='cpu', + show_progress_bars=True) + return inference + + def train(self, inference, theta, x, *args, **train_kwargs): + """Return inference object with stored training data and + trained density estimator. + + Parameters + ---------- + inference : sbi.inference.NeuralInference + Instantiated inference object with stored paramaters and + simulation outputs prepared for training. + theta : torch.tensor + Sampled prior. + x : torch.tensor + Summary statistics. + args : list, optional + Contains a uniformly distributed sbi.utils.BoxUniform + prior/proposal. Used only for SNPE, for SNLE and SNRE, + ``proposal`` should not be passed to ``append_simulations`` + method, thus ``args`` should not be passed. + train_kwargs : dict, optional + Additional keyword arguments for ``train`` method of + ``sbi.inference.NeuralInference`` object. + + Returns + ------- + tuple + ``sbi.inference.NeuralInference`` object with stored + paramaters and simulation outputs prepared for training and + trained neural density estimator object. + """ + inference = inference.append_simulations(theta, x, *args) + density_estimator = inference.train(**train_kwargs) + return (inference, density_estimator) + + def build_posterior(self, inference, density_estimator): + """Return instantiated inference object. + + Parameters + ---------- + inference : sbi.inference.NeuralInference + Instantiated inference object with stored paramaters and + simulation outputs prepared for training. + theta : torch.tensor + Sampled prior. + x : torch.tensor + Summary statistics. + args : list, optional + Contains a uniformly distributed sbi.utils.BoxUniform + prior/proposal. Used only for SNPE, for SNLE and SNRE, + ``proposal`` should not be passed to ``append_simulations`` + method, thus ``args`` should not be passed. + train_kwargs : dict, optional + Additional keyword arguments for ``train`` method of + ``sbi.inference.NeuralInference`` object. + + Returns + ------- + tuple + ``sbi.inference.NeuralInference`` object with stored + paramaters and simulation outputs prepared for training and + ``sbi.inference.NeuralInference`` object from which. + """ + posterior = inference.build_posterior(density_estimator) + return (inference, posterior) + + def infere(self, n_samples, features, n_rounds=1, inference_method='SNPE', + density_estimator_model='maf', inference_kwargs={}, + train_kwargs={}, posterior_kwargs={}, **params): + """Return the trained neural density estimator. + + Currently only sequential neural posterior estimator is + supported. + + Parameter + --------- + n_samples : int + The number of samples. + features : list + List of callables that take the voltage trace and output + summary statistics stored in `numpy.array`. + n_rounds : int or str, optional + If ``n_rounds`` is set to 1, amortized inference will be + performed. Otherwise, if ``n_rounds`` is integer larger + than 1, multi-round inference will be performed. + inference_method : str + Inference method. Either of SNPE, SNLE or SNRE. Currently, + only SNPE is supported. + density_estimator_model : str + The type of density estimator to be created. Either + ``mdn``, ``made``, ``maf`` or ``nsf``. + train_kwargs : dict, optional + Dictionary of arguments for training the posterior estimator. + params : dict + Bounds for each parameter. + + Returns + ------- + sbi.inference.posteriors.DirectPosterior + Trained posterior. + """ + if not isinstance(n_rounds, int): + raise ValueError('Number of rounds must be a positive integer.') + try: + inference_method = str.upper(inference_method) + except ValueError as e: + print(e, '\nInvalid inference method.') + if inference_method not in ['SNPE', 'SNLE', 'SNRE']: + raise ValueError(f'Inference method {inference_method} is not ' + 'supported. Choose between SNPE, SNLE or SNRE.') + + # observation the focus is on + x_o = [] + for o in self.output: + o = np.array(o) + obs = [] + for feature in features: + obs.extend(feature(o.transpose())) + x_o.append(obs) + x_o = torch.tensor(x_o, dtype=torch.float32) + self.x_o = x_o + + # initialize prior + prior = self.initialize_prior(**params) + + # initialize inference object + inference = self.init_inference(inference_method, + density_estimator_model, + prior, + **inference_kwargs) + + # allocate empty list of posteriors + posteriors = [] + proposal = prior + if inference_method == 'SNPE': + args = [proposal] + else: + args = [] + for round in range(n_rounds): + print(f'Round {round + 1} of inference.') + + # extract the training data and make adjustments for ``sbi`` + print('Generating training data...') + theta = self.generate_training_data(n_samples, proposal) + theta = torch.tensor(theta, dtype=torch.float32) + + # extract the summary statistics and make adjustments for ``sbi`` + x = self.extract_summary_statistics(theta, features) + x = torch.tensor(x) + + # pass the simulated data to the inference object and train it + print('Training the neural density estimator...') + inference, density_estimator = self.train(inference, + theta, x, + *args, **train_kwargs) + + # use the density estimator to build the posterior + inference, posterior = self.build_posterior(inference, + density_estimator, + **posterior_kwargs) + + # append the current posterior to the list of posteriors + posteriors.append(posterior) + + # update the proposal given the observation + proposal = posterior.set_default_x(x_o) + self.posterior = posterior + return posterior + + def sample(self, shape, posterior=None, **kwargs): + """Return samples from posterior distribution. + + Parameters + ---------- + shape : tuple + Desired shape of samples that are drawn from posterior. + posterior : sbi.inference.posteriors.DirectPosterior, optional + Posterior distribution. + **kwargs : dict, optional + Additional keyword arguments for ``sample`` method in + ``sbi.inference.posteriors.DirectPosterior`` class + Returns + ------- + torch.tensor + Samples from posterior of the shape as given in ``shape``. + """ + if posterior: + p = posterior + else: + p = self.posterior + if not p: + raise ValueError("Need to provide posterior argument if no " + "posterior has been calculated by the 'infere' " + "method.") + samples = p.sample(shape) + self.samples = samples + return samples + + def pairplot(self, samples=None, **kwargs): + """Plot samples in a 2-D grid with marginals and pairwise + marginals. + + Check ``sbi.analysis.plot.pairplot`` for more details. + + Parameters + ---------- + samples : iterable, optional + Samples used to build the pairplot. + **kwargs : dict, optional + Additional keyword arguments for the + ``sbi.analysis.plot.pairplot`` function. + + Returns + ------- + tuple + Figure and axis of posterior distribution plot. + """ + if samples is not None: + s = samples + else: + try: + s = self.samples + except AttributeError as e: + print(e, '\nProvide samples or call ``sample`` method first.') + raise + fig, axes = sbi.analysis.pairplot(s, **kwargs) + return fig, axes + + def generate_traces(self, posterior=None, output_var=None, param_init=None, + level=0): + """Generates traces for a single drawn sample from the trained + posterior and all inputs. + + Parameters + ---------- + posterior : sbi.inference.posteriors.DirectPosterior, optional + Posterior distribution. + output_var: str or sequence of str + Name of the output variable to be monitored, it can also be + a sequence of names to record multiple variables. + param_init : dict + Dictionary of initial values for the model. + level : int, optional + How far to go back to get the locals/globals. + + Returns + ------- + brian2.units.fundamentalunits.Quantity or dict + If a single output variable is observed, 2-D array of + traces generated by using a set of parameters sampled from + the trained posterior distribution of shape + (``n.traces``, number of time steps). Otherwise, a + dictionary with keys set to names of output variables, and + values to generated traces of respective output variables. + """ + # sample a single set of parameters from posterior distribution + if posterior: + p = posterior + else: + try: + p = self.posterior + except NameError as e: + print(e, 'Posterior object is not found.') + params = p.sample((1, )) + + # set output variable that is monitored + if output_var is None: + output_var = self.output_var + + # set up initial values + if param_init is None: + param_init = self.param_init + else: + param_init = dict(self.param_init) + self.param_init.update(param_init) + + # create a dictionary with repeated sampled prior + d_param = get_param_dict(params, self.param_names, self.n_traces) + + # set up and run the simulator + network_name = 'generate_traces' + simulator = self.setup_simulator('generate_traces', + self.n_traces, + output_var=output_var, + param_init=param_init, + level=level+1) + simulator.run(self.sim_time, d_param, self.param_names, iteration=0, + name=network_name) + + # create dictionary of traces for multiple observed output variables + if len(output_var) > 1: + for ov in output_var: + trace = getattr(simulator.statemonitor, ov)[:] + traces = {ov: trace} + else: + traces = getattr(simulator.statemonitor, output_var[0])[:] + return traces diff --git a/brian2modelfitting/utils.py b/brian2modelfitting/utils.py index 984e0a6..534c696 100644 --- a/brian2modelfitting/utils.py +++ b/brian2modelfitting/utils.py @@ -2,7 +2,7 @@ from brian2 import have_same_dimensions from brian2.units.fundamentalunits import Quantity -from tqdm.autonotebook import tqdm +from tqdm.auto import tqdm def _format_quantity(v, precision=3): @@ -12,7 +12,8 @@ def _format_quantity(v, precision=3): return f'{v:.{precision}g}' -def callback_text(params, errors, best_params, best_error, index, additional_info): +def callback_text(params, errors, best_params, best_error, index, + additional_info): """Default callback print-out for Fitters""" params = [] for p, v in sorted(best_params.items()): @@ -25,14 +26,11 @@ def callback_text(params, errors, best_params, best_error, index, additional_inf for error, normed_error, varname in zip(additional_info['objective_errors'], additional_info['objective_errors_normalized'], additional_info['output_var']): - if not have_same_dimensions(error, normed_error) or error != normed_error: raw_error_str = f', unnormalized error: {_format_quantity(error)}' else: raw_error_str = '' - errors.append(f'{_format_quantity(normed_error)} ({varname}{raw_error_str})') - error_sum = ' + '.join(errors) print(f"{round}Best parameters {param_str}\n" f"{' '*len(round)}Best error: {best_error_str} = {error_sum}") @@ -53,8 +51,8 @@ def callback_text(params, errors, best_params, best_error, index, additional_inf print(f"{' ' * len(round)}Best error: {best_error_str} ({additional_info['output_var'][0]})") - -def callback_none(params, errors, best_params, best_error, index, additional_info): +def callback_none(params, errors, best_params, best_error, index, + additional_info): """Non-verbose callback""" pass @@ -83,14 +81,12 @@ def callback_setup(set_type, n_rounds): elif type(set_type) is FunctionType: callback = set_type else: - raise TypeError("callback has to be a str ('text' or 'progressbar'), " - "callable or None") - + raise TypeError('callback has to be a str (`text` or `progressbar`), ' + 'allable or None') return callback def make_dic(names, values): """Create dictionary based on list of strings and 2D array""" result_dict = {name: value for name, value in zip(names, values)} - return result_dict diff --git a/examples/hh_sbi.py b/examples/hh_sbi.py new file mode 100644 index 0000000..4ca1c00 --- /dev/null +++ b/examples/hh_sbi.py @@ -0,0 +1,87 @@ +from brian2 import * +from brian2modelfitting import * +import pandas as pd + + +# Load input and output data traces +df_inp_traces = pd.read_csv('input_traces_hh.csv') +df_out_traces = pd.read_csv('output_traces_hh.csv') +inp_traces = df_inp_traces.to_numpy() +inp_traces = inp_traces[:3, 1:] +out_traces = df_out_traces.to_numpy() +out_traces = out_traces[:3, 1:] + +# Model and its parameters +area = 20_000 * um ** 2 +El = -65 * mV +EK = -90 * mV +ENa = 50 * mV +VT = -63 * mV +dt = 0.01 * ms +eqs = ''' + dv/dt = (gl*(El-v) - g_na*(m*m*m)*h*(v-ENa) - g_kd*(n*n*n*n)*(v-EK) + I)/Cm : volt + dm/dt = 0.32*(mV**-1)*(13.*mV-v+VT)/ + (exp((13.*mV-v+VT)/(4.*mV))-1.)/ms*(1-m)-0.28*(mV**-1)*(v-VT-40.*mV)/ + (exp((v-VT-40.*mV)/(5.*mV))-1.)/ms*m : 1 + dn/dt = 0.032*(mV**-1)*(15.*mV-v+VT)/ + (exp((15.*mV-v+VT)/(5.*mV))-1.)/ms*(1.-n)-.5*exp((10.*mV-v+VT)/(40.*mV))/ms*n : 1 + dh/dt = 0.128*exp((17.*mV-v+VT)/(18.*mV))/ms*(1.-h)-4./(1+exp((40.*mV-v+VT)/(5.*mV)))/ms*h : 1 + g_na : siemens (constant) + g_kd : siemens (constant) + gl : siemens (constant) + Cm : farad (constant) +''' + +# Simulation-based inference object instantiation +inferencer = Inferencer(dt=dt, model=eqs, + input={'I': inp_traces * amp}, + output={'v': out_traces * mV}, + method='exponential_euler', + threshold='m > 0.5', + refractory='m > 0.5', + param_init={'v': 'VT'}) + +# Generate prior and train the neural density estimator +inferencer.infere(n_samples=1000, + features=[ + lambda x: x.mean(axis=0), + lambda x: x.std(axis=0), + lambda x: x.ptp(axis=0)], + n_rounds=2, + density_estimator_model='made', + gl=[1e-09 * siemens, 1e-07 * siemens], + g_na=[2e-06 * siemens, 2e-04 * siemens], + g_kd=[6e-07 * siemens, 6e-05 * siemens], + Cm=[0.1 * uF * cm ** -2 * area, 2 * uF * cm ** -2 * area]) + +# Draw samples from posterior +inferencer.sample((1000,)) + +# Create pairplot from samples +labels_params = [r'$\overline{g}_{l}$', r'$\overline{g}_{Na}$', + r'$\overline{g}_{K}$', r'$\overline{C}_{m}$'] +inferencer.pairplot(labels=labels_params) + +# Generate traces by using a single sample from the trained posterior +inf_traces = inferencer.generate_traces() + +# Visualize traces +t = arange(0, out_traces.shape[1] * dt / ms, dt / ms) +nrows = 2 +ncols = out_traces.shape[0] +fig, axs = subplots(nrows, ncols, sharex=True, + gridspec_kw={'height_ratios': [3, 1]}, figsize=(15, 4)) +for idx in range(ncols): + axs[0, idx].plot(out_traces[idx, :].T, label='measurements') + axs[0, idx].plot(inf_traces[idx, :].T / mV, label='fits') + axs[1, idx].plot(inp_traces[idx, :].T / amp, 'k-', label='stimulus') + axs[1, idx].set_xlabel('t, ms') + if idx == 0: + axs[0, idx].set_ylabel('$v$, mV') + axs[1, idx].set_ylabel('$I$, nA') +handles, labels = [(h + l) for h, l + in zip(axs[0, idx].get_legend_handles_labels(), + axs[1, idx].get_legend_handles_labels())] +fig.legend(handles, labels) +tight_layout() +show() diff --git a/requirements.txt b/requirements.txt index 83b5aee..2d05504 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,3 +5,4 @@ scikit-learn!=0.23.0 lmfit tqdm efel +sbi diff --git a/setup.py b/setup.py index bd9d745..98ac0bb 100644 --- a/setup.py +++ b/setup.py @@ -1,27 +1,22 @@ #! /usr/bin/env python -''' -brian2modelfitting setup script -''' +"""brian2modelfitting setup script""" + import os -import sys from setuptools import setup, find_packages -def readme(): - with open('README.md') as f: - return f.read() version = {} with open(os.path.join('brian2modelfitting', 'version.py')) as fp: exec(fp.read(), version) -# Note that this does not set a version number explicitly, but automatically -# figures out a version based on git tags +with open('README.md') as f: + long_description = f.read() + setup(name='brian2modelfitting', url='https://github.com/brian-team/brian2modelfitting', version=version['version'], packages=find_packages(), - # package_data={}, install_requires=['numpy', 'brian2>=2.2', 'setuptools', @@ -32,25 +27,25 @@ def readme(): provides=['brian2modelfitting'], extras_require={'test': ['pytest'], 'docs': ['sphinx>=1.8'], - 'full': ['efel', - 'lmfit']}, + 'full': ['efel', 'lmfit', 'sbi'], + }, python_requires='>=3.6', use_2to3=False, zip_safe=False, description='Modelfitting Toolbox for the Brian 2 simulator', - long_description=readme(), + long_description=long_description, author='Aleksandra Teska, Marcel Stimberg, Romain Brette, Dan Goodman', author_email='team@briansimulator.org', license='CeCILL-2.1', classifiers=[ 'Development Status :: 4 - Beta', 'Intended Audience :: Science/Research', - 'License :: OSI Approved :: CEA CNRS Inria Logiciel Libre License, version 2.1 (CeCILL-2.1)', + 'License :: OSI Approved :: CEA CNRS Inria Logiciel Libre License, ' + 'version 2.1 (CeCILL-2.1)', 'Natural Language :: English', 'Operating System :: OS Independent', 'Programming Language :: Python', 'Programming Language :: Python :: 3', - 'Topic :: Scientific/Engineering :: Bio-Informatics' - ], - keywords='model fitting computational neuroscience' + 'Topic :: Scientific/Engineering :: Bio-Informatics'], + keywords='model fitting computational neuroscience', )