Skip to content

Commit

Permalink
Merge pull request #64 from berenslab/separate-log-statistics
Browse files Browse the repository at this point in the history
Separate log statistics
  • Loading branch information
alex404 authored Dec 5, 2024
2 parents e199a6f + d065e18 commit a197daa
Show file tree
Hide file tree
Showing 7 changed files with 212 additions and 56 deletions.
3 changes: 2 additions & 1 deletion resources/config_templates/user/optimizer/class-recon.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ optimizer: # torch.optim Class and parameters
objective:
_target_: retinal_rl.models.objective.Objective
losses:
- _target_: retinal_rl.classification.loss.PercentCorrect
- _target_: retinal_rl.classification.loss.ClassificationLoss
target_circuits: # Circuit parameters to optimize with this optimizer. We train the retina and the decoder exclusively to maximize reconstruction
- retina
Expand Down Expand Up @@ -55,3 +54,5 @@ objective:
- ${sparsity_weight}
- ${sparsity_weight}
- ${sparsity_weight}
logging_statistics:
- _target_: retinal_rl.classification.loss.PercentCorrect
17 changes: 4 additions & 13 deletions retinal_rl/classification/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from torch import Tensor, nn

from retinal_rl.models.brain import Brain
from retinal_rl.models.loss import BaseContext, Loss
from retinal_rl.models.loss import BaseContext, LoggingStatistic, Loss


class ClassificationContext(BaseContext):
Expand Down Expand Up @@ -39,8 +39,8 @@ class ClassificationLoss(Loss[ClassificationContext]):

def __init__(
self,
target_circuits: List[str] = [],
weights: List[float] = [],
target_circuits: Optional[List[str]] = None,
weights: Optional[List[float]] = None,
min_epoch: Optional[int] = None,
max_epoch: Optional[int] = None,
):
Expand All @@ -61,18 +61,9 @@ def compute_value(self, context: ClassificationContext) -> Tensor:
return self.loss_fn(predictions, classes)


class PercentCorrect(Loss[ClassificationContext]):
class PercentCorrect(LoggingStatistic[ClassificationContext]):
"""(Inverse) Loss for computing the percent correct classification."""

def __init__(
self,
target_circuits: List[str] = [],
weights: List[float] = [],
min_epoch: Optional[int] = None,
max_epoch: Optional[int] = None,
):
super().__init__(target_circuits, weights, min_epoch, max_epoch)

