diff --git a/docs/source-pytorch/advanced/training_tricks.rst b/docs/source-pytorch/advanced/training_tricks.rst index 25dd996c628a4..23c1c67b8734e 100644 --- a/docs/source-pytorch/advanced/training_tricks.rst +++ b/docs/source-pytorch/advanced/training_tricks.rst @@ -50,23 +50,44 @@ Read more about :ref:`Configuring Gradient Clipping `__ by the PyTorch team. +Lightning provides two callbacks to facilitate weight averaging. :class:`~lightning.pytorch.callbacks.WeightAveraging` +is a generic callback that wraps the +`AveragedModel `__ class from +PyTorch. It allows SWA, EMA, or a custom averaging strategy to be used and it can be customized to run at specific steps +or epochs. -.. seealso:: The :class:`~lightning.pytorch.callbacks.StochasticWeightAveraging` callback +The older :class:`~lightning.pytorch.callbacks.StochasticWeightAveraging` callback is specific to SWA. It starts the SWA +procedure after a certain number of epochs and always runs on every epoch. Additionally, it switches to a constant +learning rate schedule (`SWALR `__) when the +procedure starts. + +.. seealso:: + For a more detailed explanation of SWA and how it works, read + `this post `__ by the PyTorch team. .. testcode:: - # Enable Stochastic Weight Averaging using the callback - trainer = Trainer(callbacks=[StochasticWeightAveraging(swa_lrs=1e-2)]) + from lightning.pytorch.callbacks import StochasticWeightAveraging, WeightAveraging + from torch.optim.swa_utils import get_ema_avg_fn + + # Enable Exponential Moving Average after 100 steps + class EMAWeightAveraging(WeightAveraging): + def __init__(self): + super().__init__(avg_fn=get_ema_avg_fn()) + def should_update(self, step_idx=None, epoch_idx=None): + return (step_idx is not None) and (step_idx >= 100) + trainer = Trainer(callbacks=EMAWeightAveraging()) + + # Enable Stochastic Weight Averaging after 10 epochs with learning rate 0.01 + trainer = Trainer(callbacks=StochasticWeightAveraging(swa_epoch_start=10, swa_lrs=0.01)) ---------- diff --git a/docs/source-pytorch/api_references.rst b/docs/source-pytorch/api_references.rst index 1f58f6ac23dd5..278cc98ef5547 100644 --- a/docs/source-pytorch/api_references.rst +++ b/docs/source-pytorch/api_references.rst @@ -48,6 +48,7 @@ callbacks ThroughputMonitor Timer TQDMProgressBar + WeightAveraging cli ----- diff --git a/docs/source-pytorch/extensions/callbacks.rst b/docs/source-pytorch/extensions/callbacks.rst index c2a621f8b6d7b..7ed285591c4dc 100644 --- a/docs/source-pytorch/extensions/callbacks.rst +++ b/docs/source-pytorch/extensions/callbacks.rst @@ -83,6 +83,7 @@ Lightning has a few built-in callbacks. StochasticWeightAveraging Timer TQDMProgressBar + WeightAveraging ---------- diff --git a/docs/source-pytorch/glossary/index.rst b/docs/source-pytorch/glossary/index.rst index 6b5e4b12b307f..333ef9834ef84 100644 --- a/docs/source-pytorch/glossary/index.rst +++ b/docs/source-pytorch/glossary/index.rst @@ -42,13 +42,13 @@ Strategy registry <../advanced/strategy_registry> Strategy integrations <../integrations/strategies/index> Style guide <../starter/style_guide> - SWA <../advanced/training_tricks> SLURM <../clouds/cluster_advanced> Tensor Parallel <../advanced/model_parallel/tp> Transfer learning <../advanced/transfer_learning> Trainer <../common/trainer> TorchRun (TorchElastic) <../clouds/cluster_intermediate_2> Warnings <../advanced/warnings> + Weight averaging <../advanced/training_tricks> ######## @@ -326,13 +326,6 @@ Glossary :button_link: ../starter/style_guide.html :height: 100 -.. displayitem:: - :header: SWA - :description: Stochastic Weight Averaging (SWA) can make your models generalize better - :col_css: col-md-12 - :button_link: ../advanced/training_tricks.html#stochastic-weight-averaging - :height: 100 - .. displayitem:: :header: SLURM :description: Simple Linux Utility for Resource Management, or simply Slurm, is a free and open-source job scheduler for Linux clusters @@ -375,6 +368,13 @@ Glossary :button_link: ../advanced/warnings.html :height: 100 +.. displayitem:: + :header: Weight averaging + :description: Stochastic Weight Averaging (SWA) or Exponential Moving Average (EMA) can make your models generalize better + :col_css: col-md-12 + :button_link: ../advanced/training_tricks.html#weight-averaging + :height: 100 + .. raw:: html diff --git a/docs/source-pytorch/model/build_model_intermediate.rst b/docs/source-pytorch/model/build_model_intermediate.rst index 82362af7ecc83..8a56d20947334 100644 --- a/docs/source-pytorch/model/build_model_intermediate.rst +++ b/docs/source-pytorch/model/build_model_intermediate.rst @@ -27,7 +27,7 @@ Enable advanced training features using Trainer arguments. These are SOTA techni ) # access the latest state of the art techniques - trainer = Trainer(callbacks=[StochasticWeightAveraging(...)]) + trainer = Trainer(callbacks=[WeightAveraging(...)]) ---- diff --git a/docs/source-pytorch/starter/introduction.rst b/docs/source-pytorch/starter/introduction.rst index 8e55afb907aab..ecdda6ac1c53f 100644 --- a/docs/source-pytorch/starter/introduction.rst +++ b/docs/source-pytorch/starter/introduction.rst @@ -252,7 +252,7 @@ Enable advanced training features using Trainer arguments. These are state-of-th ) # access the latest state of the art techniques - trainer = L.Trainer(callbacks=[StochasticWeightAveraging(...)]) + trainer = L.Trainer(callbacks=[WeightAveraging(...)]) ---- diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index e3b847490a4c6..fa086dd798210 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -9,6 +9,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- WeightAveraging callback that wraps the PyTorch AveragedModel class ([#20545](https://github.com/Lightning-AI/pytorch-lightning/pull/20545)) - Add enable_autolog_hparams argument to Trainer ([#20593](https://github.com/Lightning-AI/pytorch-lightning/pull/20593)) diff --git a/src/lightning/pytorch/callbacks/__init__.py b/src/lightning/pytorch/callbacks/__init__.py index 9ee34f3866b27..d0ffb7b6a990c 100644 --- a/src/lightning/pytorch/callbacks/__init__.py +++ b/src/lightning/pytorch/callbacks/__init__.py @@ -32,6 +32,7 @@ from lightning.pytorch.callbacks.stochastic_weight_avg import StochasticWeightAveraging from lightning.pytorch.callbacks.throughput_monitor import ThroughputMonitor from lightning.pytorch.callbacks.timer import Timer +from lightning.pytorch.callbacks.weight_averaging import WeightAveraging __all__ = [ "BackboneFinetuning", @@ -58,4 +59,5 @@ "ThroughputMonitor", "Timer", "TQDMProgressBar", + "WeightAveraging", ] diff --git a/src/lightning/pytorch/callbacks/stochastic_weight_avg.py b/src/lightning/pytorch/callbacks/stochastic_weight_avg.py index 375bd15f29051..79c5423c54084 100644 --- a/src/lightning/pytorch/callbacks/stochastic_weight_avg.py +++ b/src/lightning/pytorch/callbacks/stochastic_weight_avg.py @@ -65,7 +65,7 @@ def __init__( .. warning:: ``StochasticWeightAveraging`` is currently only supported on every epoch. - See also how to :ref:`enable it directly on the Trainer ` + See also how to :ref:`enable it directly on the Trainer `. Arguments: diff --git a/src/lightning/pytorch/callbacks/weight_averaging.py b/src/lightning/pytorch/callbacks/weight_averaging.py new file mode 100644 index 0000000000000..a983b32a1a161 --- /dev/null +++ b/src/lightning/pytorch/callbacks/weight_averaging.py @@ -0,0 +1,361 @@ +# Copyright The Lightning AI team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +r""" +Weight Averaging Callback +^^^^^^^^^^^^^^^^^^^^^^^^^ +""" + +import itertools +from copy import deepcopy +from typing import Any, Optional, Union + +import torch +from torch.optim.swa_utils import AveragedModel +from typing_extensions import override + +import lightning.pytorch as pl +from lightning.pytorch.callbacks.callback import Callback +from lightning.pytorch.utilities.model_helpers import is_overridden +from lightning.pytorch.utilities.rank_zero import rank_zero_info, rank_zero_warn +from lightning.pytorch.utilities.types import STEP_OUTPUT + + +class WeightAveraging(Callback): + r"""A callback that updates an averaged model for Stochastic Weight Averaging (SWA) or Exponential Moving Average + (EMA) after each training step. + + Arguments given to the constructor will be passed to the :class:`AveragedModel` constructor. There are a couple of + differences to the default values, however. By default, the average model is stored on the CPU. If ``device`` is set + to ``None``, the device will be inferred from the original model. By default, the callback will compute running + averages for both the parameters and the buffers of the model. Setting ``use_buffers`` to ``False`` will cause only + the model parameters to be averaged, leaving updating the batch normalization statistics to the user (using + ``torch.optim.swa_utils.update_bn()``). + + You can provide a custom averaging function with the ``avg_fn`` or ``multi_avg_fn`` parameter. See the + :class:`AveragedModel` class for details. If no averaging function is provided, the default is to compute the + equally-weighted average of the weights (SWA). + + You can customize when the average model is updated by overriding the ``should_update()`` method. The callback calls + it with either ``step_idx`` or ``epoch_idx`` and the method returns a boolean indicating whether to update after the + given step or epoch. The default is to update after every step. + + During validation and after the training finishes, the current model parameters will be replaced with the averaged + values. + + See also the documentation on the :ref:`weight averaging callbacks ` + provided by Lightning. + + Note: + To ensure that the :class:`AveragedModel` will contain all layers, ``setup()`` will call + :meth:`~lightning.pytorch.core.hooks.ModelHooks.configure_model` before instantiating the + :class:`AveragedModel`. However, that hook is not called in a strategy aware context, sharded models do not work + with weight averaging, and a warning will be issued. + + Example:: + + from lightning.pytorch.callbacks import WeightAveraging + from torch.optim.swa_utils import get_ema_avg_fn + + class EMAWeightAveraging(WeightAveraging): + def __init__(self): + super().__init__(avg_fn=get_ema_avg_fn()) + + def should_update(self, step_idx=None, epoch_idx=None): + # Start after 100 steps. + return (step_idx is not None) and (step_idx >= 100) + + trainer = Trainer(callbacks=EMAWeightAveraging(), max_epochs=10) + trainer.fit(model, dataloader) + + Args: + device: If provided, the :class:`AveragedModel` will be stored on the ``device``. If ``None`` the device will be + inferred from the original model. + use_buffers: If ``False``, the buffers of the model will not be averaged. + kwargs: Additional keyword arguments to be passed to the :class:`AveragedModel` constructor, such as ``avg_fn`` + or ``multi_avg_fn``. + + """ + + def __init__( + self, + device: Optional[Union[torch.device, str, int]] = "cpu", + use_buffers: bool = True, + **kwargs: Any, + ) -> None: + # The default value is a string so that jsonargparse knows how to serialize it. + if isinstance(device, str): + self._device: Optional[Union[torch.device, int]] = torch.device(device) + else: + self._device = device + self._use_buffers = use_buffers + self._kwargs = kwargs + + self._average_model: Optional[AveragedModel] = None + + # Number of optimizer steps taken, when the average model was last updated. Initializing this with zero ensures + # that self.should_update() will be first called after the first optimizer step, which takes place after N + # batches when using accumulate_grad_batches=N. + self._latest_update_step = 0 + # The epoch after which the average model was last updated. The first epoch is 0, so initializing this to a + # negative value means that if self.should_update(epoch_idx=0) returns True, the first update is after the first + # epoch. + self._latest_update_epoch = -1 + + def should_update(self, step_idx: Optional[int] = None, epoch_idx: Optional[int] = None) -> bool: + """Called after every optimizer step and after every training epoch to check whether the average model should + be updated. + + One of the arguments is set to the zero-based index of the last training step or epoch. The default + implementation returns ``True`` when any ``step_idx`` is provided. The user can customize when the average model + gets updated by overriding this method. + + Args: + step_idx: Index of the last optimizer step, or ``None`` when called at the epoch end. + epoch_idx: Index of the last epoch, or ``None`` when called after an optimizer step. + + Returns: + ``True`` if the average model should be updated and ``False`` if not. + + """ + return step_idx is not None + + @override + def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: str) -> None: + """Called when fit, validate, test, predict, or tune begins. + + Creates an :class:`AveragedModel` when fit begins. + + Args: + trainer: The current :class:`~lightning.pytorch.trainer.trainer.Trainer` instance. + pl_module: The current :class:`~lightning.pytorch.core.LightningModule` instance. + stage: The :class:`~lightning.pytorch.trainer.trainer.Trainer` state. + + """ + if stage == "fit": + device = self._device or pl_module.device + + # If the configure_model hook is overridden, call it to create the layers before constructing the + # AveragedModel. However, sharding will not be done and a warning will be issued. + if is_overridden("configure_model", pl_module): + rank_zero_warn( + "You're using the WeightAveraging callback with a model that overrides the configure_model " + "callback. WeightAveraging doesn't support sharding model layers, so you may run out of memory." + ) + pl_module.configure_model() + + self._average_model = AveragedModel( + model=pl_module, device=device, use_buffers=self._use_buffers, **self._kwargs + ) + + @override + def on_train_batch_end( + self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", outputs: STEP_OUTPUT, batch: Any, batch_idx: int + ) -> None: + """Called when a training batch ends. + + Updates the :class:`AveragedModel` parameters, if requested by ``self.should_update()``. + + Args: + trainer: The current :class:`~lightning.pytorch.trainer.trainer.Trainer` instance. + pl_module: The current :class:`~lightning.pytorch.core.LightningModule` instance. + outputs: Outputs from the training batch. + batch: The training batch. + batch_idx: Index of the training batch. + + """ + # trainer.global_step is the number of optimizer steps taken so far, i.e. 1 after the first optimizer step. To + # make step_idx consistent with epoch_idx, we'll pass a zero-based index. + step_idx = trainer.global_step - 1 + if (trainer.global_step > self._latest_update_step) and self.should_update(step_idx=step_idx): + assert self._average_model is not None + self._average_model.update_parameters(pl_module) + self._latest_update_step = trainer.global_step + + @override + def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + """Called when a training epoch ends. + + Updates the :class:`AveragedModel` parameters, if requested by ``self.should_update()``. + + Args: + trainer: The current :class:`~lightning.pytorch.trainer.trainer.Trainer` instance. + pl_module: The current :class:`~lightning.pytorch.core.LightningModule` instance. + + """ + if (trainer.current_epoch > self._latest_update_epoch) and self.should_update(epoch_idx=trainer.current_epoch): + assert self._average_model is not None + self._average_model.update_parameters(pl_module) + self._latest_update_epoch = trainer.current_epoch + + @override + def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + """Called when training ends. + + Transfers parameters from the :class:`AveragedModel` to the current model. + + Args: + trainer: The current :class:`~lightning.pytorch.trainer.trainer.Trainer` instance. + pl_module: The current :class:`~lightning.pytorch.core.LightningModule` instance. + + """ + assert self._average_model is not None + self._copy_average_to_current(pl_module) + + @override + def on_validation_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + """Called when a validation epoch begins. + + Transfers parameter values from the :class:`AveragedModel` to the current model. + + Args: + trainer: The current :class:`~lightning.pytorch.trainer.trainer.Trainer` instance. + pl_module: The current :class:`~lightning.pytorch.core.LightningModule` instance. + + """ + if self._average_model is not None: + self._swap_models(pl_module) + + @override + def on_validation_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + """Called when a validation epoch ends. + + Recovers the current model parameters from the :class:`AveragedModel`. + + Args: + trainer: The current :class:`~lightning.pytorch.trainer.trainer.Trainer` instance. + pl_module: The current :class:`~lightning.pytorch.core.LightningModule` instance. + + """ + if self._average_model is not None: + self._swap_models(pl_module) + + @override + def state_dict(self) -> dict[str, Any]: + """Called when saving a checkpoint. + + Creates a ``state_dict`` of the callback state. + + Returns: + A dictionary containing the callback state. + + """ + return {"latest_update_step": self._latest_update_step} + + @override + def load_state_dict(self, state_dict: dict[str, Any]) -> None: + """Called when loading a checkpoint. + + Reloads the callback state given a ``state_dict``. + + Args: + state_dict: A dictionary containing the callback state. + + """ + self._latest_update_step = state_dict["latest_update_step"] + + @override + def on_save_checkpoint( + self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", checkpoint: dict[str, Any] + ) -> None: + r"""Called when saving a checkpoint. + + Moves the current model state to the key ``current_model_state``, and places the average model state in + ``state_dict`` instead. Any other state variables of the ``AveragedModel`` will be saved in + ``averaging_state``. + + Args: + trainer: The current :class:`~lightning.pytorch.trainer.trainer.Trainer` instance. + pl_module: The current :class:`~lightning.pytorch.core.LightningModule` instance. + checkpoint: The checkpoint dictionary that will be saved. + + """ + if self._average_model is None: + rank_zero_info( + "You're using the WeightAveraging callback, but saving a checkpoint outside the 'fit' stage. The state " + "of the WeightAveraging callback won't be saved in the checkpoint. If training has finished, the " + "average model parameters will be saved to the state_dict in the checkpoint." + ) + else: + average_model_state = self._average_model.state_dict() + checkpoint["current_model_state"] = checkpoint["state_dict"] + checkpoint["state_dict"] = { + name[7:]: value for name, value in average_model_state.items() if name.startswith("module.") + } + checkpoint["averaging_state"] = { + name: value for name, value in average_model_state.items() if not name.startswith("module.") + } + + @override + def on_load_checkpoint( + self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", checkpoint: dict[str, Any] + ) -> None: + r"""Called when loading a model checkpoint. + + Loads the current model and the :class:`AveragedModel` parameters from the checkpoint. + + Args: + trainer: The current :class:`~lightning.pytorch.trainer.trainer.Trainer` instance. + pl_module: The current :class:`~lightning.pytorch.core.LightningModule` instance. + checkpoint: The full checkpoint dictionary that got loaded by the Trainer. + + """ + if self._average_model is None: + rank_zero_warn( + "You're using the WeightAveraging callback, but loading a checkpoint outside the 'fit' stage. The " + "WeightAveraging state cannot be restored. If you're using the checkpoint for prediction or testing, " + "you can ignore this warning. To disable the warning, remove the WeightAveraging callback." + ) + elif ("current_model_state" in checkpoint) and ("averaging_state" in checkpoint): + rank_zero_info("Found current_model_state in the checkpoint. This will be used to initialize the model.") + average_model_state = {"module." + name: value for name, value in checkpoint["state_dict"].items()} + average_model_state |= checkpoint["averaging_state"] + self._average_model.load_state_dict(average_model_state) + # The current model state has already been loaded from "state_dict" (which contains the average model + # weights) at this point, so overwriting "state_dict" in the checkpoint dictionary makes no difference. We + # have to reload the model state from "current_model_state". + pl_module.load_state_dict(checkpoint["current_model_state"]) + else: + rank_zero_warn( + "The checkpoint was not created with WeightAveraging. Both the current and the average model will be " + "initialized with state_dict." + ) + self._average_model.module.load_state_dict(deepcopy(checkpoint["state_dict"]), strict=False) + + def _swap_models(self, pl_module: "pl.LightningModule") -> None: + """Swaps the parameter values of the current model and the :class:`AveragedModel`. + + Args: + pl_module: The current :class:`~lightning.pytorch.core.LightningModule` instance. + + """ + assert self._average_model is not None + average_params = itertools.chain(self._average_model.module.parameters(), self._average_model.module.buffers()) + current_params = itertools.chain(pl_module.parameters(), pl_module.buffers()) + for average_param, current_param in zip(average_params, current_params): + tmp = average_param.data.clone() + average_param.data.copy_(current_param.data) + current_param.data.copy_(tmp) + + def _copy_average_to_current(self, pl_module: "pl.LightningModule") -> None: + """Copies the parameter values from the :class:`AveragedModel` to the current model. + + Args: + pl_module: The current :class:`~lightning.pytorch.core.LightningModule` instance. + + """ + assert self._average_model is not None + average_params = itertools.chain(self._average_model.module.parameters(), self._average_model.module.buffers()) + current_params = itertools.chain(pl_module.parameters(), pl_module.buffers()) + for average_param, current_param in zip(average_params, current_params): + current_param.data.copy_(average_param.data) diff --git a/tests/tests_pytorch/callbacks/test_weight_averaging.py b/tests/tests_pytorch/callbacks/test_weight_averaging.py new file mode 100644 index 0000000000000..ec230b2fd6c97 --- /dev/null +++ b/tests/tests_pytorch/callbacks/test_weight_averaging.py @@ -0,0 +1,331 @@ +# Copyright The Lightning AI team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from copy import deepcopy +from pathlib import Path +from typing import Any, Optional + +import pytest +import torch +from torch import Tensor, nn +from torch.optim.swa_utils import get_swa_avg_fn +from torch.utils.data import DataLoader, Dataset + +from lightning.pytorch import LightningModule, Trainer +from lightning.pytorch.callbacks import WeightAveraging +from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset, RandomIterableDataset +from tests_pytorch.helpers.runif import RunIf + + +class TestModel(BoringModel): + def __init__(self, batch_norm: bool = True) -> None: + super().__init__() + layers = [nn.Linear(32, 32)] + if batch_norm: + layers.append(nn.BatchNorm1d(32)) + layers += [nn.ReLU(), nn.Linear(32, 2)] + self.layer = nn.Sequential(*layers) + self.crash_on_epoch = None + + def training_step(self, batch: Tensor, batch_idx: int) -> None: + if self.crash_on_epoch and self.trainer.current_epoch >= self.crash_on_epoch: + raise Exception("CRASH") + return super().training_step(batch, batch_idx) + + def configure_optimizers(self) -> None: + return torch.optim.SGD(self.layer.parameters(), lr=0.1) + + +class LargeTestModel(BoringModel): + def __init__(self): + super().__init__() + self.layer = None + + def configure_model(self): + print("XXX configure_model") + self.layer = nn.Sequential(nn.Linear(32, 32), nn.ReLU(), nn.Linear(32, 2)) + + def configure_optimizers(self): + return torch.optim.SGD(self.parameters(), lr=0.01) + + +class EMAAveragingFunction: + """EMA averaging function. + + Functionally equivalent to the closure that ``get_ema_avg_fn()`` would return. This class is needed because we + cannot use a closure with ddp_spawn. (``Popen(process_obj)`` would fail with + ``Can't get local object 'get_ema_avg_fn..ema_update'``). + + """ + + def __init__(self, decay: float = 0.999) -> None: + self.decay = decay + + @torch.no_grad() + def __call__(self, ema_param: Tensor, current_param: Tensor, num_averaged: Tensor) -> Tensor: + return self.decay * ema_param + (1 - self.decay) * current_param + + +class EMATestCallback(WeightAveraging): + def __init__(self, devices: int = 1, **kwargs: Any) -> None: + super().__init__(avg_fn=EMAAveragingFunction(), **kwargs) + self.devices = devices + self.swap_calls = 0 + self.copy_calls = 0 + # Record the first epoch, as if we are resuming from a checkpoint this may not be equal to 0. + self.first_epoch: Optional[int] = None + + def _swap_models(self, *args: Any, **kwargs: Any): + self.swap_calls += 1 + return super()._swap_models(*args, **kwargs) + + def _copy_average_to_current(self, *args: Any, **kwargs: Any): + self.copy_calls += 1 + return super()._copy_average_to_current(*args, **kwargs) + + def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None: + super().on_train_start(trainer, pl_module) + assert self.swap_calls == 0 + assert self.copy_calls == 0 + + def on_train_epoch_start(self, trainer: Trainer, *args: Any) -> None: + super().on_train_epoch_start(trainer, *args) + # Since the checkpoint loaded was saved `on_train_epoch_end`, the first `FitLoop` iteration will not update the + # model and will just call the epoch-level hooks. For that reason, we check that we are not restarting before + # choosing the first epoch. + if self.first_epoch is None and not trainer.fit_loop.restarting: + self.first_epoch = trainer.current_epoch + + def on_train_epoch_end(self, trainer: Trainer, *args: Any) -> None: + super().on_train_epoch_end(trainer, *args) + assert self._average_model.n_averaged == trainer.global_step + assert self.swap_calls == (trainer.current_epoch + 1 - self.first_epoch) * 2 + assert self.copy_calls == 0 + + def on_train_end(self, trainer: Trainer, pl_module: LightningModule) -> None: + super().on_train_end(trainer, pl_module) + # length=32, batch_size=4, accumulate_grad_batches=2 + # => Using one process we have 4 optimizer steps per epoch. + # => Using two processes we have 2 optimizer steps per epoch. + steps_per_epoch = 4 // self.devices + assert self._average_model.n_averaged == trainer.max_epochs * steps_per_epoch + assert self.swap_calls == (trainer.max_epochs - self.first_epoch) * 2 + assert self.copy_calls == 1 + + +class SWATestCallback(WeightAveraging): + def __init__(self, **kwargs: Any) -> None: + super().__init__(avg_fn=get_swa_avg_fn(), **kwargs) + self.swap_calls = 0 + self.copy_calls = 0 + # Record the first epoch, as if we are resuming from a checkpoint this may not be equal to 0. + self.first_epoch: Optional[int] = None + + def should_update(self, step_idx: Optional[int] = None, epoch_idx: Optional[int] = None) -> bool: + return epoch_idx in (3, 5, 7) + + def _swap_models(self, *args: Any, **kwargs: Any): + self.swap_calls += 1 + return super()._swap_models(*args, **kwargs) + + def _copy_average_to_current(self, *args: Any, **kwargs: Any): + self.copy_calls += 1 + return super()._copy_average_to_current(*args, **kwargs) + + def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None: + super().on_train_start(trainer, pl_module) + assert self.swap_calls == 0 + assert self.copy_calls == 0 + + def on_train_epoch_start(self, trainer: Trainer, *args: Any) -> None: + super().on_train_epoch_start(trainer, *args) + # Since the checkpoint loaded was saved `on_train_epoch_end`, the first `FitLoop` iteration will not update the + # model and will just call the epoch-level hooks. For that reason, we check that we are not restarting before + # choosing the first epoch. + if self.first_epoch is None and not trainer.fit_loop.restarting: + self.first_epoch = trainer.current_epoch + + def on_train_epoch_end(self, trainer: Trainer, *args: Any) -> None: + super().on_train_epoch_end(trainer, *args) + if trainer.current_epoch < 3: + assert self._average_model.n_averaged == 0 + elif trainer.current_epoch < 5: + assert self._average_model.n_averaged == 1 + elif trainer.current_epoch < 7: + assert self._average_model.n_averaged == 2 + else: + assert self._average_model.n_averaged == 3 + assert self.swap_calls == (trainer.current_epoch + 1 - self.first_epoch) * 2 + assert self.copy_calls == 0 + + def on_train_end(self, trainer: Trainer, pl_module: LightningModule) -> None: + super().on_train_end(trainer, pl_module) + assert self._average_model.n_averaged == 3 + assert self.swap_calls == (trainer.max_epochs - self.first_epoch) * 2 + assert self.copy_calls == 1 + + +def test_weight_averaging_deepcopy(tmp_path): + """Ensure that WeightAveraging callback doesn't deepcopy the data loaders or the data module and consume memory + more than necessary.""" + + class TestCallback(WeightAveraging): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.setup_called = False + + def setup(self, trainer, pl_module, stage) -> None: + super().setup(trainer, pl_module, stage) + assert self._average_model.module.train_dataloader is not pl_module.train_dataloader + assert self._average_model.module.train_dataloader.__self__ == self._average_model.module + assert self._average_model.module._trainer is None + self.setup_called = True + + callback = TestCallback() + trainer = Trainer(default_root_dir=tmp_path, callbacks=callback, fast_dev_run=True) + trainer.fit(BoringModel(), train_dataloaders=DataLoader(RandomDataset(32, 2))) + assert callback.setup_called + + +@pytest.mark.parametrize("batch_norm", [True, False]) +@pytest.mark.parametrize("iterable_dataset", [True, False]) +def test_ema(tmp_path, batch_norm: bool, iterable_dataset: bool): + model = TestModel(batch_norm=batch_norm) + dataset = RandomIterableDataset(32, 32) if iterable_dataset else RandomDataset(32, 32) + _train(model, dataset, tmp_path, EMATestCallback()) + + +@pytest.mark.parametrize( + "accelerator", [pytest.param("gpu", marks=RunIf(min_cuda_gpus=1)), pytest.param("mps", marks=RunIf(mps=True))] +) +def test_ema_accelerator(tmp_path, accelerator): + model = TestModel() + dataset = RandomDataset(32, 32) + _train(model, dataset, tmp_path, EMATestCallback(), accelerator=accelerator, devices=1) + + +@RunIf(min_cuda_gpus=2, standalone=True) +def test_ema_ddp(tmp_path): + model = TestModel() + dataset = RandomDataset(32, 32) + _train(model, dataset, tmp_path, EMATestCallback(devices=2), strategy="ddp", accelerator="gpu", devices=2) + + +@RunIf(min_cuda_gpus=2) +def test_ema_ddp_spawn(tmp_path): + model = TestModel() + dataset = RandomDataset(32, 32) + _train(model, dataset, tmp_path, EMATestCallback(devices=2), strategy="ddp_spawn", accelerator="gpu", devices=2) + + +@RunIf(skip_windows=True) +def test_ema_ddp_spawn_cpu(tmp_path): + model = TestModel() + dataset = RandomDataset(32, 32) + _train(model, dataset, tmp_path, EMATestCallback(devices=2), strategy="ddp_spawn", accelerator="cpu", devices=2) + + +@pytest.mark.parametrize("crash_on_epoch", [1, 3, 5]) +def test_ema_resume(tmp_path, crash_on_epoch): + dataset = RandomDataset(32, 32) + model1 = TestModel() + model2 = deepcopy(model1) + + _train(model1, dataset, tmp_path, EMATestCallback()) + + model2.crash_on_epoch = crash_on_epoch + model2 = _train_and_resume(model2, dataset, tmp_path) + + for param1, param2 in zip(model1.parameters(), model2.parameters()): + assert torch.allclose(param1, param2) + + +@RunIf(skip_windows=True) +def test_ema_resume_ddp(tmp_path): + model = TestModel() + model.crash_on_epoch = 3 + dataset = RandomDataset(32, 32) + _train_and_resume(model, dataset, tmp_path, strategy="ddp_spawn", devices=2) + + +def test_swa(tmp_path): + model = TestModel() + dataset = RandomDataset(32, 32) + _train(model, dataset, tmp_path, SWATestCallback()) + + +@pytest.mark.parametrize( + ("strategy", "accelerator", "devices"), + [ + ("auto", "cpu", 1), + pytest.param("auto", "gpu", 1, marks=RunIf(min_cuda_gpus=1)), + pytest.param("fsdp", "gpu", 1, marks=RunIf(min_cuda_gpus=1)), + ], +) +def test_ema_configure_model(tmp_path, strategy, accelerator, devices): + model = LargeTestModel() + dataset = RandomDataset(32, 32) + callback = EMATestCallback() + _train(model, dataset, tmp_path, callback, strategy=strategy, accelerator=accelerator, devices=devices) + assert isinstance(callback._average_model.module.layer, nn.Sequential) + + +def _train( + model: BoringModel, + dataset: Dataset, + tmp_path: str, + callback: WeightAveraging, + strategy: str = "auto", + accelerator: str = "cpu", + devices: int = 1, + checkpoint_path: Optional[str] = None, + will_crash: bool = False, +) -> None: + deterministic = accelerator == "cpu" + trainer = Trainer( + accelerator=accelerator, + strategy=strategy, + devices=devices, + logger=False, + callbacks=callback, + max_epochs=8, + num_sanity_val_steps=0, + enable_checkpointing=will_crash, + enable_progress_bar=False, + enable_model_summary=False, + accumulate_grad_batches=2, + deterministic=deterministic, + default_root_dir=tmp_path, + ) + dataloader = DataLoader(dataset, batch_size=4, shuffle=False) + if will_crash: + with pytest.raises(Exception, match="CRASH"): + trainer.fit(model, dataloader, ckpt_path=checkpoint_path) + else: + trainer.fit(model, dataloader, ckpt_path=checkpoint_path) + assert trainer.lightning_module == model + + +def _train_and_resume(model: TestModel, dataset: Dataset, tmp_path: str, devices: int = 1, **kwargs) -> TestModel: + _train(model, dataset, tmp_path, EMATestCallback(devices=devices), devices=devices, will_crash=True, **kwargs) + + checkpoint_dir = Path(tmp_path) / "checkpoints" + checkpoint_names = os.listdir(checkpoint_dir) + assert len(checkpoint_names) == 1 + checkpoint_path = str(checkpoint_dir / checkpoint_names[0]) + + model = TestModel.load_from_checkpoint(checkpoint_path) + callback = EMATestCallback(devices=devices) + _train(model, dataset, tmp_path, callback, devices=devices, checkpoint_path=checkpoint_path, **kwargs) + return model