Skip to content

Commit

Permalink
Merge pull request #22 from berenslab/circuit-extensions
Browse files Browse the repository at this point in the history
Circuit extensions
  • Loading branch information
alex404 authored Oct 21, 2024
2 parents 6d9d4a9 + 5c13643 commit 165ffed
Show file tree
Hide file tree
Showing 23 changed files with 464 additions and 324 deletions.
2 changes: 1 addition & 1 deletion config/base/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ system:
wandb_preempt: False # Whether to enable Weights & Biases preemption

# Whether to use Weights & Biases for logging
use_wandb: True
use_wandb: False

# Sweep command setup
sweep:
Expand Down
17 changes: 15 additions & 2 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
import sys
import warnings
from typing import Dict, List, cast

import hydra
import torch
Expand All @@ -17,7 +18,7 @@
from runner.initialize import initialize
from runner.sweep import launch_sweep
from runner.train import train
from runner.util import delete_results
from runner.util import assemble_neural_circuits, delete_results

# Load the eval resolver for OmegaConf
OmegaConf.register_new_resolver("eval", eval)
Expand All @@ -36,14 +37,26 @@ def _program(cfg: DictConfig):

device = torch.device(cfg.system.device)

brain = Brain(**cfg.brain).to(device)
sensors = OmegaConf.to_container(cfg.brain.sensors, resolve=True)
sensors = cast(Dict[str, List[int]], sensors)

connections = OmegaConf.to_container(cfg.brain.connections, resolve=True)
connections = cast(List[List[str]], connections)

connectome, circuits = assemble_neural_circuits(
cfg.brain.circuits, sensors, connections
)

brain = Brain(circuits, sensors, connectome).to(device)

if hasattr(cfg, "optimizer"):
optimizer = instantiate(cfg.optimizer.optimizer, brain.parameters())
objective = instantiate(cfg.optimizer.objective, brain=brain)
else:
warnings.warn("No optimizer config specified, is that wanted?")

if cfg.command == "scan":
brain.scan()
brain.scan_circuits()
sys.exit(0)

Expand Down
14 changes: 7 additions & 7 deletions resources/config_templates/user/brain/deep-autoencoder.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ circuits:
num_channels: [16,32] # Two layers with 16 and 32 channels
kernel_size: 8
stride: 2
act_name: ${activation}
activation: ${activation}
layer_names: ["bipolar", "retinal_ganglion"] # Names inspired by retinal cell types

# Thalamus: relay and processing station
Expand All @@ -38,7 +38,7 @@ circuits:
num_channels: 64
kernel_size: 5
stride: 1
act_name: ${activation}
activation: ${activation}
layer_names: ["lgn"] # Lateral Geniculate Nucleus

# Visual Cortex: higher-level visual processing
Expand All @@ -48,7 +48,7 @@ circuits:
num_channels: 64
kernel_size: 8
stride: 2
act_name: ${activation}
activation: ${activation}
layer_names: ["v1"] # Primary Visual Cortex

# Inferotemporal Cortex: Associations
Expand All @@ -58,7 +58,7 @@ circuits:
- 64 # Size of the latent representation
hidden_units:
- 64 # Number of hidden units
act_name: ${activation}
activation: ${activation}

# Prefrontal Cortex: high-level cognitive processing
prefrontal:
Expand All @@ -67,15 +67,15 @@ circuits:
- 32 # Size of the latent representation
hidden_units:
- 32 # Number of hidden units
act_name: ${activation}
activation: ${activation}

# Prefrontal Cortex: high-level cognitive processing
inferotemporal_decoder:
_target_: retinal_rl.models.circuits.fully_connected.FullyConnectedDecoder
output_shape: "inferotemporal.input_shape" # Size of the latent representation
hidden_units:
- 64 # Number of hidden units
act_name: ${activation}
activation: ${activation}

# Decoder: for reconstructing the input from the latent representation
decoder:
Expand All @@ -84,7 +84,7 @@ circuits:
num_channels: [64,32,16,3] # For a symmetric encoder, this should be the reverse of the num_channels in the CNN layers up to the point of decoding (in this case, the thalamus)
kernel_size: [5,8,8,8]
stride: [1,2,2,2]
act_name: ${activation}
activation: ${activation}