def compute_value(self, context: ClassificationContext) -> Tensor:
"""Compute the percent correct classification."""
predictions = context.responses["classifier"]
Expand Down
55 changes: 32 additions & 23 deletions retinal_rl/models/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,29 @@ def __init__(
self.epoch = epoch


class Loss(Generic[ContextT]):
class LoggingStatistic(Generic[ContextT]):
"""Base class for statistics that should be logged."""

def __call__(self, context: ContextT) -> Tensor:
return self.compute_value(context)

@abstractmethod
def compute_value(self, context: ContextT) -> Tensor:
"""Compute the value for this losses The context dictionary contains the necessary information to compute the loss."""
pass

@property
def key_name(self) -> str:
"""Return a user-friendly name for the loss."""
return camel_to_snake(self.__class__.__name__)


class Loss(LoggingStatistic[ContextT]):
"""Base class for losses that can be used to define a multiobjective optimization problem.
Attributes
----------
target_circuits (List[str]): The target circuits for the loss.
target_circuits (List[str]): The target circuits for the loss. If '__all__', will target all circuits
weights (List[float]): The weights for the loss.
min_epoch (int): The minimum epoch to start training the loss.
max_epoch (int): The maximum epoch to train the loss. Unbounded if < 0.
Expand All @@ -51,20 +68,22 @@ class Loss(Generic[ContextT]):

def __init__(
self,
target_circuits: List[str] = [],
weights: List[float] = [],
target_circuits: Optional[List[str]] = None,
weights: Optional[List[float]] = None,
min_epoch: Optional[int] = None,
max_epoch: Optional[int] = None,
):
"""Initialize the loss with a weight."""
if target_circuits is None:
target_circuits = []
if weights is None:
weights = [1]

self.target_circuits = target_circuits
self.weights = weights
self.min_epoch = min_epoch
self.max_epoch = max_epoch

def __call__(self, context: ContextT) -> Tensor:
return self.compute_value(context)

def is_training_epoch(self, epoch: int) -> bool:
"""Check if the objective should currently be pursued.
Expand All @@ -81,25 +100,15 @@ def is_training_epoch(self, epoch: int) -> bool:
return False
return self.max_epoch is None or epoch <= self.max_epoch

@abstractmethod
def compute_value(self, context: ContextT) -> Tensor:
"""Compute the value for this losses The context dictionary contains the necessary information to compute the loss."""
pass

@property
def key_name(self) -> str:
"""Return a user-friendly name for the loss."""
return camel_to_snake(self.__class__.__name__)


class ReconstructionLoss(Loss[ContextT]):
"""Loss for computing the reconstruction loss between inputs and reconstructions."""

def __init__(
self,
target_decoder: str,
target_circuits: List[str] = [],
weights: List[float] = [],
target_circuits: Optional[List[str]] = None,
weights: Optional[List[float]] = None,
min_epoch: Optional[int] = None,
max_epoch: Optional[int] = None,
):
Expand Down Expand Up @@ -132,8 +141,8 @@ class L1Sparsity(Loss[ContextT]):
def __init__(
self,
target_response: str,
target_circuits: List[str] = [],
weights: List[float] = [],
target_circuits: Optional[List[str]] = None,
weights: Optional[List[float]] = None,
min_epoch: Optional[int] = None,
max_epoch: Optional[int] = None,
):
Expand Down Expand Up @@ -163,8 +172,8 @@ def __init__(
self,
target_response: str,
target_sparsity: float = 0.05,
target_circuits: List[str] = [],
weights: List[float] = [],
target_circuits: Optional[List[str]] = None,
weights: Optional[List[float]] = None,
min_epoch: Optional[int] = None,
max_epoch: Optional[int] = None,
):
Expand Down
44 changes: 37 additions & 7 deletions retinal_rl/models/objective.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,38 @@
"""Module for managing optimization of complex neural network models with multiple circuits."""

import logging
from typing import Dict, Generic, List, Tuple
from typing import Dict, Generic, List, Optional, Tuple

import torch
from torch.nn.parameter import Parameter

from retinal_rl.models.brain import Brain
from retinal_rl.models.loss import ContextT, Loss
from retinal_rl.models.loss import ContextT, LoggingStatistic, Loss

logger = logging.getLogger(__name__)


class Objective(Generic[ContextT]):
def __init__(self, brain: Brain, losses: List[Loss[ContextT]]):
def __init__(
self,
brain: Brain,
losses: List[Loss[ContextT]],
logging_statistics: Optional[List[LoggingStatistic[ContextT]]] = None,
):
if logging_statistics is None:
logging_statistics = []

for loss in losses:
assert isinstance(loss, Loss), "losses need to subclass Loss"

for stat in logging_statistics:
assert isinstance(
stat, LoggingStatistic
), "logging_statistics need to subclass LoggingStatistic"

self.device = next(brain.parameters()).device
self.losses: List[Loss[ContextT]] = losses
self.losses = losses
self.logging_statistics = logging_statistics
self.brain: Brain = brain

# Build a dictionary of weighted parameters for each loss
Expand All @@ -26,12 +43,16 @@ def backward(self, context: ContextT) -> Dict[str, float]:

retain_graph = True

for i, stat in enumerate(self.logging_statistics):
loss_dict[stat.key_name] = stat(context).item()

for i, loss in enumerate(self.losses):
# Compute losses
weights, params = self._weighted_params(loss)
name = loss.key_name
value = loss(context)
loss_dict[name] = value.item()

# Compute losses
weights, params = self._weighted_params(loss)
if not loss.is_training_epoch(context.epoch) or not params:
continue

Expand All @@ -57,9 +78,18 @@ def backward(self, context: ContextT) -> Dict[str, float]:
def _weighted_params(
self, loss: Loss[ContextT]
) -> Tuple[List[float], List[Parameter]]:
_targets = loss.target_circuits
_weights = loss.weights

if "__all__" in _targets:
_targets = self.brain.circuits.keys()
if len(_weights) == 1:
_weights = [_weights[0] for _ in range(len(_targets))]
assert len(_weights) == len(_targets)

weights: List[float] = []
params: List[Parameter] = []
for weight, circuit_name in zip(loss.weights, loss.target_circuits):
for weight, circuit_name in zip(_weights, _targets):
if circuit_name in self.brain.circuits:
params0 = list(self.brain.circuits[circuit_name].parameters())
weights += [weight] * len(params0)
Expand Down
30 changes: 23 additions & 7 deletions tests/modules/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import shutil
import sys
import time
from typing import Generator

import hydra
import pytest
Expand All @@ -13,10 +14,9 @@
OmegaConf.register_new_resolver("eval", eval)


@pytest.fixture
def config() -> DictConfig:
# TODO: make this independent of whether templates are in the right place or not etc
def config(experiment: str) -> DictConfig:
with hydra.initialize(config_path="../../config/base", version_base=None):
experiment = "gathering-apples"
config = hydra.compose(
"config", overrides=[f"+experiment={experiment}", "system.device=cpu"]
)
Expand All @@ -35,11 +35,27 @@ def config() -> DictConfig:
), "hydra: values can not be resolved here. Set them manually in this fixture for tests!"

OmegaConf.resolve(config)
yield config
return config


def cleanup(config: DictConfig):
# Cleanup: remove temporary dir
if os.path.exists(config.path.run_dir):
shutil.rmtree(config.path.run_dir)


# Cleanup: remove temporary dir
if os.path.exists(config.path.run_dir):
shutil.rmtree(config.path.run_dir)
@pytest.fixture
def classification_config() -> Generator[DictConfig, None, None]:
_config = config("classification")
yield _config
cleanup(_config)


@pytest.fixture
def rl_config() -> Generator[DictConfig, None, None]:
_config = config("gathering-apples")
yield _config
cleanup(_config)


@pytest.fixture
Expand Down
Loading

0 comments on commit a197daa

Please sign in to comment.