Skip to content

Commit

Permalink
Add tests for default optimizer/metric (brian-team#5)
Browse files Browse the repository at this point in the history
  • Loading branch information
Eslam Khaled committed Apr 7, 2021
1 parent f4a0f6f commit 0a0be47
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 4 deletions.
45 changes: 42 additions & 3 deletions brian2modelfitting/tests/test_modelfitting_spikefitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,19 +144,58 @@ def test_spikefitter_fit(setup):

def test_spikefitter_fit_errors(setup):
dt, sf = setup
class NaiveOptimizer:
def __init__(self):
self.best = []
naive_opt = NaiveOptimizer()
with pytest.raises(TypeError):
results, errors = sf.fit(n_rounds=2,
optimizer=n_opt,
metric=MSEMetric(),
metric=MSEMetric(), #testing a wrong metric
gL=[20*nS, 40*nS],
C=[0.5*nF, 1.5*nF])
with pytest.raises(TypeError):
results, errors = sf.fit(n_rounds=2,
optimizer=None,
metric=MSEMetric(),
optimizer=naive_opt, #testing a Non-Optimizer child
metric=metric,
gL=[20*nS, 40*nS],
C=[0.5*nF, 1.5*nF])

def test_fitter_fit_default_optimizer(setup):
dt, sf = setup
results, errors = sf.fit(n_rounds=2,
optimizer=None,
metric=metric,
gL=[20*nS, 40*nS],
C=[0.5*nF, 1.5*nF])
assert sf.simulator.neurons.iteration == 1
attr_fit = ['optimizer', 'metric', 'best_params']
for attr in attr_fit:
assert hasattr(sf, attr)

assert isinstance(sf.optimizer, NevergradOptimizer) #default optimizer
assert isinstance(sf.simulator, Simulator)

assert_equal(results, sf.best_params)
assert_equal(errors, sf.best_error)


def test_spikefitter_fit_default_metric(setup):
dt, sf = setup
results, errors = sf.fit(n_rounds=2,
optimizer=n_opt,
metric=None,
gL=[20*nS, 40*nS],
C=[0.5*nF, 1.5*nF])
assert sf.simulator.neurons.iteration == 1
attr_fit = ['optimizer', 'metric', 'best_params']
for attr in attr_fit:
assert hasattr(sf, attr)
assert isinstance(sf.metric, GammaFactor) #default spike metric
assert isinstance(sf.simulator, Simulator)

assert_equal(results, sf.best_params)
assert_equal(errors, sf.best_error)


def test_spikefitter_param_init(setup):
Expand Down
44 changes: 43 additions & 1 deletion brian2modelfitting/tests/test_modelfitting_tracefitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,25 @@ def test_tracefitter_init_errors(setup):
output=np.array(output_traces), # no units
n_samples=2)


def test_tracefitter_fit_default_metric(setup):
dt, tf = setup
results, errors = tf.fit(n_rounds=2,
optimizer=n_opt,
metric=None,
g=[1*nS, 30*nS],
callback=None)
assert tf.simulator.neurons.iteration == 1
attr_fit = ['optimizer', 'metric', 'best_params']
for attr in attr_fit:
assert hasattr(tf, attr)
assert isinstance(tf.metric, MSEMetric) #default trace metric
assert isinstance(tf.simulator, Simulator)

assert_equal(results, tf.best_params)
assert_equal(errors, tf.best_error)


from nevergrad.optimization import registry as nevergrad_registry
@pytest.mark.parametrize('method', sorted(nevergrad_registry.keys()))
def test_fitter_fit_methods(method):
Expand Down Expand Up @@ -341,6 +360,25 @@ def test_fitter_fit_no_units(setup_no_units):
assert_equal(errors, tf.best_error)


def test_fitter_fit_default_optimizer(setup):
dt, tf = setup
results, errors = tf.fit(n_rounds=2,
optimizer=None,
metric=metric,
g=[1*nS, 30*nS],
callback=None)
assert tf.simulator.neurons.iteration == 1
attr_fit = ['optimizer', 'metric', 'best_params']
for attr in attr_fit:
assert hasattr(tf, attr)

assert isinstance(tf.optimizer, NevergradOptimizer) #default optimizer
assert isinstance(tf.simulator, Simulator)

assert_equal(results, tf.best_params)
assert_equal(errors, tf.best_error)


def test_fitter_fit_callback(setup):
dt, tf = setup

Expand Down Expand Up @@ -381,9 +419,13 @@ def our_callback(params, errors, best_params, best_error, index):

def test_fitter_fit_errors(setup):
dt, tf = setup
class NaiiveOptimizer:
def __init__(self):
self.best = []
opt = NaiiveOptimizer()
with pytest.raises(TypeError):
tf.fit(n_rounds=2,
optimizer=None,
optimizer=opt, #testing a Non-Optimizer child
metric=metric,
g=[1*nS, 30*nS])

Expand Down

0 comments on commit 0a0be47

Please sign in to comment.