-
Notifications
You must be signed in to change notification settings - Fork 6
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add sbi wrapper class #49
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Everything looking 👍 in general. I left a few specific comments, but nothing major. Regarding the features, I'd think I'd do a short-term solution before dealing with #48 (there are some subtleties we have to take into account, e.g. eFEL often needs stimulus times as an additional argument). This solution would be that the train function (and/or Inferencer.__init__
) takes a list of functions which each takes as input a 2d-array of traces (samples × inputs) and should give as an output either a 2d array (features × inputs), or a 1d array of shape (inputs, )
(as convenience for scalar features). These will then be flattened and concatenated before handing over to sbi
. Note that I did not restrict it to scalar features, I think in some cases it is simpler to write a single function that returns several features instead of asking the user to split them up into multiple functions. The only requirement is that the number of features returned by a function never changes (so something like "spike times" would not be a valid feature, but "min and max membrane potential" would be). In the long run we should certainly check this for a nice error message, but in the short run it is ok if it just fails somewhere in the pipeline with a shape mismatch.
brian2modelfitting/utils.py
Outdated
@@ -2,7 +2,7 @@ | |||
|
|||
from brian2 import have_same_dimensions | |||
from brian2.units.fundamentalunits import Quantity | |||
from tqdm.autonotebook import tqdm | |||
from tqdm import tqdm |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think there was a reason for the autonotebook
import (nicer progress bars in notebooks I guess?). Not sure what the best solution is here, you presumably removed it because of the warning?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, autonotebook
looks better but also throws the warning and sometimes fail due to IProgress not found.
Of course, I will change this import as it has been done so far and use tqdm
from autonotebook
module.
brian2modelfitting/inferencer.py
Outdated
Parameters | ||
---------- | ||
param_values : iterable | ||
Iterable of size (`n_samples`, `len(param_names)` containing |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Minor nitpick (for here and in many other places): the docstrings are restructured text not markdown, so single backticks are meant to create links to class/function names, double backticks should be used for simple "typewriter" text.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, right. I will change that 👍
brian2modelfitting/inferencer.py
Outdated
for o in self.output: | ||
o_dim = get_dimensions(o) | ||
o_obs = self.extract_features(o.transpose(), o_dim) | ||
x_o.append(o_obs.flatten()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You should not need to get the units and divide by them, you can always call np.array
around quantities (which will not even perform a copy). In some places you can also get a unit-less version by appending an underscore (this is a Brian convention), e.g. statemonitor.t_
would give you the recorded time points without units.
brian2modelfitting/inferencer.py
Outdated
x_val = obs[ov].get_value_with_unit() | ||
x_dim = get_dimensions(obs[ov]) | ||
features = self.extract_features(x_val, x_dim) | ||
x.append(features) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See above comment about units, e.g. here you could simply use get_value()
.
brian2modelfitting/inferencer.py
Outdated
# use the density estimator to build the posterior | ||
posterior = inference.build_posterior(de) | ||
|
||
# append the current posterior to the list of posteriors | ||
posteriors.append(posterior) | ||
|
||
# update the proposal given the observation | ||
proposal = posterior.set_default_x(x_o) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it would be good to divide this up a bit further, e.g. have one method to simulate from the priors (or from posteriors from a previous run in the multi-round approach), one method to train the network, and one method to build the posterior. There are several reasons for this: 1) it makes it possible to train a network and then use the trained network on different experimental data (e.g. different cells) 2) it can be useful for debugging/sanity checks (e.g. if half of the simulations from the prior distribution give completely senseless results, it probably makes sense to adapt the prior before starting to train a network and 3) sbi's train()
and build_posterior()
methods both take several arguments. For advanced users, it would be good if they could provide them (which would be a mess with the current Inferencer.train()
method. Having said all that, I think your Inferencer.train()
method is a good convenience method for the most common use case, so I'd keep it and just have it call out to the more fine-grained methods .
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I completely agree.
I was thinking to enable a simple call to train
method to do everything for the user, but for more advanced user break this up into a smaller chunks. For example, one method for the actual training of the estimator, one method for building posterior, one method for drawing samples from the posterior, and finally, one method for visual inspection.
Also, with this, we will be able to create easier calls for store and load of the trained posterior, ref to issue #46.
examples/hh_sbi.py
Outdated
labels_params = [r'$\overline{g}_{l}$', r'$\overline{g}_{Na}$', | ||
r'$\overline{g}_{K}$', r'$\overline{C}_{m}$'] | ||
samples = inferencer.sample(1000, viz=True, | ||
labels=labels_params, figsize=(10, 10)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yet another minor thing: for a quick look how the fitting went, it would be good to introduce a Inferencer.generate
method (like Fitter.generate
) that can then be used to visualize the fitted traces in comparison to the experimental ones (like in the other examples).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure thing :)
Marcel, thanks for great comments! |
So just to avoid misunderstanding, I meant that the inferencer gets something like |
That's exactly what I was thinking. Anyway, thank you for the clarification! |
Last few commits deal with some minor docstrings issues.
Advanced users can use each of the methods Also, functionality for generating fitted traces by using the inferred parameters is added. The example is updated according to these changes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, all looks good to me for a merge – I made some minor changes along the way, but nothing really important. We'll need to get into tests and documentation at some point of course, but for now let's continue a bit until the API is settled.
Great, thanks! |
Attempt to resolve the issue #44.
The class Inferencer is fully functional and can be tested via
hh_sbi.py
file in examples.It supports both multi-round inference and multi-input/output traces.
Even though the class is ready to use, the feature metric is not yet fully supported.
This actual feature metric should be supported by resolving the issue #48.
All additional changes are more or less stylistic improvements.
sbi
is added in the requirements file and in thesetup.py
as the optional dependency.