From 225f70fe963d8c5540ff8fb84d9d9a5adfb85668 Mon Sep 17 00:00:00 2001 From: Lukas Balles Date: Mon, 18 Sep 2023 11:13:27 +0200 Subject: [PATCH 1/3] Add first version of simplified selective ER --- src/renate/cli/parsing_functions.py | 18 +++ .../updaters/experimental/selective_er.py | 144 ++++++++++++++++++ 2 files changed, 162 insertions(+) create mode 100644 src/renate/updaters/experimental/selective_er.py diff --git a/src/renate/cli/parsing_functions.py b/src/renate/cli/parsing_functions.py index 0d94173e..6d784477 100644 --- a/src/renate/cli/parsing_functions.py +++ b/src/renate/cli/parsing_functions.py @@ -28,6 +28,7 @@ ) from renate.updaters.experimental.offline_er import OfflineExperienceReplayModelUpdater from renate.updaters.experimental.repeated_distill import RepeatedDistillationModelUpdater +from renate.updaters.experimental.selective_er import SelectiveExperienceReplayModelUpdater from renate.updaters.model_updater import ModelUpdater REQUIRED_ARGS_GROUP = "Required Arguments" @@ -105,6 +106,9 @@ def get_updater_and_learner_kwargs( elif args.updater == "Offline-ER": learner_args = learner_args + ["loss_weight_new_data", "memory_size", "batch_memory_frac"] updater_class = OfflineExperienceReplayModelUpdater + elif args.updater == "Selective-ER": + learner_args = learner_args + ["subsampling_ratio", "memory_size", "batch_memory_frac"] + updater_class = SelectiveExperienceReplayModelUpdater elif args.updater == "RD": learner_args = learner_args + ["memory_size"] updater_class = RepeatedDistillationModelUpdater @@ -527,6 +531,19 @@ def _add_offline_er_arguments(arguments: Dict[str, Dict[str, Any]]) -> None: ) +def _add_selective_er_arguments(arguments: Dict[str, Dict[str, Any]]) -> None: + _add_replay_learner_arguments(arguments) + arguments.update( + { + "subsampling_ratio": { + "type": float, + "default": None, + "help": "TODO", + } + } + ) + + def _add_experience_replay_arguments(arguments: Dict[str, Dict[str, Any]]) -> None: """A helper function that adds Experience Replay arguments.""" arguments.update( @@ -982,6 +999,7 @@ def get_scheduler_kwargs( "FineTuning": _add_finetuning_arguments, "RD": _add_rd_learner_arguments, "Offline-ER": _add_offline_er_arguments, + "Selective-ER": _add_selective_er_arguments, "Avalanche-ER": _add_experience_replay_arguments, "Avalanche-EWC": _add_avalanche_ewc_learner_arguments, "Avalanche-LwF": _add_avalanche_lwf_learner_arguments, diff --git a/src/renate/updaters/experimental/selective_er.py b/src/renate/updaters/experimental/selective_er.py new file mode 100644 index 00000000..4f53c8e0 --- /dev/null +++ b/src/renate/updaters/experimental/selective_er.py @@ -0,0 +1,144 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +from functools import partial +from typing import Any, Callable, Dict, List, Optional, Tuple + +import torch +import torchmetrics +from pytorch_lightning.loggers.logger import Logger +from pytorch_lightning.utilities.types import STEP_OUTPUT +from torch.nn import Parameter +from torch.optim import Optimizer + +from renate import defaults +from renate.models import RenateModule +from renate.types import NestedTensors +from renate.updaters.experimental.offline_er import OfflineExperienceReplayLearner +from renate.updaters.model_updater import SingleTrainingLoopUpdater +from renate.utils.misc import maybe_populate_mask_and_ignore_logits +from renate.utils.pytorch import cat_nested_tensors, get_length_nested_tensors + + +class SelectiveExperienceReplayLearner(OfflineExperienceReplayLearner): + """(Offline) experience replay with selective backprop + + Args: + TODO + """ + + def __init__(self, subsampling_ratio: float = 0.5, **kwargs) -> None: + super().__init__(**kwargs) + self._effective_batch_size = round(subsampling_ratio * self._batch_size) + if not self._effective_batch_size > 0: + raise ValueError( + f"Subsampling ratio {subsampling_ratio} results in an effective batch size of 0." + "Choose a larger subsampling ratio." + ) + + def training_step(self, batch: Dict[str, Tuple[NestedTensors]], batch_idx: int) -> STEP_OUTPUT: + """PyTorch Lightning function to return the training loss.""" + inputs, targets = batch["current_task"] + batch_size_current = get_length_nested_tensors(inputs) + if "memory" in batch: + (inputs_mem, targets_mem), _ = batch["memory"] + inputs = cat_nested_tensors((inputs, inputs_mem), 0) + targets = torch.cat((targets, targets_mem), 0) + outputs = self(inputs) + + outputs, self._class_mask = maybe_populate_mask_and_ignore_logits( + self._mask_unused_classes, self._class_mask, self._classes_in_current_task, outputs + ) + losses = self._loss_fn(outputs, targets) + # Just for logging. + self._update_metrics(outputs, targets, "train") + loss_current = loss[:batch_size_current].mean() + loss_memory = loss[batch_size_current:].mean() if "memory" in batch else 0.0 + self._loss_collections["train_losses"]["base_loss"](loss_current) + self._loss_collections["train_losses"]["memory_loss"](loss_memory) + # This is used for backprop. + loss = torch.topk(losses, self._effective_batch_size).values.mean() + return {"loss": loss} + + def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: + super().on_save_checkpoint(checkpoint) + checkpoint["num_points_previous_tasks"] = self._num_points_previous_tasks + + def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: + super().on_load_checkpoint(checkpoint) + self._num_points_previous_tasks = checkpoint["num_points_previous_tasks"] + + +class SelectiveExperienceReplayModelUpdater(SingleTrainingLoopUpdater): + def __init__( + self, + model: RenateModule, + loss_fn: torch.nn.Module, + optimizer: Callable[[List[Parameter]], Optimizer], + memory_size: int, + batch_memory_frac: int = defaults.BATCH_MEMORY_FRAC, + subsampling_ratio: float = 0.5, + learning_rate_scheduler: Optional[partial] = None, + learning_rate_scheduler_interval: defaults.SUPPORTED_LR_SCHEDULER_INTERVAL_TYPE = defaults.LR_SCHEDULER_INTERVAL, # noqa: E501 + batch_size: int = defaults.BATCH_SIZE, + input_state_folder: Optional[str] = None, + output_state_folder: Optional[str] = None, + max_epochs: int = defaults.MAX_EPOCHS, + train_transform: Optional[Callable] = None, + train_target_transform: Optional[Callable] = None, + test_transform: Optional[Callable] = None, + test_target_transform: Optional[Callable] = None, + buffer_transform: Optional[Callable] = None, + buffer_target_transform: Optional[Callable] = None, + metric: Optional[str] = None, + mode: defaults.SUPPORTED_TUNING_MODE_TYPE = "min", + logged_metrics: Optional[Dict[str, torchmetrics.Metric]] = None, + early_stopping_enabled: bool = False, + logger: Logger = defaults.LOGGER(**defaults.LOGGER_KWARGS), + accelerator: defaults.SUPPORTED_ACCELERATORS_TYPE = defaults.ACCELERATOR, + devices: Optional[int] = None, + strategy: str = defaults.DISTRIBUTED_STRATEGY, + precision: str = defaults.PRECISION, + seed: int = defaults.SEED, + deterministic_trainer: bool = defaults.DETERMINISTIC_TRAINER, + gradient_clip_val: Optional[float] = defaults.GRADIENT_CLIP_VAL, + gradient_clip_algorithm: Optional[str] = defaults.GRADIENT_CLIP_ALGORITHM, + mask_unused_classes: bool = defaults.MASK_UNUSED_CLASSES, + ): + learner_kwargs = { + "memory_size": memory_size, + "batch_memory_frac": batch_memory_frac, + "subsampling_ratio": subsampling_ratio, + "batch_size": batch_size, + "seed": seed, + } + super().__init__( + model, + loss_fn=loss_fn, + optimizer=optimizer, + learner_class=SelectiveExperienceReplayLearner, + learner_kwargs=learner_kwargs, + input_state_folder=input_state_folder, + output_state_folder=output_state_folder, + max_epochs=max_epochs, + learning_rate_scheduler=learning_rate_scheduler, + learning_rate_scheduler_interval=learning_rate_scheduler_interval, + train_transform=train_transform, + train_target_transform=train_target_transform, + test_transform=test_transform, + test_target_transform=test_target_transform, + buffer_transform=buffer_transform, + buffer_target_transform=buffer_target_transform, + metric=metric, + mode=mode, + logged_metrics=logged_metrics, + early_stopping_enabled=early_stopping_enabled, + logger=logger, + accelerator=accelerator, + devices=devices, + strategy=strategy, + precision=precision, + deterministic_trainer=deterministic_trainer, + gradient_clip_algorithm=gradient_clip_algorithm, + gradient_clip_val=gradient_clip_val, + mask_unused_classes=mask_unused_classes, + ) From f1a27004cd4eaf84d4969a84fd2b126aa6789501 Mon Sep 17 00:00:00 2001 From: Lukas Balles Date: Tue, 19 Sep 2023 11:42:45 +0200 Subject: [PATCH 2/3] Update selective ER --- src/renate/updaters/experimental/selective_er.py | 14 +++----------- 1 file changed, 3 insertions(+), 11 deletions(-) diff --git a/src/renate/updaters/experimental/selective_er.py b/src/renate/updaters/experimental/selective_er.py index 4f53c8e0..2203fbfa 100644 --- a/src/renate/updaters/experimental/selective_er.py +++ b/src/renate/updaters/experimental/selective_er.py @@ -51,22 +51,14 @@ def training_step(self, batch: Dict[str, Tuple[NestedTensors]], batch_idx: int) losses = self._loss_fn(outputs, targets) # Just for logging. self._update_metrics(outputs, targets, "train") - loss_current = loss[:batch_size_current].mean() - loss_memory = loss[batch_size_current:].mean() if "memory" in batch else 0.0 + loss_current = losses[:batch_size_current].mean() + loss_memory = losses[batch_size_current:].mean() if "memory" in batch else 0.0 self._loss_collections["train_losses"]["base_loss"](loss_current) self._loss_collections["train_losses"]["memory_loss"](loss_memory) # This is used for backprop. - loss = torch.topk(losses, self._effective_batch_size).values.mean() + loss = torch.topk(losses, min(len(losses), self._effective_batch_size)).values.mean() return {"loss": loss} - def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: - super().on_save_checkpoint(checkpoint) - checkpoint["num_points_previous_tasks"] = self._num_points_previous_tasks - - def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: - super().on_load_checkpoint(checkpoint) - self._num_points_previous_tasks = checkpoint["num_points_previous_tasks"] - class SelectiveExperienceReplayModelUpdater(SingleTrainingLoopUpdater): def __init__( From ac4ef667e89948d1f04085aaa526841bae59f0b3 Mon Sep 17 00:00:00 2001 From: Lukas Balles Date: Tue, 19 Sep 2023 11:54:23 +0200 Subject: [PATCH 3/3] Add option for different subsampling strategies --- src/renate/cli/parsing_functions.py | 14 ++++++++++++-- src/renate/updaters/experimental/selective_er.py | 14 +++++++++++--- 2 files changed, 23 insertions(+), 5 deletions(-) diff --git a/src/renate/cli/parsing_functions.py b/src/renate/cli/parsing_functions.py index 6d784477..52474473 100644 --- a/src/renate/cli/parsing_functions.py +++ b/src/renate/cli/parsing_functions.py @@ -107,7 +107,12 @@ def get_updater_and_learner_kwargs( learner_args = learner_args + ["loss_weight_new_data", "memory_size", "batch_memory_frac"] updater_class = OfflineExperienceReplayModelUpdater elif args.updater == "Selective-ER": - learner_args = learner_args + ["subsampling_ratio", "memory_size", "batch_memory_frac"] + learner_args = learner_args + [ + "subsampling_ratio", + "subsampling_strategy", + "memory_size", + "batch_memory_frac", + ] updater_class = SelectiveExperienceReplayModelUpdater elif args.updater == "RD": learner_args = learner_args + ["memory_size"] @@ -539,7 +544,12 @@ def _add_selective_er_arguments(arguments: Dict[str, Dict[str, Any]]) -> None: "type": float, "default": None, "help": "TODO", - } + }, + "subsampling_strategy": { + "type": str, + "default": None, + "help": "TODO", + }, } ) diff --git a/src/renate/updaters/experimental/selective_er.py b/src/renate/updaters/experimental/selective_er.py index 2203fbfa..920f5c80 100644 --- a/src/renate/updaters/experimental/selective_er.py +++ b/src/renate/updaters/experimental/selective_er.py @@ -1,7 +1,7 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 from functools import partial -from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import Callable, Dict, List, Optional, Tuple import torch import torchmetrics @@ -26,8 +26,11 @@ class SelectiveExperienceReplayLearner(OfflineExperienceReplayLearner): TODO """ - def __init__(self, subsampling_ratio: float = 0.5, **kwargs) -> None: + def __init__( + self, subsampling_ratio: float = 0.5, subsampling_strategy="loss_topk", **kwargs + ) -> None: super().__init__(**kwargs) + self._subsampling_strategy = subsampling_strategy self._effective_batch_size = round(subsampling_ratio * self._batch_size) if not self._effective_batch_size > 0: raise ValueError( @@ -56,7 +59,10 @@ def training_step(self, batch: Dict[str, Tuple[NestedTensors]], batch_idx: int) self._loss_collections["train_losses"]["base_loss"](loss_current) self._loss_collections["train_losses"]["memory_loss"](loss_memory) # This is used for backprop. - loss = torch.topk(losses, min(len(losses), self._effective_batch_size)).values.mean() + if self._subsampling_strategy == "loss_topk": + loss = torch.topk(losses, min(len(losses), self._effective_batch_size)).values.mean() + else: + raise ValueError(f"Unknown strategy: {self._strategy}.") return {"loss": loss} @@ -69,6 +75,7 @@ def __init__( memory_size: int, batch_memory_frac: int = defaults.BATCH_MEMORY_FRAC, subsampling_ratio: float = 0.5, + subsampling_strategy: str = "loss_topk", learning_rate_scheduler: Optional[partial] = None, learning_rate_scheduler_interval: defaults.SUPPORTED_LR_SCHEDULER_INTERVAL_TYPE = defaults.LR_SCHEDULER_INTERVAL, # noqa: E501 batch_size: int = defaults.BATCH_SIZE, @@ -100,6 +107,7 @@ def __init__( "memory_size": memory_size, "batch_memory_frac": batch_memory_frac, "subsampling_ratio": subsampling_ratio, + "subsampling_strategy": subsampling_strategy, "batch_size": batch_size, "seed": seed, }