Skip to content

Commit

Permalink
minor updates
Browse files Browse the repository at this point in the history
  • Loading branch information
antelk committed Jul 10, 2021
1 parent 88fc065 commit 7461b39
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 86 deletions.
132 changes: 67 additions & 65 deletions brian2modelfitting/inferencer.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,8 +197,8 @@ def __init__(self, dt, model, input, output, features, method=None,

# input data traces
if not isinstance(input, Mapping):
raise TypeError('``input`` argument must be a dictionary mapping'
' the name of the input variable and ``input``.')
raise TypeError('`input` argument must be a dictionary mapping'
' the name of the input variable and `input`.')
if len(input) > 1:
raise NotImplementedError('Only a single input is supported.')
input_var = list(input.keys())[0]
Expand All @@ -208,8 +208,8 @@ def __init__(self, dt, model, input, output, features, method=None,

# output data traces
if not isinstance(output, Mapping):
raise TypeError('``output`` argument must be a dictionary mapping'
' the name of the output variable and ``output``')
raise TypeError('`output` argument must be a dictionary mapping'
' the name of the output variable and `output`.')
output_var = list(output.keys())
output = list(output.values())
for o_var in output_var:
Expand Down Expand Up @@ -309,7 +309,7 @@ class while generating data for training the neural density
"""
if self.n_samples is None:
raise ValueError('Number of samples is not yet defined.'
'Call ``generate_training_data`` method first.')
'Call `generate_training_data` method first.')
return self.n_traces * self.n_samples

def setup_simulator(self, network_name, n_neurons, output_var, param_init,
Expand Down Expand Up @@ -486,15 +486,15 @@ def save_summary_statistics(self, f, theta=None, x=None):
elif self.theta is not None:
t = self.theta
else:
raise AttributeError('Provide sampled prior or call '
'``infere_step``method first.')
raise AttributeError('Provide sampled prior or call'
' `infere_step` method first.')
if x is not None:
pass
elif self.x is not None:
x = self.x
else:
raise AttributeError('Provide summary feautures or call '
'``infere_step``method first.')
' `infere_step` method first.')
np.savez_compressed(f, theta=t, x=x)

def load_summary_statistics(self, f, **kwargs):
Expand All @@ -517,12 +517,9 @@ def load_summary_statistics(self, f, **kwargs):
Consisting of sampled prior and summary statistics arrays.
"""
loaded = np.load(f, allow_pickle=True)
if loaded.files == ['theta', 'x'] or loaded.files == ['x', 'theta']:
if set(loaded.files) == {'theta', 'x'}:
theta = loaded['theta']
x = loaded['x']
elif loaded.files == ['arr_0', 'arr_1']:
theta = loaded['arr_0']
x = loaded['arr_1']
self.theta = theta
self.x = x
return (theta, x)
Expand Down Expand Up @@ -554,8 +551,8 @@ def init_inference(self, inference_method, density_estimator_model, prior,
inference_method = str.upper(inference_method)
inference_method_fun = getattr(sbi.inference, inference_method)
except AttributeError:
raise NameError(f'Inference method {inference_method} is not '
'supported. Choose between SNPE, SNLE or SNRE.')
raise NameError(f'Inference method {inference_method} is not'
' supported. Choose between SNPE, SNLE or SNRE.')
finally:
if inference_method == 'SNPE':
density_estimator_builder = posterior_nn(
Expand Down Expand Up @@ -671,7 +668,7 @@ def infere_step(self, proposal, inference,
# extract the training data and make adjustments for ``sbi``
if theta is None:
if n_samples is None:
raise ValueError('Either provide ``theta`` or ``n_samples``.')
raise ValueError('Either provide `theta` or `n_samples`.')
else:
theta = self.generate_training_data(n_samples, proposal)
self.theta = theta
Expand All @@ -680,7 +677,7 @@ def infere_step(self, proposal, inference,
# extract the summary statistics and make adjustments for ``sbi``
if x is None:
if n_samples is None:
raise ValueError('Either provide ``x`` or ``n_samples``.')
raise ValueError('Either provide `x` or `n_samples`.')
else:
x = self.extract_summary_statistics(theta, level=2)
self.x = x
Expand All @@ -695,6 +692,7 @@ def infere_step(self, proposal, inference,
inference, posterior = self.build_posterior(inference,
density_estimator,
**posterior_kwargs)
self.inference = inference
self.posterior = posterior
return posterior

Expand Down Expand Up @@ -742,65 +740,69 @@ def infere(self, n_samples=None, theta=None, x=None, n_rounds=1,
sbi.inference.NeuralPosterior
Trained posterior.
"""
# handle the number of rounds
if not isinstance(n_rounds, int):
raise ValueError('Number of rounds must be a positive integer.')
if self.posterior is None: # `.infere_step` has not been called
# handle the number of rounds
if not isinstance(n_rounds, int):
raise ValueError('`n_rounds` has to be a positive integer.')

# handle inference methods
try:
inference_method = str.upper(inference_method)
except ValueError as e:
print(e, '\nInvalid inference method.')
if inference_method not in ['SNPE', 'SNLE', 'SNRE']:
raise ValueError(f'Inference method {inference_method} is not '
'supported. Choose between SNPE, SNLE or SNRE.')

# initialize prior
if self.posterior is None:
# handle inference methods
try:
inference_method = str.upper(inference_method)
except ValueError as e:
print(e, '\nInvalid inference method.')
if inference_method not in ['SNPE', 'SNLE', 'SNRE']:
raise ValueError(f'Inference method {inference_method} is not'
' supported.')

# initialize prior
prior = self.init_prior(**params)
else:
prior = self.posterior.set_default_x(self.x_o)

# extract the training data and make adjustments for ``sbi``
if theta is None:
if n_samples is None:
raise ValueError('Either provide ``theta`` or ``n_samples``.')
# extract the training data and make adjustments for ``sbi``
if theta is None:
if n_samples is None:
raise ValueError('Either provide `theta` or `n_samples`.')
else:
theta = self.generate_training_data(n_samples, prior)
self.theta = theta

# extract the summary statistics and make adjustments for ``sbi``
if x is None:
if n_samples is None:
raise ValueError('Either provide `x` or `n_samples`.')
else:
x = self.extract_summary_statistics(theta)
self.x = x

# initialize inference object
self.inference = self.init_inference(inference_method,
density_estimator_model,
prior,
**inference_kwargs)

# additional args for `.train` method are needed only for SNPE
if inference_method == 'SNPE':
args = [prior]
else:
theta = self.generate_training_data(n_samples, prior)
self.theta = theta

# extract the summary statistics and make adjustments for ``sbi``
if x is None:
if n_samples is None:
raise ValueError('Either provide ``x`` or ``n_samples``.')
args = []
else: # `.infere_step` has been called manually
prior = self.posterior.set_default_x(self.x_o)
if self.posterior._method_family == 'snpe':
args = [prior]
else:
x = self.extract_summary_statistics(theta)
self.x = x

# initialize inference object
inference = self.init_inference(inference_method,
density_estimator_model,
prior,
**inference_kwargs)
args = []

# allocate empty list of posterior
posteriors = []

# set a proposal
proposal = prior

# additional arguments for `.train` method are needed only for SNPE
if inference_method == 'SNPE':
args = [proposal]
else:
args = []

# main inference loop
for round in range(n_rounds):
print(f'Round {round + 1}/{n_rounds}.')

# inference step
posterior = self.infere_step(proposal, inference,
posterior = self.infere_step(proposal, self.inference,
n_samples, self.theta, self.x,
train_kwargs, posterior_kwargs, *args)

Expand Down Expand Up @@ -872,9 +874,9 @@ def sample(self, shape, posterior=None, **kwargs):
elif posterior is None and self.posterior:
p = self.posterior
else:
raise ValueError('Need to provide posterior argument if no '
'posterior has been calculated by the ``infere`` '
'method.')
raise ValueError('Need to provide posterior argument if no'
' posterior has been calculated by the `infere`'
' method.')
samples = p.sample(shape, x=self.x_o, **kwargs)
self.samples = samples
return samples
Expand Down Expand Up @@ -904,7 +906,7 @@ def pairplot(self, samples=None, **kwargs):
try:
s = self.samples
except AttributeError as e:
print(e, '\nProvide samples or call ``sample`` method first.')
print(e, '\nProvide samples or call `sample` method first.')
raise
fig, axes = sbi.analysis.pairplot(s, **kwargs)
return fig, axes
Expand Down Expand Up @@ -942,9 +944,9 @@ def generate_traces(self, posterior=None, output_var=None, param_init=None,
elif posterior is None and self.posterior:
p = self.posterior
else:
raise ValueError('Need to provide posterior argument if no '
'posterior has been calculated by the ``infere`` '
'method.')
raise ValueError('Need to provide posterior argument if no'
' posterior has been calculated by the `infere`'
' method.')
params = p.sample((1, ), x=self.x_o)

# set output variable that is monitored
Expand Down
10 changes: 5 additions & 5 deletions examples/IF_sbi.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,25 +49,25 @@
def n_peaks(x):
n_p = []
for _x in x.transpose():
p_i = find_peaks(_x, height=-60)[0]
p_i = find_peaks(_x)[0]
n_p.append(p_i.size)
return n_p


inferencer = Inferencer(dt=dt, model=eqs_inf,
input={'I_syn': inp_trace.reshape(1, -1)},
output={'v': out_trace},
features=[lambda x: x[(t > start_syn) & (t < end_syn), :].mean(axis=0),
lambda x: x[(t > start_syn) & (t < end_syn), :].std(axis=0),
lambda x: x[(t > start_syn) & (t < end_syn), :].max(axis=0),
features=[lambda x: x[(t > start_syn) & (t < end_syn), :].mean(axis=1),
lambda x: x[(t > start_syn) & (t < end_syn), :].std(axis=1),
lambda x: x[(t > start_syn) & (t < end_syn), :].max(axis=1),
n_peaks],
method='exponential_euler',
threshold='v > -50 * mV',
reset='v = -70 * mV',
param_init={'v': -70 * mV})

inferencer.infere(n_samples=1000,
inference_method='SNLE',
inference_method='SNPE',
gl=[10*nS, 100*nS],
C=[0.1*nF, 10*nF])

Expand Down
31 changes: 15 additions & 16 deletions examples/hh_sbi_advanced.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import os

from brian2 import *
from brian2modelfitting import *
import pandas as pd
Expand Down Expand Up @@ -66,18 +68,18 @@ def n_peaks(x):
g_na=[2e-06*siemens, 2e-04*siemens],
g_kd=[6e-07*siemens, 6e-05*siemens],
Cm=[0.1*uF*cm**-2*area, 2*uF*cm**-2*area])
# Generate training data
theta = inferencer.generate_training_data(n_samples=1000,
prior=prior)
# Extract summary stats
x = inferencer.extract_summary_statistics(theta=theta, level=0)

# Loading and storing of the data
# For large training data, you can store it into .npz and reuse it later
# Prepare training data
path_to_data = __file__[:-3] + '_data.npz'
inferencer.save_summary_statistics(path_to_data, theta, x)
# Training data will be directly set to class variables once they are loaded
theta_loaded, x_loaded = inferencer.load_summary_statistics(path_to_data)
if os.path.exists(path_to_data):
theta, x = inferencer.load_summary_statistics(path_to_data)
else:
# Generate training data
theta = inferencer.generate_training_data(n_samples=1000,
prior=prior)
# Extract summary stats
x = inferencer.extract_summary_statistics(theta=theta, level=0)
# Save the data for later use
inferencer.save_summary_statistics(path_to_data, theta, x)

# Amortized inference
# Training the neural density estimator
Expand All @@ -89,7 +91,7 @@ def n_peaks(x):
# First round of inference where no observation data is set to posterior
posterior = inferencer.infere_step(proposal=prior,
inference=inference,
theta=theta_loaded, x=x_loaded,
theta=theta, x=x,
train_kwargs={'num_atoms': 10,
'learning_rate': 0.0005,
'show_train_summary': True})
Expand All @@ -107,10 +109,7 @@ def n_peaks(x):
r'$\overline{g}_{K}$',
r'$\overline{C}_{m}$'])
# ...and optionally, continue the multiround inference via ``infere`` method
posterior_multi_round = inferencer.infere(n_rounds=2,
theta=theta_loaded, x=x_loaded,
inference_method='SNPE',
density_estimator_model='maf')
posterior_multi_round = inferencer.infere(n_rounds=2)
inferencer.sample((10000, ))
inferencer.pairplot(labels=[r'$\overline{g}_{l}$',
r'$\overline{g}_{Na}$',
Expand Down

0 comments on commit 7461b39

Please sign in to comment.