From 5721a8bef4d82a684b0dfddb63a29f9ecdc01317 Mon Sep 17 00:00:00 2001 From: Sacha Sokoloski Date: Wed, 16 Oct 2024 17:54:45 +0200 Subject: [PATCH 1/6] Churned a bunch of code for an alpha version. Scan runs --- main.py | 8 +- .../user/optimizer/class-recon.yaml | 70 +++--- retinal_rl/analysis/plot.py | 12 +- retinal_rl/classification/loss.py | 23 +- retinal_rl/classification/training.py | 20 +- retinal_rl/models/goal.py | 210 ------------------ retinal_rl/models/loss.py | 73 ++++-- runner/analyze.py | 8 +- runner/train.py | 32 ++- 9 files changed, 147 insertions(+), 309 deletions(-) delete mode 100644 retinal_rl/models/goal.py diff --git a/main.py b/main.py index 532f154d..8547ba93 100644 --- a/main.py +++ b/main.py @@ -9,10 +9,8 @@ from hydra.utils import instantiate from omegaconf import DictConfig, OmegaConf -from retinal_rl.classification.loss import ClassificationContext from retinal_rl.framework_interface import TrainingFramework from retinal_rl.models.brain import Brain -from retinal_rl.models.goal import Goal from retinal_rl.rl.sample_factory.sf_framework import SFFramework from runner.analyze import analyze from runner.dataset import get_datasets @@ -40,7 +38,7 @@ def _program(cfg: DictConfig): brain = Brain(**cfg.brain).to(device) if hasattr(cfg, "optimizer"): - goal = Goal[ClassificationContext](brain, dict(cfg.optimizer.goal)) + objective = instantiate(cfg.optimizer.losses, brain=brain) optimizer = instantiate(cfg.optimizer.optimizer, brain.parameters()) else: warnings.warn("No optimizer config specified, is that wanted?") @@ -68,7 +66,7 @@ def _program(cfg: DictConfig): cfg, device, brain, - goal, + objective, optimizer, train_set, test_set, @@ -82,7 +80,7 @@ def _program(cfg: DictConfig): cfg, device, brain, - goal, + objective, histories, train_set, test_set, diff --git a/resources/config_templates/user/optimizer/class-recon.yaml b/resources/config_templates/user/optimizer/class-recon.yaml index 8c518a50..1a0b99d9 100644 --- a/resources/config_templates/user/optimizer/class-recon.yaml +++ b/resources/config_templates/user/optimizer/class-recon.yaml @@ -2,55 +2,35 @@ optimizer: # torch.optim Class and parameters _target_: torch.optim.Adam lr: 0.0003 -goal: - recon: - min_epoch: 0 # Epoch to start optimizer - max_epoch: 100 # Epoch to stop optimizer - losses: # Weighted optimizer losses as defined in retinal-rl - - _target_: retinal_rl.models.loss.ReconstructionLoss - weight: ${recon_weight_retina} - - _target_: retinal_rl.classification.loss.ClassificationLoss - weight: ${eval:'1-${recon_weight_retina}'} +losses: + - _target_: retinal_rl.classification.loss.PercentCorrect + - _target_: retinal_rl.classification.loss.ClassificationLoss target_circuits: # Circuit parameters to optimize with this optimizer. We train the retina and the decoder exclusively to maximize reconstruction - retina - decode: - min_epoch: 0 # Epoch to start optimizer - max_epoch: 100 # Epoch to stop optimizer - losses: # Weighted optimizer losses as defined in retinal-rl - - _target_: retinal_rl.models.loss.ReconstructionLoss - weight: 1 - target_circuits: # Circuit parameters to optimize with this optimizer. We train the retina and the decoder exclusively to maximize reconstruction - - decoder - - inferotemporal_decoder - mixed: - min_epoch: 0 - max_epoch: 100 - losses: - - _target_: retinal_rl.models.loss.ReconstructionLoss - weight: ${recon_weight_thalamus} - - _target_: retinal_rl.classification.loss.ClassificationLoss - weight: ${eval:'1-${recon_weight_thalamus}'} - target_circuits: # The thalamus is somewhat sensitive to task losses - thalamus - cortex: - min_epoch: 0 - max_epoch: 100 - losses: - - _target_: retinal_rl.models.loss.ReconstructionLoss - weight: ${recon_weight_cortex} - - _target_: retinal_rl.classification.loss.ClassificationLoss - weight: ${eval:'1-${recon_weight_cortex}'} - target_circuits: # Visual cortex and downstream layers are driven by the task - visual_cortex - inferotemporal - class: - min_epoch: 0 - max_epoch: 100 - losses: - - _target_: retinal_rl.classification.loss.ClassificationLoss - weight: 1 - - _target_: retinal_rl.classification.loss.PercentCorrect - weight: 0 - target_circuits: # Visual cortex and downstream layers are driven by the task - prefrontal - classifier + weights: + - ${eval:'1-${recon_weight_retina}'} + - ${eval:'1-${recon_weight_thalamus}'} + - ${eval:'1-${recon_weight_cortex}'} + - 1 + - 1 + - 1 + - _target_: retinal_rl.models.loss.ReconstructionLoss + min_epoch: 0 # Epoch to start optimizer + max_epoch: 100 # Epoch to stop optimizer + target_circuits: # Circuit parameters to optimize with this optimizer. We train the retina and the decoder exclusively to maximize reconstruction + - retina + - thalamus + - visual_cortex + - decoder + - inferotemporal_decoder + weights: + - ${recon_weight_retina} + - ${recon_weight_thalamus} + - ${recon_weight_cortex} + - 1 + - 1 diff --git a/retinal_rl/analysis/plot.py b/retinal_rl/analysis/plot.py index 1897e12e..413d64ee 100644 --- a/retinal_rl/analysis/plot.py +++ b/retinal_rl/analysis/plot.py @@ -18,7 +18,7 @@ from torchvision.utils import make_grid from retinal_rl.models.brain import Brain -from retinal_rl.models.goal import ContextT, Goal +from retinal_rl.models.objective import ContextT, Objective from retinal_rl.util import FloatArray @@ -107,7 +107,7 @@ def plot_transforms( return fig -def plot_brain_and_optimizers(brain: Brain, goal: Goal[ContextT]) -> Figure: +def plot_brain_and_optimizers(brain: Brain, objective: Objective[ContextT]) -> Figure: """Visualize the Brain's connectome organized by depth and highlight optimizer targets using border colors. Args: @@ -147,7 +147,7 @@ def plot_brain_and_optimizers(brain: Brain, goal: Goal[ContextT]) -> Figure: color_map = {"sensor": "lightblue", "circuit": "lightgreen"} # Generate colors for optimizers - optimizer_colors = sns.color_palette("husl", len(goal.losses)) + optimizer_colors = sns.color_palette("husl", len(objective.losses)) # Prepare node colors and edge colors node_colors: List[str] = [] @@ -160,8 +160,8 @@ def plot_brain_and_optimizers(brain: Brain, goal: Goal[ContextT]) -> Figure: # Determine if the node is targeted by an optimizer edge_color = "none" - for i, optimizer_name in enumerate(goal.losses.keys()): - if node in goal.target_circuits[optimizer_name]: + for i, optimizer_name in enumerate(objective.losses.keys()): + if node in objective.target_circuits[optimizer_name]: edge_color = optimizer_colors[i] break edge_colors.append(edge_color) @@ -192,7 +192,7 @@ def plot_brain_and_optimizers(brain: Brain, goal: Goal[ContextT]) -> Figure: markersize=15, markeredgewidth=3, ) - for name, color in zip(goal.losses.keys(), optimizer_colors) + for name, color in zip(objective.losses.keys(), optimizer_colors) ] # Add legend elements for sensor and circuit diff --git a/retinal_rl/classification/loss.py b/retinal_rl/classification/loss.py index 0061b3a7..8fcbca31 100644 --- a/retinal_rl/classification/loss.py +++ b/retinal_rl/classification/loss.py @@ -1,6 +1,6 @@ """Objectives for training models.""" -from typing import Dict, Tuple +from typing import Dict, List, Tuple import torch import torch.nn as nn @@ -38,9 +38,15 @@ def __init__( class ClassificationLoss(Loss[ClassificationContext]): """Loss for computing the cross entropy loss.""" - def __init__(self, weight: float = 1.0): + def __init__( + self, + min_epoch: int = 0, + max_epoch: int = 1, + target_circuits: List[str] = [], + weights: List[float] = [], + ): """Initialize the classification loss.""" - super().__init__(weight) + super().__init__(min_epoch, max_epoch, target_circuits, weights) self.loss_fn = nn.CrossEntropyLoss() def compute_value(self, context: ClassificationContext) -> Tensor: @@ -59,9 +65,14 @@ def compute_value(self, context: ClassificationContext) -> Tensor: class PercentCorrect(Loss[ClassificationContext]): """(Inverse) Loss for computing the percent correct classification.""" - def __init__(self, weight: float = 1.0): - """Initialize the percent correct classification loss.""" - super().__init__(weight) + def __init__( + self, + min_epoch: int = 0, + max_epoch: int = 1, + target_circuits: List[str] = [], + weights: List[float] = [], + ): + super().__init__(min_epoch, max_epoch, target_circuits, weights) def compute_value(self, context: ClassificationContext) -> Tensor: """Compute the percent correct classification.""" diff --git a/retinal_rl/classification/training.py b/retinal_rl/classification/training.py index 5fa5849d..ce1bf104 100644 --- a/retinal_rl/classification/training.py +++ b/retinal_rl/classification/training.py @@ -18,7 +18,7 @@ get_classification_context, ) from retinal_rl.models.brain import Brain -from retinal_rl.models.goal import Goal +from retinal_rl.models.objective import Objective logger = logging.getLogger(__name__) @@ -26,7 +26,7 @@ def run_epoch( device: torch.device, brain: Brain, - goal: Goal[ClassificationContext], + objective: Objective[ClassificationContext], optimizer: Optimizer, history: Dict[str, List[float]], epoch: int, @@ -43,7 +43,7 @@ def run_epoch( ---- device (torch.device): The device to run the computations on. brain (Brain): The Brain model to train and evaluate. - goal (Goal): The goal object specifying the training objectives. + objective (Objective): The objective object specifying the training objectives. optimizer (Optimizer): The optimizer for updating the model parameters. history (Dict[str, List[float]]): A dictionary to store the training history. epoch (int): The current epoch number. @@ -56,10 +56,10 @@ def run_epoch( """ train_losses = process_dataset( - device, brain, goal, optimizer, epoch, trainloader, is_training=True + device, brain, objective, optimizer, epoch, trainloader, is_training=True ) test_losses = process_dataset( - device, brain, goal, optimizer, epoch, testloader, is_training=False + device, brain, objective, optimizer, epoch, testloader, is_training=False ) # Update history @@ -76,7 +76,7 @@ def run_epoch( def process_dataset( device: torch.device, brain: Brain, - goal: Goal[ClassificationContext], + objective: Objective[ClassificationContext], optimizer: Optimizer, epoch: int, dataloader: DataLoader[Tuple[Tensor, Tensor, int]], @@ -109,20 +109,20 @@ def process_dataset( if is_training: brain.train() - losses, obj_dict = goal.backward(context) + losses = objective.backward(context) optimizer.step() optimizer.zero_grad(set_to_none=True) else: with torch.no_grad(): brain.eval() - losses, obj_dict = goal.evaluate_objectives(context) + losses: Dict[str, float] = {} + for loss in objective.losses: + losses[loss.key_name] = loss(context).item() # Accumulate losses and objectives for key, value in losses.items(): total_losses[key] = total_losses.get(key, 0.0) + value - for key, value in obj_dict.items(): - total_losses[key] = total_losses.get(key, 0.0) + value steps += 1 diff --git a/retinal_rl/models/goal.py b/retinal_rl/models/goal.py deleted file mode 100644 index c90d0f50..00000000 --- a/retinal_rl/models/goal.py +++ /dev/null @@ -1,210 +0,0 @@ -"""Module for managing optimization of complex neural network models with multiple circuits.""" - -import logging -from typing import Dict, Generic, List, Tuple - -import torch -from hydra.utils import instantiate -from omegaconf import DictConfig -from torch.nn.parameter import Parameter - -from retinal_rl.models.brain import Brain -from retinal_rl.models.loss import ContextT, Loss - -logger = logging.getLogger(__name__) - - -class Goal(Generic[ContextT]): - """Manages multiple optimizers that target NeuralCircuits in a Brain. - - This class handles the initialization, state management, and optimization steps - for multiple optimizers, each associated with specific circuits and objectives. - - - Attributes - ---------- - brain (Brain): The neural network model being optimized. - losses (OrderedDict[str, Optimizer]): Instantiated optimizers, sorted based on connectome. - objectives (Dict[str, List[WeightedLoss]]): Losses for each optimizer. - target_circuits (Dict[str, List[str]]): Target circuits for each optimizer. - min_epochs (Dict[str, int]): Minimum epochs for each optimizer. - max_epochs (Dict[str, int]): Maximum epochs for each optimizer. - - """ - - def __init__(self, brain: Brain, objective_configs: Dict[str, DictConfig]): - """Initialize the BrainOptimizer. - - Args: - ---- - brain (Brain): The neural network model to optimize. - optimizer (Optimizer): The optimizer to use for training. - objective_configs (Dict[str, DictConfig]): Configuration for each optimizer. - Each config should specify target_circuits, optimizer settings, and objectives. - - Raises: - ------ - ValueError: If a specified circuit is not found in the brain. - - """ - self.device = next(brain.parameters()).device - self.losses: Dict[str, List[Loss[ContextT]]] = {} - self.target_circuits: Dict[str, List[str]] = {} - self.min_epochs: Dict[str, int] = {} - self.max_epochs: Dict[str, int] = {} - self.params: Dict[str, List[Parameter]] = {} - - for objective, config in objective_configs.items(): - # Collect parameters from target circuits - params = [] - self.min_epochs[objective] = config.get("min_epoch", 0) - self.max_epochs[objective] = config.get("max_epoch", -1) - self.target_circuits[objective] = config.target_circuits - if not set(config.target_circuits).issubset(brain.connectome.nodes): - logger.warning( - f"Some target circuits for objective: {objective} are not in the brain's connectome" - ) - for circuit_name in config.target_circuits: - if circuit_name in brain.circuits: - params.extend(brain.circuits[circuit_name].parameters()) - - self.params[objective] = params - - # Initialize objectives - self.losses[objective] = [ - instantiate(obj_config) for obj_config in config.losses - ] - logger.info( - f"Initialized objective: {objective}, with losses: {[obj.key_name for obj in self.losses[objective]]}, and target circuits: {[circuit_name for circuit_name in self.target_circuits[objective]]}" - ) - - def evaluate_objective( - self, objective: str, context: ContextT - ) -> Tuple[torch.Tensor, Dict[str, float]]: - """Compute the total loss for a specific objective. - - Args: - ---- - objective (str): Name of the objective. - context (ContextT]): Context information for computing objectives. - - Returns: - ------- - Tuple[torch.Tensor, Dict[str, float]]: A tuple containing the total loss - and a dictionary of raw loss values for each objective. - - """ - total_loss = torch.tensor(0.0, device=self.device) - loss_dict: Dict[str, float] = {} - for loss in self.losses[objective]: - weighted_loss, raw_loss = loss(context) - total_loss += weighted_loss - loss_dict[loss.key_name] = raw_loss.item() - return total_loss, loss_dict - - def evaluate_objectives( - self, context: ContextT - ) -> Tuple[Dict[str, float], Dict[str, float]]: - """Compute all objectives without computing gradients. - - This method is useful for evaluation purposes. - - Args: - ---- - context (Dict[str, Any]): Context information for computing objectives. - - Returns: - ------- - Tuple[Dict[str, float], Dict[str, float]]: A tuple containing dictionaries - of total objectives and raw loss values for each objective. - - """ - objectives: Dict[str, float] = {} - loss_dict: Dict[str, float] = {} - for objective in self.losses.keys(): - loss, sub_obj_dict = self.evaluate_objective(objective, context) - objectives[f"{objective}_objective"] = loss.item() - loss_dict.update(sub_obj_dict) - return objectives, loss_dict - - def _is_training_epoch(self, name: str, epoch: int) -> bool: - """Check if the objective should currently be pursued. - - Args: - ---- - name (str): Name of the optimizer. - epoch (int): Current epoch number. - - Returns: - ------- - bool: True if the objective should continue training, False otherwise. - - """ - if epoch < self.min_epochs[name]: - return False - if self.max_epochs[name] < 0: - return True - return epoch < self.max_epochs[name] - - def backward(self, context: ContextT) -> Tuple[Dict[str, float], Dict[str, float]]: - """Compute a backward pass over the brain with respect to all objectives. - - This method computes losses, performs backpropagation, and updates parameters - for all NeuralCircuits. - - Args: - ---- - context (ContextT): Context information for computing objectives. - - Returns: - ------- - Tuple[Dict[str, float], Dict[str, float]]: A tuple containing dictionaries - of total losses and raw loss values for each objective. - - """ - objectives: Dict[str, float] = {} - loss_dict: Dict[str, float] = {} - - retain_graph = True - - for i, objective in enumerate(self.losses.keys()): - # Compute losses - loss, sub_loss_dict = self.evaluate_objective(objective, context) - objectives[f"{objective}_objective"] = loss.item() - loss_dict.update(sub_loss_dict) - - # Skip training if the optimizer is not at a training epoch - if not self._is_training_epoch(objective, context.epoch): - continue - - # Set retain_graph to True for all but the last optimizer - retain_graph = i < len(self.losses) - 1 - - # Get parameters for this optimizer - params = self.params[objective] - - # Compute gradients - grads = torch.autograd.grad( - loss, params, create_graph=False, retain_graph=retain_graph - ) - - # Manually update parameters - with torch.no_grad(): - for param, grad in zip(params, grads): - if param.grad is None: - param.grad = grad - else: - param.grad += grad - - # Perform optimization step - return objectives, loss_dict - - def num_epochs(self) -> int: - """Get the maximum number of epochs over all optimizers. - - Returns - ------- - int: The maximum number of epochs across all optimizers. - - """ - return max(self.max_epochs.values()) diff --git a/retinal_rl/models/loss.py b/retinal_rl/models/loss.py index 4417d0ec..6a0e7073 100644 --- a/retinal_rl/models/loss.py +++ b/retinal_rl/models/loss.py @@ -1,7 +1,7 @@ """Losses for training models, and the context required to evaluate them.""" from abc import abstractmethod -from typing import Dict, Generic, List, Tuple, TypeVar +from typing import Dict, Generic, List, TypeVar import torch import torch.nn as nn @@ -41,11 +41,20 @@ def __init__( class Loss(Generic[ContextT]): """Base class for losses that can be used to define a multiobjective optimization problem.""" - def __init__(self, weight: float = 1.0): + def __init__( + self, + min_epoch: int = 0, + max_epoch: int = 1, + target_circuits: List[str] = [], + weights: List[float] = [], + ): """Initialize the loss with a weight.""" - self.weight = weight + self.min_epoch = min_epoch + self.max_epoch = max_epoch + self.target_circuits = target_circuits + self.weights = weights - def __call__(self, context: ContextT) -> Tuple[Tensor, Tensor]: + def __call__(self, context: ContextT) -> Tensor: """Compute the weighted loss for this loss. Args: @@ -54,11 +63,28 @@ def __call__(self, context: ContextT) -> Tuple[Tensor, Tensor]: Returns: ------- - Tuple[Tensor, Tensor]: A tuple containing the weighted loss and the raw loss value. + Tensor: A tuple containing the weighted loss and the raw loss value. """ - value = self.compute_value(context) - return (self.weight * value, value) + return self.compute_value(context) + + def is_training_epoch(self, epoch: int) -> bool: + """Check if the objective should currently be pursued. + + Args: + ---- + epoch (int): Current epoch number. + + Returns: + ------- + bool: True if the objective should continue training, False otherwise. + + """ + if epoch < self.min_epoch: + return False + if self.max_epoch < 0: + return True + return epoch < self.max_epoch @abstractmethod def compute_value(self, context: ContextT) -> Tensor: @@ -74,9 +100,15 @@ def key_name(self) -> str: class ReconstructionLoss(Loss[ContextT]): """Loss for computing the reconstruction loss between inputs and reconstructions.""" - def __init__(self, weight: float = 1.0): + def __init__( + self, + min_epoch: int = 0, + max_epoch: int = 1, + target_circuits: List[str] = [], + weights: List[float] = [], + ): """Initialize the reconstruction loss loss.""" - super().__init__(weight) + super().__init__(min_epoch, max_epoch, target_circuits, weights) self.loss_fn = nn.MSELoss(reduction="mean") def compute_value(self, context: ContextT) -> Tensor: @@ -95,10 +127,19 @@ def compute_value(self, context: ContextT) -> Tensor: class L1Sparsity(Loss[ContextT]): """Loss for computing the L1 sparsity of activations.""" - def __init__(self, weight: float, target_responses: List[str]): + def __init__( + self, + target_responses: List[str], + min_epoch: int = 0, + max_epoch: int = 1, + target_circuits: List[str] = [], + weights: List[float] = [], + ): + """Initialize the reconstruction loss loss.""" + super().__init__(min_epoch, max_epoch, target_circuits, weights) + """Initialize the L1 sparsity loss.""" self.target_responses = target_responses - super().__init__(weight) def compute_value(self, context: ContextT) -> Tensor: """Compute the L1 sparsity of activations.""" @@ -115,12 +156,18 @@ class KLDivergenceSparsity(Loss[ContextT]): """Loss for computing the KL divergence sparsity of activations.""" def __init__( - self, weight: float, target_responses: List[str], target_sparsity: float = 0.05 + self, + target_responses: List[str], + target_sparsity: float = 0.05, + min_epoch: int = 0, + max_epoch: int = 1, + target_circuits: List[str] = [], + weights: List[float] = [], ): """Initialize the KL divergence sparsity loss.""" + super().__init__(min_epoch, max_epoch, target_circuits, weights) self.target_responses = target_responses self.target_sparsity = target_sparsity - super().__init__(weight) def compute_value(self, context: ContextT) -> torch.Tensor: """Compute the KL divergence sparsity of activations.""" diff --git a/runner/analyze.py b/runner/analyze.py index e1503df5..b3c97d0e 100644 --- a/runner/analyze.py +++ b/runner/analyze.py @@ -5,10 +5,10 @@ import matplotlib.pyplot as plt import torch +import wandb from matplotlib.figure import Figure from omegaconf import DictConfig -import wandb from retinal_rl.analysis.plot import ( layer_receptive_field_plots, plot_brain_and_optimizers, @@ -25,7 +25,7 @@ ) from retinal_rl.dataset import Imageset from retinal_rl.models.brain import Brain -from retinal_rl.models.goal import ContextT, Goal +from retinal_rl.models.objective import ContextT, Objective logger = logging.getLogger(__name__) @@ -91,7 +91,7 @@ def analyze( cfg: DictConfig, device: torch.device, brain: Brain, - goal: Goal[ContextT], + objective: Objective[ContextT], histories: Dict[str, List[float]], train_set: Imageset, test_set: Imageset, @@ -111,7 +111,7 @@ def analyze( if epoch == 0: rf_sizes_fig = plot_receptive_field_sizes(cnn_analysis) _process_figure(cfg, False, rf_sizes_fig, init_dir, "receptive_field_sizes", 0) - graph_fig = plot_brain_and_optimizers(brain, goal) + graph_fig = plot_brain_and_optimizers(brain, objective) _process_figure(cfg, False, graph_fig, init_dir, "brain_graph", 0) transforms = transform_base_images(train_set, num_steps=5, num_images=2) transforms_fig = plot_transforms(**transforms) diff --git a/runner/train.py b/runner/train.py index 264324d6..f673e835 100644 --- a/runner/train.py +++ b/runner/train.py @@ -5,16 +5,16 @@ from typing import Dict, List import torch +import wandb from omegaconf import DictConfig from torch.optim.optimizer import Optimizer from torch.utils.data import DataLoader -import wandb from retinal_rl.classification.loss import ClassificationContext from retinal_rl.classification.training import process_dataset, run_epoch from retinal_rl.dataset import Imageset from retinal_rl.models.brain import Brain -from retinal_rl.models.goal import Goal +from retinal_rl.models.objective import Objective from runner.analyze import analyze from runner.util import save_checkpoint @@ -26,7 +26,7 @@ def train( cfg: DictConfig, device: torch.device, brain: Brain, - goal: Goal[ClassificationContext], + objective: Objective[ClassificationContext], optimizer: Optimizer, train_set: Imageset, test_set: Imageset, @@ -40,7 +40,7 @@ def train( cfg (DictConfig): The configuration for the experiment. device (torch.device): The device to run the computations on. brain (Brain): The Brain model to train and evaluate. - goal (Goal): The optimizer for updating the model parameters. + objective (Objective): The optimizer for updating the model parameters. train_set (Imageset): The training dataset. test_set (Imageset): The test dataset. initial_epoch (int): The epoch to start training from. @@ -56,11 +56,23 @@ def train( if initial_epoch == 0: brain.train() train_losses = process_dataset( - device, brain, goal, optimizer, initial_epoch, trainloader, is_training=False + device, + brain, + objective, + optimizer, + initial_epoch, + trainloader, + is_training=False, ) brain.eval() test_losses = process_dataset( - device, brain, goal, optimizer, initial_epoch, testloader, is_training=False + device, + brain, + objective, + optimizer, + initial_epoch, + testloader, + is_training=False, ) # Initialize the history @@ -75,7 +87,7 @@ def train( cfg, device, brain, - goal, + objective, history, train_set, test_set, @@ -88,11 +100,11 @@ def train( logger.info("Initialization complete.") - for epoch in range(initial_epoch + 1, goal.num_epochs() + 1): + for epoch in range(initial_epoch + 1, objective.num_epochs() + 1): brain, history = run_epoch( device, brain, - goal, + objective, optimizer, history, epoch, @@ -122,7 +134,7 @@ def train( cfg, device, brain, - goal, + objective, history, train_set, test_set, From 8e175142fa04c7531c551eec0eb36fa0d437958c Mon Sep 17 00:00:00 2001 From: Sacha Sokoloski Date: Wed, 16 Oct 2024 17:57:35 +0200 Subject: [PATCH 2/6] Added missing new module --- retinal_rl/models/objective.py | 101 +++++++++++++++++++++++++++++++++ 1 file changed, 101 insertions(+) create mode 100644 retinal_rl/models/objective.py diff --git a/retinal_rl/models/objective.py b/retinal_rl/models/objective.py new file mode 100644 index 00000000..47cbed06 --- /dev/null +++ b/retinal_rl/models/objective.py @@ -0,0 +1,101 @@ +"""Module for managing optimization of complex neural network models with multiple circuits.""" + +import logging +from typing import Dict, Generic, List + +import torch +from torch.nn.parameter import Parameter +from torch.optim.optimizer import Optimizer + +from retinal_rl.models.brain import Brain +from retinal_rl.models.loss import ContextT, Loss + +logger = logging.getLogger(__name__) + + +class Objective(Generic[ContextT]): + """Manages multiple loss functions that target NeuralCircuits in a Brain. + + This class handles the initialization, state management, and optimization steps + for multiple optimizers, each associated with specific circuits and objectives. + + + Attributes + ---------- + brain (Brain): The neural network model being optimized. + losses (OrderedDict[str, Optimizer]): Instantiated optimizers, sorted based on connectome. + + """ + + def __init__(self, brain: Brain, optimizer: Optimizer, losses: List[Loss[ContextT]]): + """Initialize the BrainOptimizer. + + Args: + ---- + brain (Brain): The neural network model to optimize. + optimizer (Optimizer): The optimizer to use for training. + losses (List[Loss[ContextT]]): A list of loss functions to optimize. + + Raises: + ------ + ValueError: If a specified circuit is not found in the brain. + + """ + self.device = next(brain.parameters()).device + self.optimizer = optimizer + self.losses: List[Loss[ContextT]] = losses + self.params: List[List[Parameter]] = [] + + for loss in self.losses: + # Collect parameters from target circuits + params: List[Parameter] = [] + for circuit_name in loss.target_circuits: + if circuit_name in brain.circuits: + params.extend(brain.circuits[circuit_name].parameters()) + + self.params.append(params) + + def backward(self, context: ContextT) -> Dict[str, float]: + loss_dict: Dict[str, float] = {} + + retain_graph = True + + for i, (loss, params) in enumerate(zip(self.losses, self.params)): + # Compute losses + name = loss.key_name + weights = loss.weights + value = loss(context) + loss_dict[name] = value.item() + + # Skip training if the optimizer is not at a training epoch + if not loss.is_training_epoch(context.epoch): + continue + + # Set retain_graph to True for all but the last optimizer + retain_graph = i < len(self.losses) - 1 + + # Compute gradients + grads = torch.autograd.grad( + value, params, create_graph=False, retain_graph=retain_graph + ) + + # Manually update parameters + with torch.no_grad(): + for param, weight, grad in zip(params, weights, grads): + if param.grad is None: + param.grad = weight * grad + else: + param.grad += weight * grad + + # Perform optimization step + return loss_dict + + def num_epochs(self) -> int: + """Get the maximum number of epochs over all optimizers. + + Returns + ------- + int: The maximum number of epochs across all losses. + + """ + return max([loss.max_epoch for loss in self.losses]) From bd8238106e8b19ffe0acedba7bac029f5c525794 Mon Sep 17 00:00:00 2001 From: alex404 Date: Thu, 17 Oct 2024 08:12:30 +0200 Subject: [PATCH 3/6] Turning off warnings for doc strings. I'm officially declaring docstrings optional. A good type system should do most of the work, and doc strings fill in the rest (where necessary). --- pyproject.toml | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index a4aa92be..aad653ca 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,7 +20,10 @@ select = [ "I", # Import conventions ] -ignore = ["E501"] # Example: Ignore line length warnings +ignore = [ + "E501", # Example: Ignore line length warnings + "D", # Ignore all docstring-related warnings +] [tool.ruff.format] docstring-code-format = true From b80c0eb6c8c8c01cf424333886e883199b1a62cd Mon Sep 17 00:00:00 2001 From: alex404 Date: Thu, 17 Oct 2024 10:25:39 +0200 Subject: [PATCH 4/6] Initial bugs resolved. Seems to be running. Going to work on updating plots. Also here and there stripping about excessive documentation. --- main.py | 2 +- .../user/optimizer/class-recon.yaml | 69 ++++++++++--------- retinal_rl/analysis/plot.py | 48 ++----------- retinal_rl/classification/loss.py | 4 +- retinal_rl/models/loss.py | 30 ++++---- retinal_rl/models/objective.py | 51 ++------------ runner/analyze.py | 7 +- runner/train.py | 13 ++-- 8 files changed, 75 insertions(+), 149 deletions(-) diff --git a/main.py b/main.py index 8547ba93..46de613b 100644 --- a/main.py +++ b/main.py @@ -38,8 +38,8 @@ def _program(cfg: DictConfig): brain = Brain(**cfg.brain).to(device) if hasattr(cfg, "optimizer"): - objective = instantiate(cfg.optimizer.losses, brain=brain) optimizer = instantiate(cfg.optimizer.optimizer, brain.parameters()) + objective = instantiate(cfg.optimizer.objective, brain=brain) else: warnings.warn("No optimizer config specified, is that wanted?") diff --git a/resources/config_templates/user/optimizer/class-recon.yaml b/resources/config_templates/user/optimizer/class-recon.yaml index 1a0b99d9..c0e0a9cb 100644 --- a/resources/config_templates/user/optimizer/class-recon.yaml +++ b/resources/config_templates/user/optimizer/class-recon.yaml @@ -1,36 +1,41 @@ +# Number of training epochs +num_epochs: 100 + +# The optimizer to use optimizer: # torch.optim Class and parameters _target_: torch.optim.Adam lr: 0.0003 -losses: - - _target_: retinal_rl.classification.loss.PercentCorrect - - _target_: retinal_rl.classification.loss.ClassificationLoss - target_circuits: # Circuit parameters to optimize with this optimizer. We train the retina and the decoder exclusively to maximize reconstruction - - retina - - thalamus - - visual_cortex - - inferotemporal - - prefrontal - - classifier - weights: - - ${eval:'1-${recon_weight_retina}'} - - ${eval:'1-${recon_weight_thalamus}'} - - ${eval:'1-${recon_weight_cortex}'} - - 1 - - 1 - - 1 - - _target_: retinal_rl.models.loss.ReconstructionLoss - min_epoch: 0 # Epoch to start optimizer - max_epoch: 100 # Epoch to stop optimizer - target_circuits: # Circuit parameters to optimize with this optimizer. We train the retina and the decoder exclusively to maximize reconstruction - - retina - - thalamus - - visual_cortex - - decoder - - inferotemporal_decoder - weights: - - ${recon_weight_retina} - - ${recon_weight_thalamus} - - ${recon_weight_cortex} - - 1 - - 1 +# The objective function +objective: + _target_: retinal_rl.models.objective.Objective + losses: + - _target_: retinal_rl.classification.loss.PercentCorrect + - _target_: retinal_rl.classification.loss.ClassificationLoss + target_circuits: # Circuit parameters to optimize with this optimizer. We train the retina and the decoder exclusively to maximize reconstruction + - retina + - thalamus + - visual_cortex + - inferotemporal + - prefrontal + - classifier + weights: + - ${eval:'1-${recon_weight_retina}'} + - ${eval:'1-${recon_weight_thalamus}'} + - ${eval:'1-${recon_weight_cortex}'} + - 1 + - 1 + - 1 + - _target_: retinal_rl.models.loss.ReconstructionLoss + target_circuits: # Circuit parameters to optimize with this optimizer. We train the retina and the decoder exclusively to maximize reconstruction + - retina + - thalamus + - visual_cortex + - decoder + - inferotemporal_decoder + weights: + - ${recon_weight_retina} + - ${recon_weight_thalamus} + - ${recon_weight_cortex} + - 1 + - 1 diff --git a/retinal_rl/analysis/plot.py b/retinal_rl/analysis/plot.py index 413d64ee..57847230 100644 --- a/retinal_rl/analysis/plot.py +++ b/retinal_rl/analysis/plot.py @@ -108,14 +108,6 @@ def plot_transforms( def plot_brain_and_optimizers(brain: Brain, objective: Objective[ContextT]) -> Figure: - """Visualize the Brain's connectome organized by depth and highlight optimizer targets using border colors. - - Args: - ---- - - brain: The Brain instance - - brain_optimizer: The BrainOptimizer instance - - """ graph = brain.connectome # Compute the depth of each node @@ -146,7 +138,7 @@ def plot_brain_and_optimizers(brain: Brain, objective: Objective[ContextT]) -> F # Color scheme for different node types color_map = {"sensor": "lightblue", "circuit": "lightgreen"} - # Generate colors for optimizers + # Generate colors for losses optimizer_colors = sns.color_palette("husl", len(objective.losses)) # Prepare node colors and edge colors @@ -229,13 +221,7 @@ def plot_brain_and_optimizers(brain: Brain, objective: Objective[ContextT]) -> F def plot_receptive_field_sizes(results: Dict[str, Dict[str, FloatArray]]) -> Figure: - """Plot the receptive field sizes for each layer of the convolutional part of the network. - - Args: - ---- - - results: Dictionary containing the results from cnn_statistics function - - """ + """Plot the receptive field sizes for each layer of the convolutional part of the network.""" # Get visual field size from the input shape input_shape = results["input"]["shape"] [_, height, width] = list(input_shape) @@ -300,17 +286,7 @@ def plot_receptive_field_sizes(results: Dict[str, Dict[str, FloatArray]]) -> Fig def plot_histories(histories: Dict[str, List[float]]) -> Figure: - """Plot training and test losses over epochs. - - Args: - ---- - histories (Dict[str, List[float]]): Dictionary containing training and test loss histories. - - Returns: - ------- - Figure: Matplotlib figure containing the plotted histories. - - """ + """Plot training and test losses over epochs.""" train_metrics = [ key.split("_", 1)[1] for key in histories.keys() if key.startswith("train_") ] @@ -467,23 +443,7 @@ def plot_reconstructions( test_estimates: List[Tuple[Tensor, int]], num_samples: int, ) -> Figure: - """Plot original and reconstructed images for both training and test sets, including the classes. - - Args: - ---- - train_sources (List[Tuple[Tensor, int]]): List of original source images and their classes. - train_inputs (List[Tuple[Tensor, int]]): List of original training images and their classes. - train_estimates (List[Tuple[Tensor, int]]): List of reconstructed training images and their predicted classes. - test_sources (List[Tuple[Tensor, int]]): List of original source images and their classes. - test_inputs (List[Tuple[Tensor, int]]): List of original test images and their classes. - test_estimates (List[Tuple[Tensor, int]]): List of reconstructed test images and their predicted classes. - num_samples (int): The number of samples to plot. - - Returns: - ------- - Figure: The matplotlib Figure object with the plotted images. - - """ + """Plot original and reconstructed images for both training and test sets, including the classes.""" fig, axes = plt.subplots(6, num_samples, figsize=(15, 10)) for i in range(num_samples): diff --git a/retinal_rl/classification/loss.py b/retinal_rl/classification/loss.py index 8fcbca31..5627a93a 100644 --- a/retinal_rl/classification/loss.py +++ b/retinal_rl/classification/loss.py @@ -41,7 +41,7 @@ class ClassificationLoss(Loss[ClassificationContext]): def __init__( self, min_epoch: int = 0, - max_epoch: int = 1, + max_epoch: int = -1, target_circuits: List[str] = [], weights: List[float] = [], ): @@ -68,7 +68,7 @@ class PercentCorrect(Loss[ClassificationContext]): def __init__( self, min_epoch: int = 0, - max_epoch: int = 1, + max_epoch: int = -1, target_circuits: List[str] = [], weights: List[float] = [], ): diff --git a/retinal_rl/models/loss.py b/retinal_rl/models/loss.py index 6a0e7073..0c16be3d 100644 --- a/retinal_rl/models/loss.py +++ b/retinal_rl/models/loss.py @@ -39,12 +39,21 @@ def __init__( class Loss(Generic[ContextT]): - """Base class for losses that can be used to define a multiobjective optimization problem.""" + """Base class for losses that can be used to define a multiobjective optimization problem. + + Attributes + ---------- + min_epoch (int): The minimum epoch to start training the loss. + max_epoch (int): The maximum epoch to train the loss. Unbounded if < 0. + target_circuits (List[str]): The target circuits for the loss. + weights (List[float]): The weights for the loss. + + """ def __init__( self, min_epoch: int = 0, - max_epoch: int = 1, + max_epoch: int = -1, target_circuits: List[str] = [], weights: List[float] = [], ): @@ -55,17 +64,6 @@ def __init__( self.weights = weights def __call__(self, context: ContextT) -> Tensor: - """Compute the weighted loss for this loss. - - Args: - ---- - context (ContextT): Context information for computing losses. - - Returns: - ------- - Tensor: A tuple containing the weighted loss and the raw loss value. - - """ return self.compute_value(context) def is_training_epoch(self, epoch: int) -> bool: @@ -103,7 +101,7 @@ class ReconstructionLoss(Loss[ContextT]): def __init__( self, min_epoch: int = 0, - max_epoch: int = 1, + max_epoch: int = -1, target_circuits: List[str] = [], weights: List[float] = [], ): @@ -131,7 +129,7 @@ def __init__( self, target_responses: List[str], min_epoch: int = 0, - max_epoch: int = 1, + max_epoch: int = -1, target_circuits: List[str] = [], weights: List[float] = [], ): @@ -160,7 +158,7 @@ def __init__( target_responses: List[str], target_sparsity: float = 0.05, min_epoch: int = 0, - max_epoch: int = 1, + max_epoch: int = -1, target_circuits: List[str] = [], weights: List[float] = [], ): diff --git a/retinal_rl/models/objective.py b/retinal_rl/models/objective.py index 47cbed06..a2cae818 100644 --- a/retinal_rl/models/objective.py +++ b/retinal_rl/models/objective.py @@ -5,7 +5,6 @@ import torch from torch.nn.parameter import Parameter -from torch.optim.optimizer import Optimizer from retinal_rl.models.brain import Brain from retinal_rl.models.loss import ContextT, Loss @@ -14,37 +13,10 @@ class Objective(Generic[ContextT]): - """Manages multiple loss functions that target NeuralCircuits in a Brain. - - This class handles the initialization, state management, and optimization steps - for multiple optimizers, each associated with specific circuits and objectives. - - - Attributes - ---------- - brain (Brain): The neural network model being optimized. - losses (OrderedDict[str, Optimizer]): Instantiated optimizers, sorted based on connectome. - - """ - - def __init__(self, brain: Brain, optimizer: Optimizer, losses: List[Loss[ContextT]]): - """Initialize the BrainOptimizer. - - Args: - ---- - brain (Brain): The neural network model to optimize. - optimizer (Optimizer): The optimizer to use for training. - losses (List[Loss[ContextT]]): A list of loss functions to optimize. - - Raises: - ------ - ValueError: If a specified circuit is not found in the brain. - - """ + def __init__(self, brain: Brain, losses: List[Loss[ContextT]]): self.device = next(brain.parameters()).device - self.optimizer = optimizer self.losses: List[Loss[ContextT]] = losses - self.params: List[List[Parameter]] = [] + self.paramss: List[List[Parameter]] = [] for loss in self.losses: # Collect parameters from target circuits @@ -52,23 +24,20 @@ def __init__(self, brain: Brain, optimizer: Optimizer, losses: List[Loss[Context for circuit_name in loss.target_circuits: if circuit_name in brain.circuits: params.extend(brain.circuits[circuit_name].parameters()) - - self.params.append(params) + self.paramss.append(params) def backward(self, context: ContextT) -> Dict[str, float]: loss_dict: Dict[str, float] = {} retain_graph = True - for i, (loss, params) in enumerate(zip(self.losses, self.params)): + for i, (loss, params) in enumerate(zip(self.losses, self.paramss)): # Compute losses name = loss.key_name weights = loss.weights value = loss(context) loss_dict[name] = value.item() - - # Skip training if the optimizer is not at a training epoch - if not loss.is_training_epoch(context.epoch): + if not loss.is_training_epoch(context.epoch) or not params: continue # Set retain_graph to True for all but the last optimizer @@ -89,13 +58,3 @@ def backward(self, context: ContextT) -> Dict[str, float]: # Perform optimization step return loss_dict - - def num_epochs(self) -> int: - """Get the maximum number of epochs over all optimizers. - - Returns - ------- - int: The maximum number of epochs across all losses. - - """ - return max([loss.max_epoch for loss in self.losses]) diff --git a/runner/analyze.py b/runner/analyze.py index b3c97d0e..24d9a92d 100644 --- a/runner/analyze.py +++ b/runner/analyze.py @@ -5,13 +5,12 @@ import matplotlib.pyplot as plt import torch -import wandb from matplotlib.figure import Figure from omegaconf import DictConfig +import wandb from retinal_rl.analysis.plot import ( layer_receptive_field_plots, - plot_brain_and_optimizers, plot_channel_statistics, plot_histories, plot_receptive_field_sizes, @@ -111,8 +110,8 @@ def analyze( if epoch == 0: rf_sizes_fig = plot_receptive_field_sizes(cnn_analysis) _process_figure(cfg, False, rf_sizes_fig, init_dir, "receptive_field_sizes", 0) - graph_fig = plot_brain_and_optimizers(brain, objective) - _process_figure(cfg, False, graph_fig, init_dir, "brain_graph", 0) + # graph_fig = plot_brain_and_optimizers(brain, objective) + # _process_figure(cfg, False, graph_fig, init_dir, "brain_graph", 0) transforms = transform_base_images(train_set, num_steps=5, num_images=2) transforms_fig = plot_transforms(**transforms) _process_figure(cfg, False, transforms_fig, init_dir, "transforms", 0) diff --git a/runner/train.py b/runner/train.py index f673e835..b9e9b12c 100644 --- a/runner/train.py +++ b/runner/train.py @@ -5,11 +5,11 @@ from typing import Dict, List import torch -import wandb from omegaconf import DictConfig from torch.optim.optimizer import Optimizer from torch.utils.data import DataLoader +import wandb from retinal_rl.classification.loss import ClassificationContext from retinal_rl.classification.training import process_dataset, run_epoch from retinal_rl.dataset import Imageset @@ -51,7 +51,6 @@ def train( testloader = DataLoader(test_set, batch_size=64, shuffle=False) wall_time = time.time() - epoch_wall_time = 0 if initial_epoch == 0: brain.train() @@ -95,12 +94,18 @@ def train( True, ) + new_wall_time = time.time() + epoch_wall_time = new_wall_time - wall_time + wall_time = new_wall_time + logger.info("Initialization complete. Wall Time: {epoch_wall_time:.2f}s.") + if cfg.use_wandb: _wandb_log_statistics(initial_epoch, epoch_wall_time, history) - logger.info("Initialization complete.") + else: + logger.info(f"Reloading complete. Resuming training from epoch {initial_epoch}.") - for epoch in range(initial_epoch + 1, objective.num_epochs() + 1): + for epoch in range(initial_epoch + 1, cfg.optimizer.num_epochs + 1): brain, history = run_epoch( device, brain, From f32f74120a91ab36575d9cdf3e497a1031de2492 Mon Sep 17 00:00:00 2001 From: alex404 Date: Thu, 17 Oct 2024 15:47:58 +0200 Subject: [PATCH 5/6] More bugs resolved. Seems to actually run now, and be more efficient. Trying to see if I can do something useful with the graph plot. --- retinal_rl/analysis/plot.py | 108 +++++++++++++++++++++++++-------- retinal_rl/models/objective.py | 30 +++++---- runner/analyze.py | 5 +- runner/train.py | 2 +- 4 files changed, 106 insertions(+), 39 deletions(-) diff --git a/retinal_rl/analysis/plot.py b/retinal_rl/analysis/plot.py index 57847230..259deaec 100644 --- a/retinal_rl/analysis/plot.py +++ b/retinal_rl/analysis/plot.py @@ -13,6 +13,7 @@ from matplotlib.axes import Axes from matplotlib.figure import Figure from matplotlib.lines import Line2D +from matplotlib.patches import Circle, Wedge from matplotlib.ticker import MaxNLocator from torch import Tensor from torchvision.utils import make_grid @@ -130,44 +131,101 @@ def plot_brain_and_optimizers(brain: Brain, objective: Objective[ContextT]) -> F pos[node] = ((i - width / 2) / (width + 1), -(max_depth - depth) / max_depth) # Set up the plot - fig = plt.figure(figsize=(12, 10)) + fig, ax = plt.subplots(figsize=(12, 10)) # Draw edges - nx.draw_networkx_edges(graph, pos, edge_color="gray", arrows=True) + nx.draw_networkx_edges(graph, pos, edge_color="gray", arrows=True, ax=ax) # Color scheme for different node types color_map = {"sensor": "lightblue", "circuit": "lightgreen"} # Generate colors for losses - optimizer_colors = sns.color_palette("husl", len(objective.losses)) + loss_colors = sns.color_palette("husl", len(objective.losses)) - # Prepare node colors and edge colors - node_colors: List[str] = [] - edge_colors: List[Tuple[float, float, float]] = [] + # Draw nodes for node in graph.nodes(): + x, y = pos[node] + + # Determine node type and base color if node in brain.sensors: - node_colors.append(color_map["sensor"]) + base_color = color_map["sensor"] else: - node_colors.append(color_map["circuit"]) - - # Determine if the node is targeted by an optimizer - edge_color = "none" - for i, optimizer_name in enumerate(objective.losses.keys()): - if node in objective.target_circuits[optimizer_name]: - edge_color = optimizer_colors[i] - break - edge_colors.append(edge_color) - - # Draw nodes with a single call - nx.draw_networkx_nodes( - graph, - pos, - node_color=node_colors, - edgecolors=edge_colors, - node_size=4000, - linewidths=5, + base_color = color_map["circuit"] + + # Draw base circle + circle = Circle((x, y), 0.05, facecolor=base_color, edgecolor="black") + ax.add_patch(circle) + + # Determine which losses target this node + targeting_losses = [ + loss for loss in objective.losses if node in loss.target_circuits + ] + + if targeting_losses: + # Calculate angle for each loss + angle_per_loss = 360 / len(targeting_losses) + + # Draw a wedge for each targeting loss + for i, loss in enumerate(targeting_losses): + start_angle = i * angle_per_loss + wedge = Wedge( + (x, y), + 0.07, + start_angle, + start_angle + angle_per_loss, + width=0.02, + facecolor=loss_colors[objective.losses.index(loss)], + ) + ax.add_patch(wedge) + + # Draw labels + nx.draw_networkx_labels(graph, pos, font_size=8, font_weight="bold", ax=ax) + + # Add a legend for losses + legend_elements = [ + Line2D( + [0], + [0], + marker="o", + color="w", + label=f"Loss: {loss.__class__.__name__}", + markerfacecolor=color, + markersize=15, + ) + for loss, color in zip(objective.losses, loss_colors) + ] + + # Add legend elements for sensor and circuit + legend_elements.extend( + [ + Line2D( + [0], + [0], + marker="o", + color="w", + label="Sensor", + markerfacecolor=color_map["sensor"], + markersize=15, + ), + Line2D( + [0], + [0], + marker="o", + color="w", + label="Circuit", + markerfacecolor=color_map["circuit"], + markersize=15, + ), + ] ) + plt.legend(handles=legend_elements, loc="center left", bbox_to_anchor=(1, 0.5)) + + plt.title("Brain Connectome and Loss Targets") + plt.tight_layout() + plt.axis("off") + + return fig # Draw labels nx.draw_networkx_labels(graph, pos, font_size=8, font_weight="bold") diff --git a/retinal_rl/models/objective.py b/retinal_rl/models/objective.py index a2cae818..2eaa06d6 100644 --- a/retinal_rl/models/objective.py +++ b/retinal_rl/models/objective.py @@ -1,7 +1,7 @@ """Module for managing optimization of complex neural network models with multiple circuits.""" import logging -from typing import Dict, Generic, List +from typing import Dict, Generic, List, Tuple import torch from torch.nn.parameter import Parameter @@ -16,25 +16,20 @@ class Objective(Generic[ContextT]): def __init__(self, brain: Brain, losses: List[Loss[ContextT]]): self.device = next(brain.parameters()).device self.losses: List[Loss[ContextT]] = losses - self.paramss: List[List[Parameter]] = [] + self.brain: Brain = brain - for loss in self.losses: - # Collect parameters from target circuits - params: List[Parameter] = [] - for circuit_name in loss.target_circuits: - if circuit_name in brain.circuits: - params.extend(brain.circuits[circuit_name].parameters()) - self.paramss.append(params) + # Build a dictionary of weighted parameters for each loss + # TODO: If the parameters() list of a neural circuit changes dynamically, this will break def backward(self, context: ContextT) -> Dict[str, float]: loss_dict: Dict[str, float] = {} retain_graph = True - for i, (loss, params) in enumerate(zip(self.losses, self.paramss)): + for i, loss in enumerate(self.losses): # Compute losses + weights, params = self._weighted_params(loss) name = loss.key_name - weights = loss.weights value = loss(context) loss_dict[name] = value.item() if not loss.is_training_epoch(context.epoch) or not params: @@ -58,3 +53,16 @@ def backward(self, context: ContextT) -> Dict[str, float]: # Perform optimization step return loss_dict + + def _weighted_params( + self, loss: Loss[ContextT] + ) -> Tuple[List[float], List[Parameter]]: + weights: List[float] = [] + params: List[Parameter] = [] + for weight, circuit_name in zip(loss.weights, loss.target_circuits): + if circuit_name in self.brain.circuits: + params0 = list(self.brain.circuits[circuit_name].parameters()) + weights += [weight] * len(params0) + params += params0 + + return weights, params diff --git a/runner/analyze.py b/runner/analyze.py index 24d9a92d..e522d826 100644 --- a/runner/analyze.py +++ b/runner/analyze.py @@ -11,6 +11,7 @@ import wandb from retinal_rl.analysis.plot import ( layer_receptive_field_plots, + plot_brain_and_optimizers, plot_channel_statistics, plot_histories, plot_receptive_field_sizes, @@ -110,8 +111,8 @@ def analyze( if epoch == 0: rf_sizes_fig = plot_receptive_field_sizes(cnn_analysis) _process_figure(cfg, False, rf_sizes_fig, init_dir, "receptive_field_sizes", 0) - # graph_fig = plot_brain_and_optimizers(brain, objective) - # _process_figure(cfg, False, graph_fig, init_dir, "brain_graph", 0) + graph_fig = plot_brain_and_optimizers(brain, objective) + _process_figure(cfg, False, graph_fig, init_dir, "brain_graph", 0) transforms = transform_base_images(train_set, num_steps=5, num_images=2) transforms_fig = plot_transforms(**transforms) _process_figure(cfg, False, transforms_fig, init_dir, "transforms", 0) diff --git a/runner/train.py b/runner/train.py index b9e9b12c..75876d90 100644 --- a/runner/train.py +++ b/runner/train.py @@ -97,7 +97,7 @@ def train( new_wall_time = time.time() epoch_wall_time = new_wall_time - wall_time wall_time = new_wall_time - logger.info("Initialization complete. Wall Time: {epoch_wall_time:.2f}s.") + logger.info(f"Initialization complete. Wall Time: {epoch_wall_time:.2f}s.") if cfg.use_wandb: _wandb_log_statistics(initial_epoch, epoch_wall_time, history) From 2c44b6a27da9c3430f2b42e69ca825619b02fc54 Mon Sep 17 00:00:00 2001 From: alex404 Date: Thu, 17 Oct 2024 15:58:11 +0200 Subject: [PATCH 6/6] Brain graph plot updated --- retinal_rl/analysis/plot.py | 1 + 1 file changed, 1 insertion(+) diff --git a/retinal_rl/analysis/plot.py b/retinal_rl/analysis/plot.py index 259deaec..eae79c78 100644 --- a/retinal_rl/analysis/plot.py +++ b/retinal_rl/analysis/plot.py @@ -223,6 +223,7 @@ def plot_brain_and_optimizers(brain: Brain, objective: Objective[ContextT]) -> F plt.title("Brain Connectome and Loss Targets") plt.tight_layout() + plt.axis("equal") plt.axis("off") return fig