# Classifier: for categorizing the input into classes
classifier:
Expand Down
26 changes: 19 additions & 7 deletions resources/config_templates/user/brain/shallow-autoencoder.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ connections:
- ["thalamus","visual_cortex"] # Thalamus to visual cortex
- ["visual_cortex", "prefrontal"] # Visual cortex to prefrontal cortex
- ["prefrontal", "classifier"] # Prefrontal cortex to classifier
- ["visual_cortex", "decoder"] # Thalamus to decoder (for reconstruction)
- ["visual_cortex", "v1_decoder"] # v1 to decoder (for reconstruction)

# Define the individual nodes (neural circuits) of the network. Many circuit
# parameters are interpolated from the experiment config.
Expand All @@ -28,8 +28,10 @@ circuits:
- ${bp_kernel_size}
- ${rgc_kernel_size}
stride: 2
act_name: ${activation}
activation: ${activation}
layer_names: ["bipolar", "retinal_ganglion"] # Names inspired by retinal cell types
layer_norm: ${layer_norm}
affine_norm: ${affine_norm}

# Thalamus: relay and processing station
thalamus:
Expand All @@ -38,8 +40,10 @@ circuits:
num_channels: 64
kernel_size: ${lgn_kernel_size}
stride: 1
act_name: ${activation}
activation: ${activation}
layer_names: ["lgn"] # Lateral Geniculate Nucleus
layer_norm: ${layer_norm}
affine_norm: ${affine_norm}

# Visual Cortex: higher-level visual processing
visual_cortex:
Expand All @@ -48,8 +52,10 @@ circuits:
num_channels: 64
kernel_size: ${v1_kernel_size}
stride: 2
act_name: ${activation}
activation: ${activation}
layer_names: ["v1"] # Primary Visual Cortex
layer_norm: ${layer_norm}
affine_norm: ${affine_norm}

# Prefrontal Cortex: high-level cognitive processing
prefrontal:
Expand All @@ -58,20 +64,26 @@ circuits:
- 128 # Size of the latent representation
hidden_units:
- 64 # Number of hidden units
act_name: ${activation}
activation: ${activation}

# Decoder: for reconstructing the input from the latent representation
decoder:
v1_decoder:
_target_: retinal_rl.models.circuits.convolutional.ConvolutionalDecoder
num_layers: 4
layer_norm: ${layer_norm}
affine_norm: ${affine_norm}
num_channels: [64,32,16,3] # For a symmetric encoder, this should be the reverse of the num_channels in the CNN layers up to the point of decoding (in this case, the thalamus)
kernel_size:
- ${v1_kernel_size}
- ${lgn_kernel_size}
- ${rgc_kernel_size}
- ${bp_kernel_size}
stride: [2,1,2,2]
act_name: ${activation}
activation:
- ${activation}
- ${activation}
- ${activation}
- "tanh"

# Classifier: for categorizing the input into classes
classifier:
Expand Down
8 changes: 4 additions & 4 deletions resources/config_templates/user/dataset/cifar10.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@ imageset:
- ${eval:"1.5 if ${shot_noise_transform} else 0"}
- _target_: retinal_rl.datasets.transforms.ContrastTransform
contrast_range:
- ${eval:"0.01 if ${contrast_noise_transform} else 1"}
- ${eval:"1.2 if ${contrast_noise_transform} else 1"}
- ${eval:"0.6 if ${contrast_noise_transform} else 1"}
- ${eval:"1.4 if ${contrast_noise_transform} else 1"}
- _target_: retinal_rl.datasets.transforms.IlluminationTransform
brightness_range:
- ${eval:"0.1 if ${brightness_noise_transform} else 1"}
- ${eval:"1.5 if ${brightness_noise_transform} else 1"}
- ${eval:"0.6 if ${brightness_noise_transform} else 1"}
- ${eval:"1.4 if ${brightness_noise_transform} else 1"}
- _target_: retinal_rl.datasets.transforms.BlurTransform
blur_range:
- ${eval:"0 if ${blur_noise_transform} else 0"}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ framework: classification
# This is a free list of parameters that can be interpolated by the subconfigs
# in sweep, dataset, brain, and optimizer. A major use for this is interpolating
# values in the subconfigs, and then looping over them in a sweep.
weight_decay: 0.00001

