Skip to content

Commit

Permalink
Merge pull request #16 from berenslab/flatten-objectives
Browse files Browse the repository at this point in the history
Flatten objectives
  • Loading branch information
alex404 authored Oct 18, 2024
2 parents baef341 + 2c44b6a commit 5ea0208
Show file tree
Hide file tree
Showing 11 changed files with 322 additions and 386 deletions.
8 changes: 3 additions & 5 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,8 @@
from hydra.utils import instantiate
from omegaconf import DictConfig, OmegaConf

from retinal_rl.classification.loss import ClassificationContext
from retinal_rl.framework_interface import TrainingFramework
from retinal_rl.models.brain import Brain
from retinal_rl.models.goal import Goal
from retinal_rl.rl.sample_factory.sf_framework import SFFramework
from runner.analyze import analyze
from runner.dataset import get_datasets
Expand Down Expand Up @@ -40,8 +38,8 @@ def _program(cfg: DictConfig):

brain = Brain(**cfg.brain).to(device)
if hasattr(cfg, "optimizer"):
goal = Goal[ClassificationContext](brain, dict(cfg.optimizer.goal))
optimizer = instantiate(cfg.optimizer.optimizer, brain.parameters())
objective = instantiate(cfg.optimizer.objective, brain=brain)
else:
warnings.warn("No optimizer config specified, is that wanted?")

Expand All @@ -68,7 +66,7 @@ def _program(cfg: DictConfig):
cfg,
device,
brain,
goal,
objective,
optimizer,
train_set,
test_set,
Expand All @@ -82,7 +80,7 @@ def _program(cfg: DictConfig):
cfg,
device,
brain,
goal,
objective,
histories,
train_set,
test_set,
Expand Down
5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@ select = [
"I", # Import conventions
]

ignore = ["E501"] # Example: Ignore line length warnings
ignore = [
"E501", # Example: Ignore line length warnings
"D", # Ignore all docstring-related warnings
]

[tool.ruff.format]
docstring-code-format = true
Expand Down
89 changes: 37 additions & 52 deletions resources/config_templates/user/optimizer/class-recon.yaml
Original file line number Diff line number Diff line change
@@ -1,56 +1,41 @@
# Number of training epochs
num_epochs: 100

# The optimizer to use
optimizer: # torch.optim Class and parameters
_target_: torch.optim.Adam
lr: 0.0003

goal:
recon:
min_epoch: 0 # Epoch to start optimizer
max_epoch: 100 # Epoch to stop optimizer
losses: # Weighted optimizer losses as defined in retinal-rl
- _target_: retinal_rl.models.loss.ReconstructionLoss
weight: ${recon_weight_retina}
- _target_: retinal_rl.classification.loss.ClassificationLoss
weight: ${eval:'1-${recon_weight_retina}'}
target_circuits: # Circuit parameters to optimize with this optimizer. We train the retina and the decoder exclusively to maximize reconstruction
- retina
decode:
min_epoch: 0 # Epoch to start optimizer
max_epoch: 100 # Epoch to stop optimizer
losses: # Weighted optimizer losses as defined in retinal-rl
- _target_: retinal_rl.models.loss.ReconstructionLoss
weight: 1
target_circuits: # Circuit parameters to optimize with this optimizer. We train the retina and the decoder exclusively to maximize reconstruction
- decoder
- inferotemporal_decoder
mixed:
min_epoch: 0
max_epoch: 100
losses:
- _target_: retinal_rl.models.loss.ReconstructionLoss
weight: ${recon_weight_thalamus}
- _target_: retinal_rl.classification.loss.ClassificationLoss
weight: ${eval:'1-${recon_weight_thalamus}'}
target_circuits: # The thalamus is somewhat sensitive to task losses
- thalamus
cortex:
min_epoch: 0
max_epoch: 100
losses:
- _target_: retinal_rl.models.loss.ReconstructionLoss
weight: ${recon_weight_cortex}
- _target_: retinal_rl.classification.loss.ClassificationLoss
weight: ${eval:'1-${recon_weight_cortex}'}
target_circuits: # Visual cortex and downstream layers are driven by the task
- visual_cortex
- inferotemporal
class:
min_epoch: 0
max_epoch: 100
losses:
- _target_: retinal_rl.classification.loss.ClassificationLoss
weight: 1
- _target_: retinal_rl.classification.loss.PercentCorrect
weight: 0
target_circuits: # Visual cortex and downstream layers are driven by the task
- prefrontal
- classifier
# The objective function
objective:
_target_: retinal_rl.models.objective.Objective
losses:
- _target_: retinal_rl.classification.loss.PercentCorrect
- _target_: retinal_rl.classification.loss.ClassificationLoss
target_circuits: # Circuit parameters to optimize with this optimizer. We train the retina and the decoder exclusively to maximize reconstruction
- retina
- thalamus
- visual_cortex
- inferotemporal
- prefrontal
- classifier
weights:
- ${eval:'1-${recon_weight_retina}'}
- ${eval:'1-${recon_weight_thalamus}'}
- ${eval:'1-${recon_weight_cortex}'}
- 1
- 1
- 1
- _target_: retinal_rl.models.loss.ReconstructionLoss
target_circuits: # Circuit parameters to optimize with this optimizer. We train the retina and the decoder exclusively to maximize reconstruction
- retina
- thalamus
- visual_cortex
- decoder
- inferotemporal_decoder
weights:
- ${recon_weight_retina}
- ${recon_weight_thalamus}
- ${recon_weight_cortex}
- 1
- 1
163 changes: 91 additions & 72 deletions retinal_rl/analysis/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,13 @@
from matplotlib.axes import Axes
from matplotlib.figure import Figure
from matplotlib.lines import Line2D
from matplotlib.patches import Circle, Wedge
from matplotlib.ticker import MaxNLocator
from torch import Tensor
from torchvision.utils import make_grid

