-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #16 from berenslab/flatten-objectives
Flatten objectives
- Loading branch information
Showing
11 changed files
with
322 additions
and
386 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
89 changes: 37 additions & 52 deletions
89
resources/config_templates/user/optimizer/class-recon.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,56 +1,41 @@ | ||
# Number of training epochs | ||
num_epochs: 100 | ||
|
||
# The optimizer to use | ||
optimizer: # torch.optim Class and parameters | ||
_target_: torch.optim.Adam | ||
lr: 0.0003 | ||
|
||
goal: | ||
recon: | ||
min_epoch: 0 # Epoch to start optimizer | ||
max_epoch: 100 # Epoch to stop optimizer | ||
losses: # Weighted optimizer losses as defined in retinal-rl | ||
- _target_: retinal_rl.models.loss.ReconstructionLoss | ||
weight: ${recon_weight_retina} | ||
- _target_: retinal_rl.classification.loss.ClassificationLoss | ||
weight: ${eval:'1-${recon_weight_retina}'} | ||
target_circuits: # Circuit parameters to optimize with this optimizer. We train the retina and the decoder exclusively to maximize reconstruction | ||
- retina | ||
decode: | ||
min_epoch: 0 # Epoch to start optimizer | ||
max_epoch: 100 # Epoch to stop optimizer | ||
losses: # Weighted optimizer losses as defined in retinal-rl | ||
- _target_: retinal_rl.models.loss.ReconstructionLoss | ||
weight: 1 | ||
target_circuits: # Circuit parameters to optimize with this optimizer. We train the retina and the decoder exclusively to maximize reconstruction | ||
- decoder | ||
- inferotemporal_decoder | ||
mixed: | ||
min_epoch: 0 | ||
max_epoch: 100 | ||
losses: | ||
- _target_: retinal_rl.models.loss.ReconstructionLoss | ||
weight: ${recon_weight_thalamus} | ||
- _target_: retinal_rl.classification.loss.ClassificationLoss | ||
weight: ${eval:'1-${recon_weight_thalamus}'} | ||
target_circuits: # The thalamus is somewhat sensitive to task losses | ||
- thalamus | ||
cortex: | ||
min_epoch: 0 | ||
max_epoch: 100 | ||
losses: | ||
- _target_: retinal_rl.models.loss.ReconstructionLoss | ||
weight: ${recon_weight_cortex} | ||
- _target_: retinal_rl.classification.loss.ClassificationLoss | ||
weight: ${eval:'1-${recon_weight_cortex}'} | ||
target_circuits: # Visual cortex and downstream layers are driven by the task | ||
- visual_cortex | ||
- inferotemporal | ||
class: | ||
min_epoch: 0 | ||
max_epoch: 100 | ||
losses: | ||
- _target_: retinal_rl.classification.loss.ClassificationLoss | ||
weight: 1 | ||
- _target_: retinal_rl.classification.loss.PercentCorrect | ||
weight: 0 | ||
target_circuits: # Visual cortex and downstream layers are driven by the task | ||
- prefrontal | ||
- classifier | ||
# The objective function | ||
objective: | ||
_target_: retinal_rl.models.objective.Objective | ||
losses: | ||
- _target_: retinal_rl.classification.loss.PercentCorrect | ||
- _target_: retinal_rl.classification.loss.ClassificationLoss | ||
target_circuits: # Circuit parameters to optimize with this optimizer. We train the retina and the decoder exclusively to maximize reconstruction | ||
- retina | ||
- thalamus | ||
- visual_cortex | ||
- inferotemporal | ||
- prefrontal | ||
- classifier | ||
weights: | ||
- ${eval:'1-${recon_weight_retina}'} | ||
- ${eval:'1-${recon_weight_thalamus}'} | ||
- ${eval:'1-${recon_weight_cortex}'} | ||
- 1 | ||
- 1 | ||
- 1 | ||
- _target_: retinal_rl.models.loss.ReconstructionLoss | ||
target_circuits: # Circuit parameters to optimize with this optimizer. We train the retina and the decoder exclusively to maximize reconstruction | ||
- retina | ||
- thalamus | ||
- visual_cortex | ||
- decoder | ||
- inferotemporal_decoder | ||
weights: | ||
- ${recon_weight_retina} | ||
- ${recon_weight_thalamus} | ||
- ${recon_weight_cortex} | ||
- 1 | ||
- 1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.