activation: "elu"

bp_kernel_size: 14
Expand All @@ -29,3 +31,6 @@ shot_noise_transform: True
contrast_noise_transform: True
brightness_noise_transform: True
blur_noise_transform: True

layer_norm: False
affine_norm: False
4 changes: 3 additions & 1 deletion resources/config_templates/user/optimizer/class-recon.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ num_epochs: 100
optimizer: # torch.optim Class and parameters
_target_: torch.optim.Adam
lr: 0.0003
weight_decay: ${weight_decay}

# The objective function
objective:
Expand All @@ -27,11 +28,12 @@ objective:
- 1
- 1
- _target_: retinal_rl.models.loss.ReconstructionLoss
target_decoder: "v1_decoder"
target_circuits: # Circuit parameters to optimize with this optimizer. We train the retina and the decoder exclusively to maximize reconstruction
- retina
- thalamus
- visual_cortex
- decoder
- v1_decoder
- inferotemporal_decoder
weights:
- ${recon_weight_retina}
Expand Down
12 changes: 9 additions & 3 deletions resources/config_templates/user/sweep/recon-weight-sweep.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,15 @@ method: grid
project: retinal-rl

parameters:
use_wandb:
value: True
recon_weight_retina:
values: [0,0.9,0.99,0.999,0.9999,1]
values: [0,0.9,0.999,1]
recon_weight_thalamus:
values: [0,0.9,0.99,0.999,0.9999,1]
values: [0,0.9,0.999,1]
recon_weight_cortex:
values: [0,0.9,0.999,1]
optimizer:
value: "recon-weight"
value: "class-recon"
activation:
value: "gelu"
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ parameters:
value: "recon-weight"
brain:
value: "shallow-autoencoder"
use_wandb:
value: True
shot_noise_transform:
values: [False,True]
contrast_noise_transform:
Expand Down
2 changes: 1 addition & 1 deletion resources/retinal-rl.def
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ From: ubuntu:22.04
seaborn==0.13.2 \
hydra-core==1.3.2 \
networkx==3.3 \
ruff==0.5.4
ruff==0.7.0

# Clean up for smaller container size
rm acc159linux-x64.zip
Expand Down
50 changes: 0 additions & 50 deletions retinal_rl/analysis/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,56 +227,6 @@ def plot_brain_and_optimizers(brain: Brain, objective: Objective[ContextT]) -> F
plt.axis("off")

return fig
# Draw labels
nx.draw_networkx_labels(graph, pos, font_size=8, font_weight="bold")

# Add a legend for optimizers
legend_elements = [
Line2D(
[0],
[0],
marker="o",
color="w",
label=f"Optimizer: {name}",
markerfacecolor="none",
markeredgecolor=color,
markersize=15,
markeredgewidth=3,
)
for name, color in zip(objective.losses.keys(), optimizer_colors)
]

# Add legend elements for sensor and circuit
legend_elements.extend(
[
Line2D(
[0],
[0],
marker="o",
color="w",
label="Sensor",
markerfacecolor=color_map["sensor"],
markersize=15,
),
Line2D(
[0],
[0],
marker="o",
color="w",
label="Circuit",
markerfacecolor=color_map["circuit"],
markersize=15,
),
]
)

plt.legend(handles=legend_elements, loc="center left", bbox_to_anchor=(1, 0.5))

plt.title("Brain Connectome and Optimizer Targets")
plt.tight_layout()
plt.axis("off")

return fig


def plot_receptive_field_sizes(results: Dict[str, Dict[str, FloatArray]]) -> Figure:
Expand Down
Loading

0 comments on commit 165ffed

Please sign in to comment.