From e3249f9ac978e63bd214dbd513e5c2b0ce307ba5 Mon Sep 17 00:00:00 2001 From: Max Burg Date: Fri, 18 Mar 2022 14:14:01 +0100 Subject: [PATCH 1/6] [del] unused standard_trainer, black formatting --- nndichromacy/training/trainers.py | 239 +----------------------------- 1 file changed, 7 insertions(+), 232 deletions(-) diff --git a/nndichromacy/training/trainers.py b/nndichromacy/training/trainers.py index c3943da9..d1f44248 100644 --- a/nndichromacy/training/trainers.py +++ b/nndichromacy/training/trainers.py @@ -16,214 +16,6 @@ from ..utility.measures import get_correlations, get_poisson_loss -def standart_trainer( - model, - dataloaders, - seed, - avg_loss=False, - scale_loss=True, # trainer args - loss_function="PoissonLoss", - stop_function="get_correlations", - loss_accum_batch_n=None, - device="cuda", - verbose=True, - interval=1, - patience=5, - epoch=0, - lr_init=0.005, # early stopping args - max_iter=100, - maximize=True, - tolerance=1e-6, - restore_best=True, - lr_decay_steps=3, - lr_decay_factor=0.3, - min_lr=0.0001, # lr scheduler args - cb=None, - track_training=False, - return_test_score=False, - **kwargs -): - """ - - Args: - model: - dataloaders: - seed: - avg_loss: - scale_loss: - loss_function: - stop_function: - loss_accum_batch_n: - device: - verbose: - interval: - patience: - epoch: - lr_init: - max_iter: - maximize: - tolerance: - restore_best: - lr_decay_steps: - lr_decay_factor: - min_lr: - cb: - track_training: - **kwargs: - - Returns: - - """ - - def full_objective(model, dataloader, data_key, *args, **kwargs): - """ - - Args: - model: - dataloader: - data_key: - *args: - - Returns: - - """ - loss_scale = ( - np.sqrt(len(dataloader[data_key].dataset) / args[0].shape[0]) - if scale_loss - else 1.0 - ) - return loss_scale * criterion( - model(*args, data_key=data_key, **kwargs), args[1].to(device) - ) + model.regularizer(data_key) - - ##### Model training #################################################################################################### - model.to(device) - set_random_seed(seed) - model.train() - - criterion = getattr(mlmeasures, loss_function)(avg=avg_loss) - stop_closure = partial( - getattr(measures, stop_function), - dataloaders=dataloaders["validation"], - device=device, - per_neuron=False, - avg=True, - ) - - n_iterations = len(LongCycler(dataloaders["train"])) - - optimizer = torch.optim.Adam(model.parameters(), lr=lr_init) - scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( - optimizer, - mode="max" if maximize else "min", - factor=lr_decay_factor, - patience=patience, - threshold=tolerance, - min_lr=min_lr, - verbose=verbose, - threshold_mode="abs", - ) - - # set the number of iterations over which you would like to accummulate gradients - optim_step_count = ( - len(dataloaders["train"].keys()) - if loss_accum_batch_n is None - else loss_accum_batch_n - ) - - if track_training: - tracker_dict = dict( - correlation=partial( - get_correlations, - model=model, - dataloaders=dataloaders["validation"], - device=device, - per_neuron=False, - ), - poisson_loss=partial( - get_poisson_loss, - model=model, - datalaoders=dataloaders["validation"], - device=device, - per_neuron=False, - avg=False, - ), - ) - if hasattr(model, "tracked_values"): - tracker_dict.update(model.tracked_values) - tracker = MultipleObjectiveTracker(**tracker_dict) - else: - tracker = None - - # train over epochs - for epoch, val_obj in early_stopping( - model, - stop_closure, - interval=interval, - patience=patience, - start=epoch, - max_iter=max_iter, - maximize=maximize, - tolerance=tolerance, - restore_best=restore_best, - tracker=tracker, - scheduler=scheduler, - lr_decay_steps=lr_decay_steps, - ): - - # print the quantities from tracker - if verbose and tracker is not None: - print("=======================================", flush=True) - for key in tracker.log.keys(): - print(key, tracker.log[key][-1], flush=True) - - # executes callback function if passed in keyword args - if cb is not None: - cb() - - # train over batches - optimizer.zero_grad() - if hasattr(tqdm, "_instances"): - tqdm._instances.clear() - - for batch_no, (data_key, data) in tqdm( - enumerate(LongCycler(dataloaders["train"])), - total=n_iterations, - desc="Epoch {}".format(epoch), - ): - loss = full_objective( - model, dataloaders["train"], data_key, *data, **data._asdict() - ) - loss.backward() - if (batch_no + 1) % optim_step_count == 0: - optimizer.step() - optimizer.zero_grad() - - ##### Model evaluation #################################################################################################### - model.eval() - tracker.finalize() if track_training else None - - # Compute avg validation and test correlation - validation_correlation = get_correlations( - model, dataloaders["validation"], device=device, as_dict=False, per_neuron=False - ) - if return_test_score: - test_correlation = get_correlations( - model, dataloaders["test"], device=device, as_dict=False, per_neuron=False - ) - - # return the whole tracker output as a dict - output = {k: v for k, v in tracker.log.items()} if track_training else {} - output["validation_corr"] = validation_correlation - - score = ( - np.mean(test_correlation) - if return_test_score - else np.mean(validation_correlation) - ) - return score, output, model.state_dict() - - def standard_trainer( model, dataloaders, @@ -252,6 +44,7 @@ def standard_trainer( detach_core=False, **kwargs ): + print("TRAINER 2") """ Args: @@ -286,14 +79,8 @@ def standard_trainer( def full_objective(model, dataloader, data_key, *args, detach_core): - loss_scale = ( - np.sqrt(len(dataloader[data_key].dataset) / args[0].shape[0]) - if scale_loss - else 1.0 - ) - regularizers = int( - not detach_core - ) * model.core.regularizer() + model.readout.regularizer(data_key) + loss_scale = np.sqrt(len(dataloader[data_key].dataset) / args[0].shape[0]) if scale_loss else 1.0 + regularizers = int(not detach_core) * model.core.regularizer() + model.readout.regularizer(data_key) return ( loss_scale * criterion( @@ -332,11 +119,7 @@ def full_objective(model, dataloader, data_key, *args, detach_core): ) # set the number of iterations over which you would like to accummulate gradients - optim_step_count = ( - len(dataloaders["train"].keys()) - if loss_accum_batch_n is None - else loss_accum_batch_n - ) + optim_step_count = len(dataloaders["train"].keys()) if loss_accum_batch_n is None else loss_accum_batch_n if track_training: tracker_dict = dict( @@ -396,9 +179,7 @@ def full_objective(model, dataloader, data_key, *args, detach_core): desc="Epoch {}".format(epoch), ): - loss = full_objective( - model, dataloaders["train"], data_key, *data, detach_core=detach_core - ) + loss = full_objective(model, dataloaders["train"], data_key, *data, detach_core=detach_core) loss.backward() if (batch_no + 1) % optim_step_count == 0: optimizer.step() @@ -412,19 +193,13 @@ def full_objective(model, dataloader, data_key, *args, detach_core): validation_correlation = get_correlations( model, dataloaders["validation"], device=device, as_dict=False, per_neuron=False ) - test_correlation = get_correlations( - model, dataloaders["test"], device=device, as_dict=False, per_neuron=False - ) + test_correlation = get_correlations(model, dataloaders["test"], device=device, as_dict=False, per_neuron=False) # return the whole tracker output as a dict output = {k: v for k, v in tracker.log.items()} if track_training else {} output["validation_corr"] = validation_correlation - score = ( - np.mean(test_correlation) - if return_test_score - else np.mean(validation_correlation) - ) + score = np.mean(test_correlation) if return_test_score else np.mean(validation_correlation) return score, output, model.state_dict() From bcb81510ae3fa8955c27a7b4df6d8e49bee96078 Mon Sep 17 00:00:00 2001 From: Max Burg Date: Fri, 18 Mar 2022 14:14:01 +0100 Subject: [PATCH 2/6] [del] unused standard_trainer, black formatting --- nndichromacy/training/trainers.py | 238 +----------------------------- 1 file changed, 6 insertions(+), 232 deletions(-) diff --git a/nndichromacy/training/trainers.py b/nndichromacy/training/trainers.py index c3943da9..cff73b9b 100644 --- a/nndichromacy/training/trainers.py +++ b/nndichromacy/training/trainers.py @@ -16,214 +16,6 @@ from ..utility.measures import get_correlations, get_poisson_loss -def standart_trainer( - model, - dataloaders, - seed, - avg_loss=False, - scale_loss=True, # trainer args - loss_function="PoissonLoss", - stop_function="get_correlations", - loss_accum_batch_n=None, - device="cuda", - verbose=True, - interval=1, - patience=5, - epoch=0, - lr_init=0.005, # early stopping args - max_iter=100, - maximize=True, - tolerance=1e-6, - restore_best=True, - lr_decay_steps=3, - lr_decay_factor=0.3, - min_lr=0.0001, # lr scheduler args - cb=None, - track_training=False, - return_test_score=False, - **kwargs -): - """ - - Args: - model: - dataloaders: - seed: - avg_loss: - scale_loss: - loss_function: - stop_function: - loss_accum_batch_n: - device: - verbose: - interval: - patience: - epoch: - lr_init: - max_iter: - maximize: - tolerance: - restore_best: - lr_decay_steps: - lr_decay_factor: - min_lr: - cb: - track_training: - **kwargs: - - Returns: - - """ - - def full_objective(model, dataloader, data_key, *args, **kwargs): - """ - - Args: - model: - dataloader: - data_key: - *args: - - Returns: - - """ - loss_scale = ( - np.sqrt(len(dataloader[data_key].dataset) / args[0].shape[0]) - if scale_loss - else 1.0 - ) - return loss_scale * criterion( - model(*args, data_key=data_key, **kwargs), args[1].to(device) - ) + model.regularizer(data_key) - - ##### Model training #################################################################################################### - model.to(device) - set_random_seed(seed) - model.train() - - criterion = getattr(mlmeasures, loss_function)(avg=avg_loss) - stop_closure = partial( - getattr(measures, stop_function), - dataloaders=dataloaders["validation"], - device=device, - per_neuron=False, - avg=True, - ) - - n_iterations = len(LongCycler(dataloaders["train"])) - - optimizer = torch.optim.Adam(model.parameters(), lr=lr_init) - scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( - optimizer, - mode="max" if maximize else "min", - factor=lr_decay_factor, - patience=patience, - threshold=tolerance, - min_lr=min_lr, - verbose=verbose, - threshold_mode="abs", - ) - - # set the number of iterations over which you would like to accummulate gradients - optim_step_count = ( - len(dataloaders["train"].keys()) - if loss_accum_batch_n is None - else loss_accum_batch_n - ) - - if track_training: - tracker_dict = dict( - correlation=partial( - get_correlations, - model=model, - dataloaders=dataloaders["validation"], - device=device, - per_neuron=False, - ), - poisson_loss=partial( - get_poisson_loss, - model=model, - datalaoders=dataloaders["validation"], - device=device, - per_neuron=False, - avg=False, - ), - ) - if hasattr(model, "tracked_values"): - tracker_dict.update(model.tracked_values) - tracker = MultipleObjectiveTracker(**tracker_dict) - else: - tracker = None - - # train over epochs - for epoch, val_obj in early_stopping( - model, - stop_closure, - interval=interval, - patience=patience, - start=epoch, - max_iter=max_iter, - maximize=maximize, - tolerance=tolerance, - restore_best=restore_best, - tracker=tracker, - scheduler=scheduler, - lr_decay_steps=lr_decay_steps, - ): - - # print the quantities from tracker - if verbose and tracker is not None: - print("=======================================", flush=True) - for key in tracker.log.keys(): - print(key, tracker.log[key][-1], flush=True) - - # executes callback function if passed in keyword args - if cb is not None: - cb() - - # train over batches - optimizer.zero_grad() - if hasattr(tqdm, "_instances"): - tqdm._instances.clear() - - for batch_no, (data_key, data) in tqdm( - enumerate(LongCycler(dataloaders["train"])), - total=n_iterations, - desc="Epoch {}".format(epoch), - ): - loss = full_objective( - model, dataloaders["train"], data_key, *data, **data._asdict() - ) - loss.backward() - if (batch_no + 1) % optim_step_count == 0: - optimizer.step() - optimizer.zero_grad() - - ##### Model evaluation #################################################################################################### - model.eval() - tracker.finalize() if track_training else None - - # Compute avg validation and test correlation - validation_correlation = get_correlations( - model, dataloaders["validation"], device=device, as_dict=False, per_neuron=False - ) - if return_test_score: - test_correlation = get_correlations( - model, dataloaders["test"], device=device, as_dict=False, per_neuron=False - ) - - # return the whole tracker output as a dict - output = {k: v for k, v in tracker.log.items()} if track_training else {} - output["validation_corr"] = validation_correlation - - score = ( - np.mean(test_correlation) - if return_test_score - else np.mean(validation_correlation) - ) - return score, output, model.state_dict() - - def standard_trainer( model, dataloaders, @@ -286,14 +78,8 @@ def standard_trainer( def full_objective(model, dataloader, data_key, *args, detach_core): - loss_scale = ( - np.sqrt(len(dataloader[data_key].dataset) / args[0].shape[0]) - if scale_loss - else 1.0 - ) - regularizers = int( - not detach_core - ) * model.core.regularizer() + model.readout.regularizer(data_key) + loss_scale = np.sqrt(len(dataloader[data_key].dataset) / args[0].shape[0]) if scale_loss else 1.0 + regularizers = int(not detach_core) * model.core.regularizer() + model.readout.regularizer(data_key) return ( loss_scale * criterion( @@ -332,11 +118,7 @@ def full_objective(model, dataloader, data_key, *args, detach_core): ) # set the number of iterations over which you would like to accummulate gradients - optim_step_count = ( - len(dataloaders["train"].keys()) - if loss_accum_batch_n is None - else loss_accum_batch_n - ) + optim_step_count = len(dataloaders["train"].keys()) if loss_accum_batch_n is None else loss_accum_batch_n if track_training: tracker_dict = dict( @@ -396,9 +178,7 @@ def full_objective(model, dataloader, data_key, *args, detach_core): desc="Epoch {}".format(epoch), ): - loss = full_objective( - model, dataloaders["train"], data_key, *data, detach_core=detach_core - ) + loss = full_objective(model, dataloaders["train"], data_key, *data, detach_core=detach_core) loss.backward() if (batch_no + 1) % optim_step_count == 0: optimizer.step() @@ -412,19 +192,13 @@ def full_objective(model, dataloader, data_key, *args, detach_core): validation_correlation = get_correlations( model, dataloaders["validation"], device=device, as_dict=False, per_neuron=False ) - test_correlation = get_correlations( - model, dataloaders["test"], device=device, as_dict=False, per_neuron=False - ) + test_correlation = get_correlations(model, dataloaders["test"], device=device, as_dict=False, per_neuron=False) # return the whole tracker output as a dict output = {k: v for k, v in tracker.log.items()} if track_training else {} output["validation_corr"] = validation_correlation - score = ( - np.mean(test_correlation) - if return_test_score - else np.mean(validation_correlation) - ) + score = np.mean(test_correlation) if return_test_score else np.mean(validation_correlation) return score, output, model.state_dict() From 0c1a39c5b1025f065f92f8e74c092a9007bbd282 Mon Sep 17 00:00:00 2001 From: Max Burg Date: Fri, 18 Mar 2022 14:18:03 +0100 Subject: [PATCH 3/6] [del] debug print --- nndichromacy/training/trainers.py | 1 - 1 file changed, 1 deletion(-) diff --git a/nndichromacy/training/trainers.py b/nndichromacy/training/trainers.py index d1f44248..cff73b9b 100644 --- a/nndichromacy/training/trainers.py +++ b/nndichromacy/training/trainers.py @@ -44,7 +44,6 @@ def standard_trainer( detach_core=False, **kwargs ): - print("TRAINER 2") """ Args: From ba30ae0d6cf22d7dd4ac0311f0531df651ea3087 Mon Sep 17 00:00:00 2001 From: Max Burg Date: Fri, 18 Mar 2022 15:09:02 +0100 Subject: [PATCH 4/6] undo --- nndichromacy/training/trainers.py | 238 +++++++++++++++++++++++++++++- 1 file changed, 232 insertions(+), 6 deletions(-) diff --git a/nndichromacy/training/trainers.py b/nndichromacy/training/trainers.py index cff73b9b..c3943da9 100644 --- a/nndichromacy/training/trainers.py +++ b/nndichromacy/training/trainers.py @@ -16,6 +16,214 @@ from ..utility.measures import get_correlations, get_poisson_loss +def standart_trainer( + model, + dataloaders, + seed, + avg_loss=False, + scale_loss=True, # trainer args + loss_function="PoissonLoss", + stop_function="get_correlations", + loss_accum_batch_n=None, + device="cuda", + verbose=True, + interval=1, + patience=5, + epoch=0, + lr_init=0.005, # early stopping args + max_iter=100, + maximize=True, + tolerance=1e-6, + restore_best=True, + lr_decay_steps=3, + lr_decay_factor=0.3, + min_lr=0.0001, # lr scheduler args + cb=None, + track_training=False, + return_test_score=False, + **kwargs +): + """ + + Args: + model: + dataloaders: + seed: + avg_loss: + scale_loss: + loss_function: + stop_function: + loss_accum_batch_n: + device: + verbose: + interval: + patience: + epoch: + lr_init: + max_iter: + maximize: + tolerance: + restore_best: + lr_decay_steps: + lr_decay_factor: + min_lr: + cb: + track_training: + **kwargs: + + Returns: + + """ + + def full_objective(model, dataloader, data_key, *args, **kwargs): + """ + + Args: + model: + dataloader: + data_key: + *args: + + Returns: + + """ + loss_scale = ( + np.sqrt(len(dataloader[data_key].dataset) / args[0].shape[0]) + if scale_loss + else 1.0 + ) + return loss_scale * criterion( + model(*args, data_key=data_key, **kwargs), args[1].to(device) + ) + model.regularizer(data_key) + + ##### Model training #################################################################################################### + model.to(device) + set_random_seed(seed) + model.train() + + criterion = getattr(mlmeasures, loss_function)(avg=avg_loss) + stop_closure = partial( + getattr(measures, stop_function), + dataloaders=dataloaders["validation"], + device=device, + per_neuron=False, + avg=True, + ) + + n_iterations = len(LongCycler(dataloaders["train"])) + + optimizer = torch.optim.Adam(model.parameters(), lr=lr_init) + scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( + optimizer, + mode="max" if maximize else "min", + factor=lr_decay_factor, + patience=patience, + threshold=tolerance, + min_lr=min_lr, + verbose=verbose, + threshold_mode="abs", + ) + + # set the number of iterations over which you would like to accummulate gradients + optim_step_count = ( + len(dataloaders["train"].keys()) + if loss_accum_batch_n is None + else loss_accum_batch_n + ) + + if track_training: + tracker_dict = dict( + correlation=partial( + get_correlations, + model=model, + dataloaders=dataloaders["validation"], + device=device, + per_neuron=False, + ), + poisson_loss=partial( + get_poisson_loss, + model=model, + datalaoders=dataloaders["validation"], + device=device, + per_neuron=False, + avg=False, + ), + ) + if hasattr(model, "tracked_values"): + tracker_dict.update(model.tracked_values) + tracker = MultipleObjectiveTracker(**tracker_dict) + else: + tracker = None + + # train over epochs + for epoch, val_obj in early_stopping( + model, + stop_closure, + interval=interval, + patience=patience, + start=epoch, + max_iter=max_iter, + maximize=maximize, + tolerance=tolerance, + restore_best=restore_best, + tracker=tracker, + scheduler=scheduler, + lr_decay_steps=lr_decay_steps, + ): + + # print the quantities from tracker + if verbose and tracker is not None: + print("=======================================", flush=True) + for key in tracker.log.keys(): + print(key, tracker.log[key][-1], flush=True) + + # executes callback function if passed in keyword args + if cb is not None: + cb() + + # train over batches + optimizer.zero_grad() + if hasattr(tqdm, "_instances"): + tqdm._instances.clear() + + for batch_no, (data_key, data) in tqdm( + enumerate(LongCycler(dataloaders["train"])), + total=n_iterations, + desc="Epoch {}".format(epoch), + ): + loss = full_objective( + model, dataloaders["train"], data_key, *data, **data._asdict() + ) + loss.backward() + if (batch_no + 1) % optim_step_count == 0: + optimizer.step() + optimizer.zero_grad() + + ##### Model evaluation #################################################################################################### + model.eval() + tracker.finalize() if track_training else None + + # Compute avg validation and test correlation + validation_correlation = get_correlations( + model, dataloaders["validation"], device=device, as_dict=False, per_neuron=False + ) + if return_test_score: + test_correlation = get_correlations( + model, dataloaders["test"], device=device, as_dict=False, per_neuron=False + ) + + # return the whole tracker output as a dict + output = {k: v for k, v in tracker.log.items()} if track_training else {} + output["validation_corr"] = validation_correlation + + score = ( + np.mean(test_correlation) + if return_test_score + else np.mean(validation_correlation) + ) + return score, output, model.state_dict() + + def standard_trainer( model, dataloaders, @@ -78,8 +286,14 @@ def standard_trainer( def full_objective(model, dataloader, data_key, *args, detach_core): - loss_scale = np.sqrt(len(dataloader[data_key].dataset) / args[0].shape[0]) if scale_loss else 1.0 - regularizers = int(not detach_core) * model.core.regularizer() + model.readout.regularizer(data_key) + loss_scale = ( + np.sqrt(len(dataloader[data_key].dataset) / args[0].shape[0]) + if scale_loss + else 1.0 + ) + regularizers = int( + not detach_core + ) * model.core.regularizer() + model.readout.regularizer(data_key) return ( loss_scale * criterion( @@ -118,7 +332,11 @@ def full_objective(model, dataloader, data_key, *args, detach_core): ) # set the number of iterations over which you would like to accummulate gradients - optim_step_count = len(dataloaders["train"].keys()) if loss_accum_batch_n is None else loss_accum_batch_n + optim_step_count = ( + len(dataloaders["train"].keys()) + if loss_accum_batch_n is None + else loss_accum_batch_n + ) if track_training: tracker_dict = dict( @@ -178,7 +396,9 @@ def full_objective(model, dataloader, data_key, *args, detach_core): desc="Epoch {}".format(epoch), ): - loss = full_objective(model, dataloaders["train"], data_key, *data, detach_core=detach_core) + loss = full_objective( + model, dataloaders["train"], data_key, *data, detach_core=detach_core + ) loss.backward() if (batch_no + 1) % optim_step_count == 0: optimizer.step() @@ -192,13 +412,19 @@ def full_objective(model, dataloader, data_key, *args, detach_core): validation_correlation = get_correlations( model, dataloaders["validation"], device=device, as_dict=False, per_neuron=False ) - test_correlation = get_correlations(model, dataloaders["test"], device=device, as_dict=False, per_neuron=False) + test_correlation = get_correlations( + model, dataloaders["test"], device=device, as_dict=False, per_neuron=False + ) # return the whole tracker output as a dict output = {k: v for k, v in tracker.log.items()} if track_training else {} output["validation_corr"] = validation_correlation - score = np.mean(test_correlation) if return_test_score else np.mean(validation_correlation) + score = ( + np.mean(test_correlation) + if return_test_score + else np.mean(validation_correlation) + ) return score, output, model.state_dict() From e9028bf8de3dfb6d9050ffeffb545f8eb351552e Mon Sep 17 00:00:00 2001 From: Max Burg Date: Tue, 5 Apr 2022 10:12:02 +0200 Subject: [PATCH 5/6] [add] argument to turn off final_readout_nonlinearity in Encoder and se_core_full_gauss_readout model --- nndichromacy/models/encoders.py | 8 ++- nndichromacy/models/models.py | 99 +++++++++------------------------ 2 files changed, 33 insertions(+), 74 deletions(-) diff --git a/nndichromacy/models/encoders.py b/nndichromacy/models/encoders.py index 67ff78ba..00113ebf 100644 --- a/nndichromacy/models/encoders.py +++ b/nndichromacy/models/encoders.py @@ -7,12 +7,13 @@ class Encoder(nn.Module): - def __init__(self, core, readout, elu_offset, shifter=None): + def __init__(self, core, readout, final_nonlinearity, elu_offset, shifter=None): super().__init__() self.core = core self.readout = readout self.offset = elu_offset self.shifter = shifter + self.nonlinearity = final_nonlinearity def forward( self, *args, data_key=None, eye_pos=None, shift=None, trial_idx=None, **kwargs @@ -52,7 +53,10 @@ def forward( shift = self.shifter[data_key](eye_pos) x = self.readout(x, data_key=data_key, shift=shift, **kwargs) - return F.elu(x + self.offset) + 1 + if self.nonlinearity is True: + return F.elu(x + self.offset) + 1 + else: + return x def regularizer(self, data_key): return self.core.regularizer() + self.readout.regularizer(data_key=data_key) diff --git a/nndichromacy/models/models.py b/nndichromacy/models/models.py index a6fa9772..ec2acdf9 100644 --- a/nndichromacy/models/models.py +++ b/nndichromacy/models/models.py @@ -29,9 +29,7 @@ except ModuleNotFoundError: pass except: - print( - "dj database connection could not be established. no access to pretrained models available." - ) + print("dj database connection could not be established. no access to pretrained models available.") # from . import logger as log @@ -162,11 +160,7 @@ def se_core_gauss_readout( in_shapes_dict = {k: v[in_name] for k, v in session_shape_dict.items()} input_channels = [v[in_name][1] for v in session_shape_dict.values()] - core_input_channels = ( - list(input_channels.values())[0] - if isinstance(input_channels, dict) - else input_channels[0] - ) + core_input_channels = list(input_channels.values())[0] if isinstance(input_channels, dict) else input_channels[0] class Encoder(nn.Module): def __init__(self, core, readout, elu_offset): @@ -251,6 +245,7 @@ def se_core_full_gauss_readout( init_sigma=1.0, readout_bias=True, # readout args, gamma_readout=4, + final_readout_nonlinearity=True, elu_offset=0, stack=None, se_reduction=32, @@ -313,11 +308,7 @@ def se_core_full_gauss_readout( in_shapes_dict = {k: v[in_name] for k, v in session_shape_dict.items()} input_channels = [v[in_name][1] for v in session_shape_dict.values()] - core_input_channels = ( - list(input_channels.values())[0] - if isinstance(input_channels, dict) - else input_channels[0] - ) + core_input_channels = list(input_channels.values())[0] if isinstance(input_channels, dict) else input_channels[0] source_grids = None grid_mean_predictor_type = None @@ -326,18 +317,13 @@ def se_core_full_gauss_readout( grid_mean_predictor_type = grid_mean_predictor.pop("type") if grid_mean_predictor_type == "cortex": input_dim = grid_mean_predictor.pop("input_dimensions", 2) - source_grids = { - k: v.dataset.neurons.cell_motor_coordinates[:, :input_dim] - for k, v in dataloaders.items() - } + source_grids = {k: v.dataset.neurons.cell_motor_coordinates[:, :input_dim] for k, v in dataloaders.items()} elif grid_mean_predictor_type == "shared": pass shared_match_ids = None if share_features or share_grid: - shared_match_ids = { - k: v.dataset.neurons.multi_match_id for k, v in dataloaders.items() - } + shared_match_ids = {k: v.dataset.neurons.multi_match_id for k, v in dataloaders.items()} all_multi_unit_ids = set(np.hstack(shared_match_ids.values())) for match_id in shared_match_ids.values(): @@ -414,7 +400,13 @@ def se_core_full_gauss_readout( _, targets = next(iter(value))[:2] readout[key].bias.data = targets.mean(0) - model = Encoder(core=core, readout=readout, elu_offset=elu_offset, shifter=shifter) + model = Encoder( + core=core, + readout=readout, + final_nonlinearity=final_readout_nonlinearity, + elu_offset=elu_offset, + shifter=shifter, + ) return model @@ -503,11 +495,7 @@ def se_core_behavior_gauss( in_shapes_dict = {k: v[in_name] for k, v in session_shape_dict.items()} input_channels = [v[in_name][1] for v in session_shape_dict.values()] - core_input_channels = ( - list(input_channels.values())[0] - if isinstance(input_channels, dict) - else input_channels[0] - ) + core_input_channels = list(input_channels.values())[0] if isinstance(input_channels, dict) else input_channels[0] if "train" in dataloaders.keys(): dataloaders = dataloaders["train"] @@ -527,18 +515,13 @@ def se_core_behavior_gauss( grid_mean_predictor_type = grid_mean_predictor.pop("type") if grid_mean_predictor_type == "cortex": input_dim = grid_mean_predictor.pop("input_dimensions", 2) - source_grids = { - k: v.dataset.neurons.cell_motor_coordinates[:, :input_dim] - for k, v in dataloaders.items() - } + source_grids = {k: v.dataset.neurons.cell_motor_coordinates[:, :input_dim] for k, v in dataloaders.items()} elif grid_mean_predictor_type == "shared": pass shared_match_ids = None if share_features or share_grid: - shared_match_ids = { - k: v.dataset.neurons.multi_match_id for k, v in dataloaders.items() - } + shared_match_ids = {k: v.dataset.neurons.multi_match_id for k, v in dataloaders.items()} all_multi_unit_ids = set(np.hstack(shared_match_ids.values())) for match_id in shared_match_ids.values(): @@ -618,9 +601,7 @@ def se_core_behavior_gauss( _, targets = next(iter(value))[:2] readout[key].bias.data = targets.mean(0) - model = GeneralEncoder( - core=core, readout=readout, elu_offset=elu_offset, shifter=shifter - ) + model = GeneralEncoder(core=core, readout=readout, elu_offset=elu_offset, shifter=shifter) return model @@ -683,11 +664,7 @@ def se_core_point_readout( in_shapes_dict = {k: v[in_name] for k, v in session_shape_dict.items()} input_channels = [v[in_name][1] for v in session_shape_dict.values()] - core_input_channels = ( - list(input_channels.values())[0] - if isinstance(input_channels, dict) - else input_channels[0] - ) + core_input_channels = list(input_channels.values())[0] if isinstance(input_channels, dict) else input_channels[0] set_random_seed(seed) @@ -792,11 +769,7 @@ def stacked2d_core_gaussian_readout( in_shapes_dict = {k: v[in_name] for k, v in session_shape_dict.items()} input_channels = [v[in_name][1] for v in session_shape_dict.values()] - core_input_channels = ( - list(input_channels.values())[0] - if isinstance(input_channels, dict) - else input_channels[0] - ) + core_input_channels = list(input_channels.values())[0] if isinstance(input_channels, dict) else input_channels[0] class Encoder(nn.Module): def __init__(self, core, readout, elu_offset): @@ -905,11 +878,7 @@ def vgg_core_gauss_readout( in_shapes_dict = {k: v[in_name] for k, v in session_shape_dict.items()} input_channels = [v[in_name][1] for v in session_shape_dict.values()] - core_input_channels = ( - list(input_channels.values())[0] - if isinstance(input_channels, dict) - else input_channels[0] - ) + core_input_channels = list(input_channels.values())[0] if isinstance(input_channels, dict) else input_channels[0] class Encoder(nn.Module): """ @@ -1011,11 +980,7 @@ def vgg_core_full_gauss_readout( in_shapes_dict = {k: v[in_name] for k, v in session_shape_dict.items()} input_channels = [v[in_name][1] for v in session_shape_dict.values()] - core_input_channels = ( - list(input_channels.values())[0] - if isinstance(input_channels, dict) - else input_channels[0] - ) + core_input_channels = list(input_channels.values())[0] if isinstance(input_channels, dict) else input_channels[0] class Encoder(nn.Module): """ @@ -1118,11 +1083,7 @@ def se_core_spatialXfeature_readout( in_shapes_dict = {k: v[in_name] for k, v in session_shape_dict.items()} input_channels = [v[in_name][1] for v in session_shape_dict.values()] - core_input_channels = ( - list(input_channels.values())[0] - if isinstance(input_channels, dict) - else input_channels[0] - ) + core_input_channels = list(input_channels.values())[0] if isinstance(input_channels, dict) else input_channels[0] class Encoder(nn.Module): def __init__(self, core, readout, elu_offset): @@ -1219,11 +1180,7 @@ def rotation_equivariant_gauss_readout( in_shapes_dict = {k: v[in_name] for k, v in session_shape_dict.items()} input_channels = [v[in_name][1] for v in session_shape_dict.values()] - core_input_channels = ( - list(input_channels.values())[0] - if isinstance(input_channels, dict) - else input_channels[0] - ) + core_input_channels = list(input_channels.values())[0] if isinstance(input_channels, dict) else input_channels[0] class Encoder(nn.Module): def __init__(self, core, readout, elu_offset): @@ -1337,9 +1294,9 @@ def augmented_full_readout( model.readout["augmentation"].features.data[ :, :, :, insert_index : insert_index + neuron_repeats ] = features[:, :, :, None] - model.readout["augmentation"].bias.data[ - insert_index : insert_index + neuron_repeats - ] = model.readout[data_key].bias.data[i] + model.readout["augmentation"].bias.data[insert_index : insert_index + neuron_repeats] = model.readout[ + data_key + ].bias.data[i] model.readout["augmentation"].sigma.data[ :, insert_index : insert_index + neuron_repeats, :, : ] = model.readout[data_key].sigma.data[:, i, ...] @@ -1362,9 +1319,7 @@ def augmented_full_readout( if rename_data_key is False: if len(sessions) > 1: - raise ValueError( - "Renaming to original data key is only possible when dataloader has one data key only" - ) + raise ValueError("Renaming to original data key is only possible when dataloader has one data key only") model.readout[sessions[0]] = model.readout.pop("augmentation") return models From f8cf924521497312ee8daded347f139f08769795 Mon Sep 17 00:00:00 2001 From: Max Burg Date: Wed, 6 Apr 2022 10:14:18 +0200 Subject: [PATCH 6/6] [refactor] more descriptive attribute name --- nndichromacy/models/encoders.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/nndichromacy/models/encoders.py b/nndichromacy/models/encoders.py index 00113ebf..4cd38ab7 100644 --- a/nndichromacy/models/encoders.py +++ b/nndichromacy/models/encoders.py @@ -13,7 +13,7 @@ def __init__(self, core, readout, final_nonlinearity, elu_offset, shifter=None): self.readout = readout self.offset = elu_offset self.shifter = shifter - self.nonlinearity = final_nonlinearity + self.readout_nonlinearity = final_nonlinearity def forward( self, *args, data_key=None, eye_pos=None, shift=None, trial_idx=None, **kwargs @@ -53,7 +53,7 @@ def forward( shift = self.shifter[data_key](eye_pos) x = self.readout(x, data_key=data_key, shift=shift, **kwargs) - if self.nonlinearity is True: + if self.readout_nonlinearity is True: return F.elu(x + self.offset) + 1 else: return x