from retinal_rl.models.brain import Brain
from retinal_rl.models.goal import ContextT, Goal
from retinal_rl.models.objective import ContextT, Objective
from retinal_rl.util import FloatArray


Expand Down Expand Up @@ -107,15 +108,7 @@ def plot_transforms(
return fig


def plot_brain_and_optimizers(brain: Brain, goal: Goal[ContextT]) -> Figure:
"""Visualize the Brain's connectome organized by depth and highlight optimizer targets using border colors.
Args:
----
- brain: The Brain instance
- brain_optimizer: The BrainOptimizer instance
"""
def plot_brain_and_optimizers(brain: Brain, objective: Objective[ContextT]) -> Figure:
graph = brain.connectome

# Compute the depth of each node
Expand All @@ -138,44 +131,102 @@ def plot_brain_and_optimizers(brain: Brain, goal: Goal[ContextT]) -> Figure:
pos[node] = ((i - width / 2) / (width + 1), -(max_depth - depth) / max_depth)

# Set up the plot
fig = plt.figure(figsize=(12, 10))
fig, ax = plt.subplots(figsize=(12, 10))

# Draw edges
nx.draw_networkx_edges(graph, pos, edge_color="gray", arrows=True)
nx.draw_networkx_edges(graph, pos, edge_color="gray", arrows=True, ax=ax)

# Color scheme for different node types
color_map = {"sensor": "lightblue", "circuit": "lightgreen"}

# Generate colors for optimizers
optimizer_colors = sns.color_palette("husl", len(goal.losses))
# Generate colors for losses
loss_colors = sns.color_palette("husl", len(objective.losses))

# Prepare node colors and edge colors
node_colors: List[str] = []
edge_colors: List[Tuple[float, float, float]] = []
# Draw nodes
for node in graph.nodes():
x, y = pos[node]

# Determine node type and base color
if node in brain.sensors:
node_colors.append(color_map["sensor"])
base_color = color_map["sensor"]
else:
node_colors.append(color_map["circuit"])

# Determine if the node is targeted by an optimizer
edge_color = "none"
for i, optimizer_name in enumerate(goal.losses.keys()):
if node in goal.target_circuits[optimizer_name]:
edge_color = optimizer_colors[i]
break
edge_colors.append(edge_color)

# Draw nodes with a single call
nx.draw_networkx_nodes(
graph,
pos,
node_color=node_colors,
edgecolors=edge_colors,
node_size=4000,
linewidths=5,
base_color = color_map["circuit"]

# Draw base circle
circle = Circle((x, y), 0.05, facecolor=base_color, edgecolor="black")
ax.add_patch(circle)

# Determine which losses target this node
targeting_losses = [
loss for loss in objective.losses if node in loss.target_circuits
]

if targeting_losses:
# Calculate angle for each loss
angle_per_loss = 360 / len(targeting_losses)

# Draw a wedge for each targeting loss
for i, loss in enumerate(targeting_losses):
start_angle = i * angle_per_loss
wedge = Wedge(
(x, y),
0.07,
start_angle,
start_angle + angle_per_loss,
width=0.02,
facecolor=loss_colors[objective.losses.index(loss)],
)
ax.add_patch(wedge)

# Draw labels
nx.draw_networkx_labels(graph, pos, font_size=8, font_weight="bold", ax=ax)

# Add a legend for losses
legend_elements = [
Line2D(
[0],
[0],
marker="o",
color="w",
label=f"Loss: {loss.__class__.__name__}",
markerfacecolor=color,
markersize=15,
)
for loss, color in zip(objective.losses, loss_colors)
]

# Add legend elements for sensor and circuit
legend_elements.extend(
[
Line2D(
[0],
[0],
marker="o",
color="w",
label="Sensor",
markerfacecolor=color_map["sensor"],
markersize=15,
),
Line2D(
[0],
[0],
marker="o",
color="w",
label="Circuit",
markerfacecolor=color_map["circuit"],
markersize=15,
),
]
)

plt.legend(handles=legend_elements, loc="center left", bbox_to_anchor=(1, 0.5))

plt.title("Brain Connectome and Loss Targets")
plt.tight_layout()
plt.axis("equal")
plt.axis("off")

return fig
# Draw labels
nx.draw_networkx_labels(graph, pos, font_size=8, font_weight="bold")

Expand All @@ -192,7 +243,7 @@ def plot_brain_and_optimizers(brain: Brain, goal: Goal[ContextT]) -> Figure:
markersize=15,
markeredgewidth=3,
)
for name, color in zip(goal.losses.keys(), optimizer_colors)
for name, color in zip(objective.losses.keys(), optimizer_colors)
]

# Add legend elements for sensor and circuit
Expand Down Expand Up @@ -229,13 +280,7 @@ def plot_brain_and_optimizers(brain: Brain, goal: Goal[ContextT]) -> Figure:


def plot_receptive_field_sizes(results: Dict[str, Dict[str, FloatArray]]) -> Figure:
"""Plot the receptive field sizes for each layer of the convolutional part of the network.
Args:
----
- results: Dictionary containing the results from cnn_statistics function
"""
"""Plot the receptive field sizes for each layer of the convolutional part of the network."""
# Get visual field size from the input shape
input_shape = results["input"]["shape"]
[_, height, width] = list(input_shape)
Expand Down Expand Up @@ -300,17 +345,7 @@ def plot_receptive_field_sizes(results: Dict[str, Dict[str, FloatArray]]) -> Fig


def plot_histories(histories: Dict[str, List[float]]) -> Figure:
"""Plot training and test losses over epochs.
Args:
----
histories (Dict[str, List[float]]): Dictionary containing training and test loss histories.
Returns:
-------
Figure: Matplotlib figure containing the plotted histories.
"""
"""Plot training and test losses over epochs."""
train_metrics = [
key.split("_", 1)[1] for key in histories.keys() if key.startswith("train_")
]
Expand Down Expand Up @@ -467,23 +502,7 @@ def plot_reconstructions(
test_estimates: List[Tuple[Tensor, int]],
num_samples: int,
) -> Figure:
"""Plot original and reconstructed images for both training and test sets, including the classes.
Args:
----
train_sources (List[Tuple[Tensor, int]]): List of original source images and their classes.
train_inputs (List[Tuple[Tensor, int]]): List of original training images and their classes.
train_estimates (List[Tuple[Tensor, int]]): List of reconstructed training images and their predicted classes.
test_sources (List[Tuple[Tensor, int]]): List of original source images and their classes.
test_inputs (List[Tuple[Tensor, int]]): List of original test images and their classes.
test_estimates (List[Tuple[Tensor, int]]): List of reconstructed test images and their predicted classes.
num_samples (int): The number of samples to plot.
Returns:
-------
Figure: The matplotlib Figure object with the plotted images.
"""
"""Plot original and reconstructed images for both training and test sets, including the classes."""
fig, axes = plt.subplots(6, num_samples, figsize=(15, 10))

for i in range(num_samples):
Expand Down
Loading

0 comments on commit 5ea0208

Please sign in to comment.