From a4a64af4c07132aa28092ffc1b6c5a6edeaf8fcb Mon Sep 17 00:00:00 2001 From: Seppo Enarvi Date: Mon, 10 Jul 2023 10:42:49 +0300 Subject: [PATCH] Revision of the MoCo SSL model (#928) --- CHANGELOG.md | 2 +- docs/source/introduction_guide.rst | 2 +- docs/source/models/self_supervised.rst | 2 +- pyproject.toml | 1 - .../detection/retinanet/retinanet_module.py | 5 +- .../models/detection/yolo/yolo_module.py | 3 +- .../models/self_supervised/__init__.py | 4 +- .../models/self_supervised/moco/callbacks.py | 2 +- .../self_supervised/moco/moco2_module.py | 399 ------------------ .../self_supervised/moco/moco_module.py | 328 ++++++++++++++ .../models/self_supervised/moco/utils.py | 163 +++++++ tests/models/self_supervised/test_models.py | 8 +- tests/models/test_scripts.py | 4 +- 13 files changed, 506 insertions(+), 417 deletions(-) delete mode 100644 src/pl_bolts/models/self_supervised/moco/moco2_module.py create mode 100644 src/pl_bolts/models/self_supervised/moco/moco_module.py create mode 100644 src/pl_bolts/models/self_supervised/moco/utils.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 1d00d9d990..2819df9d0c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,7 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed -- +- Revision of the MoCo SSL model ([#928](https://github.com/PyTorchLightning/pytorch-lightning-bolts/pull/928)) ### Deprecated diff --git a/docs/source/introduction_guide.rst b/docs/source/introduction_guide.rst index df18916b93..18d74f0d3a 100644 --- a/docs/source/introduction_guide.rst +++ b/docs/source/introduction_guide.rst @@ -22,7 +22,7 @@ All models are tested (daily), benchmarked, documented and work on CPUs, TPUs, G from pl_bolts.models import VAE from pl_bolts.models.vision import GPT2, ImageGPT, PixelCNN - from pl_bolts.models.self_supervised import AMDIM, CPC_v2, SimCLR, Moco_v2 + from pl_bolts.models.self_supervised import AMDIM, CPC_v2, SimCLR, MoCo from pl_bolts.models import LinearRegression, LogisticRegression from pl_bolts.models.gans import GAN from pl_bolts.callbacks import PrintTableMetricsCallback diff --git a/docs/source/models/self_supervised.rst b/docs/source/models/self_supervised.rst index 270a6f06e7..dacb523ec9 100644 --- a/docs/source/models/self_supervised.rst +++ b/docs/source/models/self_supervised.rst @@ -259,7 +259,7 @@ CPC (v2) API Moco (v2) API ^^^^^^^^^^^^^ -.. autoclass:: pl_bolts.models.self_supervised.Moco_v2 +.. autoclass:: pl_bolts.models.self_supervised.MoCo :noindex: SimCLR diff --git a/pyproject.toml b/pyproject.toml index 2f34e45b2a..adca3a93b1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -198,7 +198,6 @@ module = [ "pl_bolts.models.self_supervised.cpc.transforms", "pl_bolts.models.self_supervised.evaluator", "pl_bolts.models.self_supervised.moco.callbacks", - "pl_bolts.models.self_supervised.moco.moco2_module", "pl_bolts.models.self_supervised.moco.transforms", "pl_bolts.models.self_supervised.resnets", "pl_bolts.models.self_supervised.simclr.simclr_finetuner", diff --git a/src/pl_bolts/models/detection/retinanet/retinanet_module.py b/src/pl_bolts/models/detection/retinanet/retinanet_module.py index c0415cc30b..67583fb818 100644 --- a/src/pl_bolts/models/detection/retinanet/retinanet_module.py +++ b/src/pl_bolts/models/detection/retinanet/retinanet_module.py @@ -135,10 +135,7 @@ def configure_optimizers(self): @under_review() def cli_main(): - try: # Backward compatibility for Lightning CLI - from pytorch_lightning.cli import LightningCLI # PL v1.9+ - except ImportError: - from pytorch_lightning.utilities.cli import LightningCLI # PL v1.8 + from pytorch_lightning.cli import LightningCLI from pl_bolts.datamodules import VOCDetectionDataModule diff --git a/src/pl_bolts/models/detection/yolo/yolo_module.py b/src/pl_bolts/models/detection/yolo/yolo_module.py index 342f9f5ca7..cd29409163 100644 --- a/src/pl_bolts/models/detection/yolo/yolo_module.py +++ b/src/pl_bolts/models/detection/yolo/yolo_module.py @@ -4,7 +4,6 @@ import torch import torch.nn as nn from pytorch_lightning import LightningModule -from pytorch_lightning.utilities.cli import LightningCLI from pytorch_lightning.utilities.types import STEP_OUTPUT from torch import Tensor, optim @@ -614,4 +613,6 @@ def _resize(self, image: Tensor, target: TARGET) -> Tuple[Tensor, TARGET]: if __name__ == "__main__": + from pytorch_lightning.cli import LightningCLI + LightningCLI(CLIYOLO, ResizedVOCDetectionDataModule, seed_everything_default=42) diff --git a/src/pl_bolts/models/self_supervised/__init__.py b/src/pl_bolts/models/self_supervised/__init__.py index c2518bb303..7286fe9dc8 100644 --- a/src/pl_bolts/models/self_supervised/__init__.py +++ b/src/pl_bolts/models/self_supervised/__init__.py @@ -20,7 +20,7 @@ from pl_bolts.models.self_supervised.byol.byol_module import BYOL from pl_bolts.models.self_supervised.cpc.cpc_module import CPC_v2 from pl_bolts.models.self_supervised.evaluator import SSLEvaluator -from pl_bolts.models.self_supervised.moco.moco2_module import Moco_v2 +from pl_bolts.models.self_supervised.moco.moco_module import MoCo from pl_bolts.models.self_supervised.simclr.simclr_module import SimCLR from pl_bolts.models.self_supervised.simsiam.simsiam_module import SimSiam from pl_bolts.models.self_supervised.ssl_finetuner import SSLFineTuner @@ -31,7 +31,7 @@ "BYOL", "CPC_v2", "SSLEvaluator", - "Moco_v2", + "MoCo", "SimCLR", "SimSiam", "SSLFineTuner", diff --git a/src/pl_bolts/models/self_supervised/moco/callbacks.py b/src/pl_bolts/models/self_supervised/moco/callbacks.py index dfc4bb3948..24bd5a1bf6 100644 --- a/src/pl_bolts/models/self_supervised/moco/callbacks.py +++ b/src/pl_bolts/models/self_supervised/moco/callbacks.py @@ -6,7 +6,7 @@ @under_review() -class MocoLRScheduler(Callback): +class MoCoLRScheduler(Callback): def __init__(self, initial_lr=0.03, use_cosine_scheduler=False, schedule=(120, 160), max_epochs=200) -> None: super().__init__() self.lr = initial_lr diff --git a/src/pl_bolts/models/self_supervised/moco/moco2_module.py b/src/pl_bolts/models/self_supervised/moco/moco2_module.py deleted file mode 100644 index 2c534614b3..0000000000 --- a/src/pl_bolts/models/self_supervised/moco/moco2_module.py +++ /dev/null @@ -1,399 +0,0 @@ -# Original work is: Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. -# This implementation is: Copyright (c) PyTorch Lightning, Inc. and its affiliates. All Rights Reserved -# -# This implementation is licensed under Attribution-NonCommercial 4.0 International; -# You may not use this file except in compliance with the License. -# -# You may obtain a copy of the License from the LICENSE file present in this folder. -"""MoCo2. - -Adapted from https: //github.com/facebookresearch/moco. -""" -from argparse import ArgumentParser -from typing import Union - -import torch -from pytorch_lightning import LightningModule, Trainer -from pytorch_lightning.strategies import DDPStrategy -from torch import nn -from torch.nn import functional as F # noqa: N812 - -from pl_bolts.metrics import mean, precision_at_k -from pl_bolts.transforms.self_supervised.moco_transforms import ( - MoCo2EvalCIFAR10Transforms, - MoCo2EvalImagenetTransforms, - MoCo2EvalSTL10Transforms, - MoCo2TrainCIFAR10Transforms, - MoCo2TrainImagenetTransforms, - MoCo2TrainSTL10Transforms, -) -from pl_bolts.utils import _TORCHVISION_AVAILABLE -from pl_bolts.utils.stability import under_review -from pl_bolts.utils.warnings import warn_missing_pkg - -if _TORCHVISION_AVAILABLE: - import torchvision -else: # pragma: no cover - warn_missing_pkg("torchvision") - - -@under_review() -class Moco_v2(LightningModule): # noqa: N801 - """PyTorch Lightning implementation of `Moco `_ - - Paper authors: Xinlei Chen, Haoqi Fan, Ross Girshick, Kaiming He. - - Code adapted from `facebookresearch/moco `_ to Lightning by: - - - `William Falcon `_ - - Example:: - - from pl_bolts.models.self_supervised import Moco_v2 - model = Moco_v2() - trainer = Trainer() - trainer.fit(model) - - CLI command:: - - # cifar10 - python moco2_module.py --gpus 1 - - # imagenet - python moco2_module.py - --gpus 8 - --dataset imagenet2012 - --data_dir /path/to/imagenet/ - --meta_dir /path/to/folder/with/meta.bin/ - --batch_size 32 - """ - - def __init__( - self, - base_encoder: Union[str, torch.nn.Module] = "resnet18", - emb_dim: int = 128, - num_negatives: int = 65536, - encoder_momentum: float = 0.999, - softmax_temperature: float = 0.07, - learning_rate: float = 0.03, - momentum: float = 0.9, - weight_decay: float = 1e-4, - data_dir: str = "./", - batch_size: int = 256, - use_mlp: bool = False, - num_workers: int = 8, - *args, - **kwargs - ) -> None: - """ - Args: - base_encoder: torchvision model name or torch.nn.Module - emb_dim: feature dimension (default: 128) - num_negatives: queue size; number of negative keys (default: 65536) - encoder_momentum: moco momentum of updating key encoder (default: 0.999) - softmax_temperature: softmax temperature (default: 0.07) - learning_rate: the learning rate - momentum: optimizer momentum - weight_decay: optimizer weight decay - datamodule: the DataModule (train, val, test dataloaders) - data_dir: the directory to store data - batch_size: batch size - use_mlp: add an mlp to the encoders - num_workers: workers for the loaders - """ - - super().__init__() - self.save_hyperparameters() - - # create the encoders - # num_classes is the output fc dimension - self.encoder_q, self.encoder_k = self.init_encoders(base_encoder) - - if use_mlp: # hack: brute-force replacement - dim_mlp = self.encoder_q.fc.weight.shape[1] - self.encoder_q.fc = nn.Sequential(nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), self.encoder_q.fc) - self.encoder_k.fc = nn.Sequential(nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), self.encoder_k.fc) - - for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()): - param_k.data.copy_(param_q.data) # initialize - param_k.requires_grad = False # not update by gradient - - # create the queue - self.register_buffer("queue", torch.randn(emb_dim, num_negatives)) - self.queue = nn.functional.normalize(self.queue, dim=0) - - self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long)) - - # create the validation queue - self.register_buffer("val_queue", torch.randn(emb_dim, num_negatives)) - self.val_queue = nn.functional.normalize(self.val_queue, dim=0) - - self.register_buffer("val_queue_ptr", torch.zeros(1, dtype=torch.long)) - - def init_encoders(self, base_encoder): - """Override to add your own encoders.""" - - template_model = getattr(torchvision.models, base_encoder) - encoder_q = template_model(num_classes=self.hparams.emb_dim) - encoder_k = template_model(num_classes=self.hparams.emb_dim) - - return encoder_q, encoder_k - - @torch.no_grad() - def _momentum_update_key_encoder(self): - """Momentum update of the key encoder.""" - for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()): - em = self.hparams.encoder_momentum - param_k.data = param_k.data * em + param_q.data * (1.0 - em) - - @torch.no_grad() - def _dequeue_and_enqueue(self, keys, queue_ptr, queue): - # gather keys before updating queue - if self._use_ddp(self.trainer): - keys = concat_all_gather(keys) - - batch_size = keys.shape[0] - - ptr = int(queue_ptr) - assert self.hparams.num_negatives % batch_size == 0 # for simplicity - - # replace the keys at ptr (dequeue and enqueue) - queue[:, ptr : ptr + batch_size] = keys.T - ptr = (ptr + batch_size) % self.hparams.num_negatives # move pointer - - queue_ptr[0] = ptr - - @torch.no_grad() - def _batch_shuffle_ddp(self, x): # pragma: no cover - """Batch shuffle, for making use of BatchNorm. - - *** Only support DistributedDataParallel (DDP) model. *** - """ - # gather from all gpus - batch_size_this = x.shape[0] - x_gather = concat_all_gather(x) - batch_size_all = x_gather.shape[0] - - num_gpus = batch_size_all // batch_size_this - - # random shuffle index - idx_shuffle = torch.randperm(batch_size_all).cuda() - - # broadcast to all gpus - torch.distributed.broadcast(idx_shuffle, src=0) - - # index for restoring - idx_unshuffle = torch.argsort(idx_shuffle) - - # shuffled index for this gpu - gpu_idx = torch.distributed.get_rank() - idx_this = idx_shuffle.view(num_gpus, -1)[gpu_idx] - - return x_gather[idx_this], idx_unshuffle - - @torch.no_grad() - def _batch_unshuffle_ddp(self, x, idx_unshuffle): # pragma: no cover - """Undo batch shuffle. - - *** Only support DistributedDataParallel (DDP) model. *** - """ - # gather from all gpus - batch_size_this = x.shape[0] - x_gather = concat_all_gather(x) - batch_size_all = x_gather.shape[0] - - num_gpus = batch_size_all // batch_size_this - - # restored index for this gpu - gpu_idx = torch.distributed.get_rank() - idx_this = idx_unshuffle.view(num_gpus, -1)[gpu_idx] - - return x_gather[idx_this] - - def forward(self, img_q, img_k, queue): - """ - Input: - im_q: a batch of query images - im_k: a batch of key images - queue: a queue from which to pick negative samples - Output: - logits, targets - """ - - # compute query features - q = self.encoder_q(img_q) # queries: NxC - q = nn.functional.normalize(q, dim=1) - - # compute key features - with torch.no_grad(): # no gradient to keys - # shuffle for making use of BN - if self._use_ddp(self.trainer): - img_k, idx_unshuffle = self._batch_shuffle_ddp(img_k) - - k = self.encoder_k(img_k) # keys: NxC - k = nn.functional.normalize(k, dim=1) - - # undo shuffle - if self._use_ddp(self.trainer): - k = self._batch_unshuffle_ddp(k, idx_unshuffle) - - # compute logits - # Einstein sum is more intuitive - # positive logits: Nx1 - l_pos = torch.einsum("nc,nc->n", [q, k]).unsqueeze(-1) - # negative logits: NxK - l_neg = torch.einsum("nc,ck->nk", [q, queue.clone().detach()]) - - # logits: Nx(1+K) - logits = torch.cat([l_pos, l_neg], dim=1) - - # apply temperature - logits /= self.hparams.softmax_temperature - - # labels: positive key indicators - labels = torch.zeros(logits.shape[0], dtype=torch.long) - labels = labels.type_as(logits) - - return logits, labels, k - - def training_step(self, batch, batch_idx): - # in STL10 we pass in both lab+unl for online ft - if self.trainer.datamodule.name == "stl10": - # labeled_batch = batch[1] - unlabeled_batch = batch[0] - batch = unlabeled_batch - - (img_1, img_2), _ = batch - - self._momentum_update_key_encoder() # update the key encoder - output, target, keys = self(img_q=img_1, img_k=img_2, queue=self.queue) - self._dequeue_and_enqueue(keys, queue=self.queue, queue_ptr=self.queue_ptr) # dequeue and enqueue - - loss = F.cross_entropy(output.float(), target.long()) - - acc1, acc5 = precision_at_k(output, target, top_k=(1, 5)) - - log = {"train_loss": loss, "train_acc1": acc1, "train_acc5": acc5} - self.log_dict(log) - return loss - - def validation_step(self, batch, batch_idx): - # in STL10 we pass in both lab+unl for online ft - if self.trainer.datamodule.name == "stl10": - # labeled_batch = batch[1] - unlabeled_batch = batch[0] - batch = unlabeled_batch - - (img_1, img_2), labels = batch - - output, target, keys = self(img_q=img_1, img_k=img_2, queue=self.val_queue) - self._dequeue_and_enqueue(keys, queue=self.val_queue, queue_ptr=self.val_queue_ptr) # dequeue and enqueue - - loss = F.cross_entropy(output, target.long()) - - acc1, acc5 = precision_at_k(output, target, top_k=(1, 5)) - - return {"val_loss": loss, "val_acc1": acc1, "val_acc5": acc5} - - def validation_epoch_end(self, outputs): - val_loss = mean(outputs, "val_loss") - val_acc1 = mean(outputs, "val_acc1") - val_acc5 = mean(outputs, "val_acc5") - - log = {"val_loss": val_loss, "val_acc1": val_acc1, "val_acc5": val_acc5} - self.log_dict(log) - - def configure_optimizers(self): - optimizer = torch.optim.SGD( - self.parameters(), - self.hparams.learning_rate, - momentum=self.hparams.momentum, - weight_decay=self.hparams.weight_decay, - ) - scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( - optimizer, - self.trainer.max_epochs, - ) - return [optimizer], [scheduler] - - @staticmethod - def add_model_specific_args(parent_parser): - parser = ArgumentParser(parents=[parent_parser], add_help=False) - parser.add_argument("--base_encoder", type=str, default="resnet18") - parser.add_argument("--emb_dim", type=int, default=128) - parser.add_argument("--num_workers", type=int, default=8) - parser.add_argument("--num_negatives", type=int, default=65536) - parser.add_argument("--encoder_momentum", type=float, default=0.999) - parser.add_argument("--softmax_temperature", type=float, default=0.07) - parser.add_argument("--learning_rate", type=float, default=0.03) - parser.add_argument("--momentum", type=float, default=0.9) - parser.add_argument("--weight_decay", type=float, default=1e-4) - parser.add_argument("--data_dir", type=str, default="./") - parser.add_argument("--dataset", type=str, default="cifar10", choices=["cifar10", "imagenet2012", "stl10"]) - parser.add_argument("--batch_size", type=int, default=256) - parser.add_argument("--use_mlp", action="store_true") - parser.add_argument("--meta_dir", default=".", type=str, help="path to meta.bin for imagenet") - - return parser - - @staticmethod - def _use_ddp(trainer: Trainer) -> bool: - return isinstance(trainer.strategy, DDPStrategy) - - -# utils -@torch.no_grad() -@under_review() -def concat_all_gather(tensor): - """Performs all_gather operation on the provided tensors. - - *** Warning ***: torch.distributed.all_gather has no gradient. - """ - tensors_gather = [torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size())] - torch.distributed.all_gather(tensors_gather, tensor, async_op=False) - - return torch.cat(tensors_gather, dim=0) - - -@under_review() -def cli_main(): - from pl_bolts.datamodules import CIFAR10DataModule, SSLImagenetDataModule, STL10DataModule - - parser = ArgumentParser() - - # trainer args - parser = Trainer.add_argparse_args(parser) - - # model args - parser = Moco_v2.add_model_specific_args(parser) - args = parser.parse_args() - - if args.dataset == "cifar10": - datamodule = CIFAR10DataModule.from_argparse_args(args) - datamodule.train_transforms = MoCo2TrainCIFAR10Transforms() - datamodule.val_transforms = MoCo2EvalCIFAR10Transforms() - - elif args.dataset == "stl10": - datamodule = STL10DataModule.from_argparse_args(args) - datamodule.train_dataloader = datamodule.train_dataloader_mixed - datamodule.val_dataloader = datamodule.val_dataloader_mixed - datamodule.train_transforms = MoCo2TrainSTL10Transforms() - datamodule.val_transforms = MoCo2EvalSTL10Transforms() - - elif args.dataset == "imagenet2012": - datamodule = SSLImagenetDataModule.from_argparse_args(args) - datamodule.train_transforms = MoCo2TrainImagenetTransforms() - datamodule.val_transforms = MoCo2EvalImagenetTransforms() - - else: - # replace with your own dataset, otherwise CIFAR-10 will be used by default if `None` passed in - datamodule = None - - model = Moco_v2(**args.__dict__) - - trainer = Trainer.from_argparse_args(args) - trainer.fit(model, datamodule=datamodule) - - -if __name__ == "__main__": - cli_main() diff --git a/src/pl_bolts/models/self_supervised/moco/moco_module.py b/src/pl_bolts/models/self_supervised/moco/moco_module.py new file mode 100644 index 0000000000..cb9ddbdde1 --- /dev/null +++ b/src/pl_bolts/models/self_supervised/moco/moco_module.py @@ -0,0 +1,328 @@ +"""Adapted from: https://github.com/facebookresearch/moco. + +Original work is: Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved This implementation is: Copyright +(c) PyTorch Lightning, Inc. and its affiliates. All Rights Reserved + +This implementation is licensed under Attribution-NonCommercial 4.0 International; You may not use this file except in +compliance with the License. + +You may obtain a copy of the License from the LICENSE file present in this folder. +""" +from copy import copy, deepcopy +from typing import Any, Dict, List, Optional, Tuple, Type, Union + +import torch +from pytorch_lightning import LightningModule +from pytorch_lightning.strategies import DDPStrategy +from pytorch_lightning.utilities.types import STEP_OUTPUT +from torch import Tensor, nn, optim +from torch.nn import functional as F # noqa: N812 +from torch.utils.data import DataLoader, Dataset + +# It seems to be impossible to avoid mypy errors if using import instead of getattr(). +# See https://github.com/python/mypy/issues/8823 +try: + LRScheduler: Any = getattr(optim.lr_scheduler, "LRScheduler") +except AttributeError: + LRScheduler = getattr(optim.lr_scheduler, "_LRScheduler") + +from pl_bolts.datamodules import CIFAR10DataModule +from pl_bolts.metrics import precision_at_k +from pl_bolts.models.self_supervised.moco.utils import concatenate_all, shuffle_batch, sort_batch, validate_batch +from pl_bolts.transforms.self_supervised.moco_transforms import ( + MoCo2EvalCIFAR10Transforms, + MoCo2TrainCIFAR10Transforms, +) +from pl_bolts.utils import _TORCHVISION_AVAILABLE +from pl_bolts.utils.warnings import warn_missing_pkg + +if _TORCHVISION_AVAILABLE: + import torchvision +else: # pragma: no cover + warn_missing_pkg("torchvision") + + +class RepresentationQueue(nn.Module): + """The queue is implemented as list of representations and a pointer to the location where the next batch of + representations will be overwritten.""" + + def __init__(self, representation_size: int, queue_size: int): + super().__init__() + + self.representations: Tensor + self.register_buffer("representations", torch.randn(representation_size, queue_size)) + self.representations = nn.functional.normalize(self.representations, dim=0) + + self.pointer: Tensor + self.register_buffer("pointer", torch.zeros([], dtype=torch.long)) + + @torch.no_grad() + def dequeue_and_enqueue(self, x: Tensor) -> None: + """Replaces representations in the queue, starting at the current queue pointer, and advances the pointer. + + Args: + x: A mini-batch of representations. The queue size has to be a multiple of the total number of + representations across all devices. + """ + # Gather representations from all GPUs into a [batch_size * world_size, num_features] tensor, in case of + # distributed training. + if torch.distributed.is_available() and torch.distributed.is_initialized(): + x = concatenate_all(x) + + queue_size = self.representations.shape[1] + batch_size = x.shape[0] + if queue_size % batch_size != 0: + raise ValueError(f"Queue size ({queue_size}) is not a multiple of the batch size ({batch_size}).") + + end = self.pointer + batch_size + self.representations[:, int(self.pointer) : int(end)] = x.T + self.pointer = end % queue_size + + +class MoCo(LightningModule): + def __init__( + self, + encoder: Union[str, nn.Module] = "resnet18", + head: Optional[nn.Module] = None, + representation_size: int = 128, + num_negatives: int = 65536, + encoder_momentum: float = 0.999, + temperature: float = 0.07, + exclude_bn_bias: bool = False, + optimizer: Type[optim.Optimizer] = optim.SGD, + optimizer_params: Optional[Dict[str, Any]] = None, + lr_scheduler: Type[LRScheduler] = optim.lr_scheduler.CosineAnnealingLR, + lr_scheduler_params: Optional[Dict[str, Any]] = None, + ) -> None: + """A module that trains an encoder using Momentum Contrast. + + *MoCo paper*: `Kaiming He, Haoqi Fan, Yuxin Wu, Saining Xie, and Ross Girshick + `_ + + *Moco v2 paper*: `Xinlei Chen, Haoqi Fan, Ross Girshick, and Kaiming He `_ + + *Adapted from `facebookresearch/moco `_ to Lightning by*: + `William Falcon `_ + + *Refactored by*: `Seppo Enarvi `_ + + Example:: + from pl_bolts.models.self_supervised import MoCo + model = MoCo() + trainer = Trainer() + trainer.fit(model) + + CLI command:: + python moco_module.py fit \ + --data.data_dir /path/to/imagenet \ + --data.batch_size 32 \ + --data.num_workers 4 \ + --trainer.accelerator gpu \ + --trainer.devices 8 + + Args: + encoder: The encoder module. Either a Torchvision model name or a ``torch.nn.Module``. + head: An optional projection head that will be appended to the encoder during training. + representation_size: Size of a feature vector produced by the projection head (or in case a projection head + is not used, the encoder). + num_negatives: Number of negative examples to be kept in the queue. + encoder_momentum: Momentum for updating the key encoder. + temperature: The temperature parameter for the MoCo loss. + exclude_bn_bias: If ``True``, weight decay will be applied only to convolutional layer weights. + optimizer: Which optimizer class to use for training. + optimizer_params: Parameters to pass to the optimizer constructor. + lr_scheduler: Which learning rate scheduler class to use for training. + lr_scheduler_params: Parameters to pass to the learning rate scheduler constructor. + """ + super().__init__() + + self.num_negatives = num_negatives + self.encoder_momentum = encoder_momentum + self.temperature = temperature + self.exclude_bn_bias = exclude_bn_bias + self.optimizer_class = optimizer + if optimizer_params is not None: + self.optimizer_params = optimizer_params + else: + self.optimizer_params = {"lr": 0.03, "momentum": 0.9, "weight_decay": 1e-4} + self.lr_scheduler_class = lr_scheduler + if lr_scheduler_params is not None: + self.lr_scheduler_params = lr_scheduler_params + else: + self.lr_scheduler_params = {"T_max": 100} + + if isinstance(encoder, str): + template_model = getattr(torchvision.models, encoder) + self.encoder_q = template_model(num_classes=representation_size) + else: + self.encoder_q = encoder + self.encoder_k = deepcopy(self.encoder_q) + for param in self.encoder_k.parameters(): + param.requires_grad = False + + if head is not None: + self.head_q: Optional[nn.Module] = head + self.head_k: Optional[nn.Module] = deepcopy(head) + for param in self.head_k.parameters(): + param.requires_grad = False + else: + self.head_q = None + self.head_k = None + + # Two different queues of representations are needed, one for training and one for validation data. + self.queue = RepresentationQueue(representation_size, num_negatives) + self.val_queue = RepresentationQueue(representation_size, num_negatives) + + def forward(self, query_images: Tensor, key_images: Tensor) -> Tuple[Tensor, Tensor]: + """Computes the forward passes of both encoders and projection heads. + + Args: + query_images: A mini-batch of query images in a ``[batch_size, num_channels, height, width]`` tensor. + key_images: A mini-batch of key images in a ``[batch_size, num_channels, height, width]`` tensor. + + Returns: + A tuple of query and key representations. + """ + q = self.encoder_q(query_images) + if self.head_q is not None: + q = self.head_q(q) + q = nn.functional.normalize(q, dim=1) + + with torch.no_grad(): + # The keys are shuffled between the GPUs before encoding them, to avoid batch normalization leaking + # information between the samples. This works only when using the DDP strategy. + if isinstance(self.trainer.strategy, DDPStrategy): + key_images, original_order = shuffle_batch(key_images) + + k = self.encoder_k(key_images) + if self.head_k is not None: + k = self.head_k(k) + k = nn.functional.normalize(k, dim=1) + + if isinstance(self.trainer.strategy, DDPStrategy): + k = sort_batch(k, original_order) + + return q, k + + def training_step(self, batch: Tuple[List[List[Tensor]], List[Any]], batch_idx: int) -> STEP_OUTPUT: + images = validate_batch(batch) + self._momentum_update_key_encoder() + loss, acc1, acc5 = self._calculate_loss(images, self.queue) + self.log("train/loss", loss, sync_dist=True) + self.log("train/acc1", acc1, sync_dist=True) + self.log("train/acc5", acc5, sync_dist=True) + return {"loss": loss} + + def validation_step(self, batch: Tuple[List[List[Tensor]], List[Any]], batch_idx: int) -> Optional[STEP_OUTPUT]: + images = validate_batch(batch) + loss, acc1, acc5 = self._calculate_loss(images, self.val_queue) + self.log("val/loss", loss, sync_dist=True) + self.log("val/acc1", acc1, sync_dist=True) + self.log("val/acc5", acc5, sync_dist=True) + + def configure_optimizers(self) -> Tuple[List[optim.Optimizer], List[optim.lr_scheduler._LRScheduler]]: + """Constructs the optimizer and learning rate scheduler based on ``self.optimizer_params`` and + ``self.lr_scheduler_params``. + + If weight decay is specified, it will be applied only to convolutional layer weights. + """ + if ( + ("weight_decay" in self.optimizer_params) + and (self.optimizer_params["weight_decay"] != 0) + and self.exclude_bn_bias + ): + defaults = copy(self.optimizer_params) + weight_decay = defaults.pop("weight_decay") + + wd_group = [] + nowd_group = [] + for name, tensor in self.named_parameters(): + if not tensor.requires_grad: + continue + if ("bias" in name) or ("bn" in name): + nowd_group.append(tensor) + else: + wd_group.append(tensor) + + params = [ + {"params": wd_group, "weight_decay": weight_decay}, + {"params": nowd_group, "weight_decay": 0.0}, + ] + optimizer = self.optimizer_class(params, **defaults) + else: + optimizer = self.optimizer_class(self.parameters(), **self.optimizer_params) + lr_scheduler = self.lr_scheduler_class(optimizer, **self.lr_scheduler_params) + return [optimizer], [lr_scheduler] + + @torch.no_grad() + def _momentum_update_key_encoder(self) -> None: + """Momentum update of the key encoder.""" + momentum = self.encoder_momentum + for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()): + param_k.data = param_k.data * momentum + param_q.data * (1.0 - momentum) + + def _calculate_loss(self, images: Tensor, queue: RepresentationQueue) -> Tuple[Tensor, Tensor, Tensor]: + """Calculates the normalized temperature-scaled cross entropy loss from a mini-batch of image pairs. + + Args: + images: A mini-batch of image pairs in a ``[batch_size, 2, num_channels, height, width]`` tensor. + queue: The queue that the query representations will be compared against. The key representations will be + added to the queue. + """ + if images.size(1) != 2: + raise ValueError( + f"MoCo expects two transformations of every image. Got {images.size(1)} transformations of an image." + ) + + query_images = images[:, 0] + key_images = images[:, 1] + q, k = self(query_images, key_images) + + # Concatenate logits from the positive pairs (batch_size x 1) and the negative pairs (batch_size x queue_size). + pos_logits = torch.einsum("nc,nc->n", [q, k]).unsqueeze(-1) + neg_logits = torch.einsum("nc,ck->nk", [q, queue.representations.clone().detach()]) + logits = torch.cat([pos_logits, neg_logits], dim=1) + logits /= self.temperature + + # The correct label for every query is 0. Calculate the cross entropy of classifying each query correctly. + target_idxs = torch.zeros(logits.shape[0], dtype=torch.long).type_as(logits) + loss = F.cross_entropy(logits, target_idxs.long()) + acc1, acc5 = precision_at_k(logits, target_idxs, top_k=(1, 5)) + + queue.dequeue_and_enqueue(k) + return loss, acc1, acc5 + + +def collate(samples: List[Tuple[Tuple[Tensor, Tensor], int]]) -> Tuple[List[Tuple[Tensor, Tensor]], List[int]]: + return tuple(zip(*samples)) # type: ignore + + +class CIFAR10ContrastiveDataModule(CIFAR10DataModule): + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__( + *args, + train_transforms=MoCo2TrainCIFAR10Transforms(), + val_transforms=MoCo2EvalCIFAR10Transforms(), + **kwargs, + ) + + def _data_loader(self, dataset: Dataset, shuffle: bool = False) -> DataLoader: + return DataLoader( + dataset, + batch_size=self.batch_size, + shuffle=shuffle, + num_workers=self.num_workers, + drop_last=self.drop_last, + pin_memory=self.pin_memory, + collate_fn=collate, + ) + + +def cli_main() -> None: + from pytorch_lightning.cli import LightningCLI + + LightningCLI(MoCo, CIFAR10ContrastiveDataModule, seed_everything_default=42) + + +if __name__ == "__main__": + cli_main() diff --git a/src/pl_bolts/models/self_supervised/moco/utils.py b/src/pl_bolts/models/self_supervised/moco/utils.py new file mode 100644 index 0000000000..ef12e633e1 --- /dev/null +++ b/src/pl_bolts/models/self_supervised/moco/utils.py @@ -0,0 +1,163 @@ +from typing import Any, List, Tuple + +import torch +from torch import Tensor + + +def validate_batch(batch: Tuple[List[List[Tensor]], List[Any]]) -> Tensor: + """Reads a batch of data, validates the format, and stacks the images into a single tensor. + + Contrastive SSL models expect each image in the batch to be transformed in multiple (typically two) ways. The input + is similar to image classification and object detection models, but a tuple of images is expected in place of each + image. + + Args: + batch: The batch of data read by the :class:`~torch.utils.data.DataLoader`. A tuple containing a nested list of + ``N`` image pairs (or tuples of more than two images) and a list of ``N`` target dictionaries. + + Returns: + The input batch with images stacked into a single ``[N, 2, channels, height, width]`` tensor. + """ + images, targets = batch + + if not images: + raise ValueError("No images in batch.") + + batch_size = len(images) + if batch_size != len(targets): + raise ValueError(f"Got {batch_size} image pairs, but {len(targets)} targets.") + + image_transforms = images[0] + + if isinstance(image_transforms, Tensor): + if image_transforms.ndim != 4: + raise ValueError( + "Contrastive training expects the transformed images as a tuple, a list, or a 4-dimensional tensor. " + f"Got a tensor with {image_transforms.ndim} dimensions." + ) + shape = image_transforms.shape + for image_transforms in images[1:]: + if not isinstance(image_transforms, Tensor): + raise ValueError(f"Expected transformed images in a tensor, got {type(image_transforms)}.") + if image_transforms.shape != shape: + raise ValueError( + f"Different shapes for transformed images in one batch: {shape} and {image_transforms.shape}" + ) + + return torch.stack(images) + + if isinstance(image_transforms, (tuple, list)): + num_transforms = len(image_transforms) + if num_transforms < 2: + raise ValueError( + f"Contrastive training expects at least two transformations of every image, got {num_transforms}." + ) + if not isinstance(image_transforms[0], Tensor): + raise ValueError(f"Expected image to be of type Tensor, got {type(image_transforms[0]).__name__}.") + shape = image_transforms[0].shape + for image_transforms in images: + for image in image_transforms: + if not isinstance(image, Tensor): + raise ValueError(f"Expected image to be of type Tensor, got {type(image).__name__}.") + if image.shape != shape: + raise ValueError(f"Images with different shapes in one batch: {shape} and {image.shape}") + + # PyTorch doesn't stack nested lists of tensors. Stacking the tensors in two steps would cause the data to be + # copied twice, so instead we'll first flatten the hierarchy and then reshape in the end. + flat_images = [image for image_transforms in images for image in image_transforms] + flat_images = torch.stack(flat_images) # [batch_size * num_transforms, channels, height, width] + return flat_images.view(batch_size, num_transforms, *shape) + + raise ValueError( + "Contrastive training expects the transformed images as a tuple, a list, or a 4-dimensional tensor. Got " + f"{type(image_transforms).__name__}." + ) + + +class ConcatenateAll(torch.autograd.Function): + @staticmethod + def forward(ctx: Any, tensor: Tensor) -> Tensor: # type: ignore + """Concatenates tensors from all GPUs.""" + ctx.batch_size = tensor.shape[0] + + gathered_tensor = [torch.zeros_like(tensor) for _ in range(torch.distributed.get_world_size())] + torch.distributed.all_gather(gathered_tensor, tensor.contiguous()) + return torch.cat(gathered_tensor, 0) + + @staticmethod + def backward(ctx: Any, grad_output: Tensor) -> Tensor: # type: ignore + """Sums the gradients from all GPUs and takes the ones corresponding to our mini-batch.""" + start_idx = torch.distributed.get_rank() * ctx.batch_size + stop_idx = start_idx + ctx.batch_size + + grad_input = grad_output.clone().contiguous() + torch.distributed.all_reduce(grad_input, op=torch.distributed.ReduceOp.SUM) + return grad_input[start_idx:stop_idx] + + +@torch.no_grad() +def concatenate_all(tensor: Tensor) -> Tensor: + """Performs ``all_gather`` operation to concatenate the provided tensor from all devices. + + This function has no gradient. + """ + gathered_tensor = [torch.zeros_like(tensor) for _ in range(torch.distributed.get_world_size())] + torch.distributed.all_gather(gathered_tensor, tensor.contiguous()) + return torch.cat(gathered_tensor, 0) + + +@torch.no_grad() +def shuffle_batch(x: Tensor) -> Tuple[Tensor, Tensor]: + """Redistributes the batch randomly to different devices. + + Gathers a mini-batch from all devices and shuffles it into a random order. Each device will receive a random subset + of the mini-batch. Only support Distributed Data Parallel (DDP) training strategy. + + Args: + x: The input tensor, whose first dimension is the batch. + + Returns: + The output tensor and a list of indices that gives the original order of the combined mini-batch. The output + tensor is the same size as the input tensor, but contains a random subset of the combined mini-batch. + """ + all_x = concatenate_all(x) + + local_batch_size = x.shape[0] + global_batch_size = all_x.shape[0] + num_gpus = global_batch_size // local_batch_size + + # Create a random ordering of the images in all GPUs and broadcast it from rank 0 to the other GPUs. + random_order = torch.randperm(global_batch_size).cuda() + torch.distributed.broadcast(random_order, src=0) + + # Save a mapping from the shuffled order back to the linear order. + original_order = torch.argsort(random_order) + + rank = torch.distributed.get_rank() + local_idxs = random_order.view(num_gpus, -1)[rank] + return all_x[local_idxs], original_order + + +@torch.no_grad() +def sort_batch(x: Tensor, order: Tensor) -> Tensor: + """Sorts the samples across devices into given order. + + Gathers a mini-batch from all devices and sorts it into given order. Each device will receive a consecutive subset + of the mini-batch. Only support Distributed Data Parallel (DDP) training strategy. + + Args: + x: The input tensor, whose first dimension is the batch. + order: Indices to the combined mini-batch in the correct order. + + Returns: + The subset of the combined mini-batch that corresponds to this device. + """ + all_x = concatenate_all(x) + + local_batch_size = x.shape[0] + global_batch_size = all_x.shape[0] + num_gpus = global_batch_size // local_batch_size + + rank = torch.distributed.get_rank() + local_idxs = order.view(num_gpus, -1)[rank] + return all_x[local_idxs] diff --git a/tests/models/self_supervised/test_models.py b/tests/models/self_supervised/test_models.py index 0605b9c844..d5c0108486 100644 --- a/tests/models/self_supervised/test_models.py +++ b/tests/models/self_supervised/test_models.py @@ -3,9 +3,9 @@ import pytest import torch from pl_bolts.datamodules import CIFAR10DataModule -from pl_bolts.models.self_supervised import AMDIM, BYOL, CPC_v2, Moco_v2, SimCLR, SimSiam, SwAV +from pl_bolts.models.self_supervised import AMDIM, BYOL, CPC_v2, MoCo, SimCLR, SimSiam, SwAV from pl_bolts.models.self_supervised.cpc import CPCEvalTransformsCIFAR10, CPCTrainTransformsCIFAR10 -from pl_bolts.models.self_supervised.moco.callbacks import MocoLRScheduler +from pl_bolts.models.self_supervised.moco.callbacks import MoCoLRScheduler from pl_bolts.transforms.dataset_normalizations import cifar10_normalization from pl_bolts.transforms.self_supervised.moco_transforms import MoCo2EvalCIFAR10Transforms, MoCo2TrainCIFAR10Transforms from pl_bolts.transforms.self_supervised.simclr_transforms import SimCLREvalDataTransform, SimCLRTrainDataTransform @@ -79,8 +79,8 @@ def test_moco(tmpdir, datadir): datamodule.train_transforms = MoCo2TrainCIFAR10Transforms() datamodule.val_transforms = MoCo2EvalCIFAR10Transforms() - model = Moco_v2(data_dir=datadir, batch_size=2, online_ft=True) - trainer = Trainer(fast_dev_run=True, default_root_dir=tmpdir, callbacks=[MocoLRScheduler()]) + model = MoCo() + trainer = Trainer(fast_dev_run=True, default_root_dir=tmpdir, callbacks=[MoCoLRScheduler()]) trainer.fit(model, datamodule=datamodule) diff --git a/tests/models/test_scripts.py b/tests/models/test_scripts.py index 9cdad7284d..e4124cfd65 100644 --- a/tests/models/test_scripts.py +++ b/tests/models/test_scripts.py @@ -86,8 +86,8 @@ ], ), pytest.param( - "models.self_supervised.moco.moco2_module", - _DEFAULT_ARGS + _ARG_WORKERS_0 + _ARG_GPUS, + "models.self_supervised.moco.moco_module", + _DEFAULT_LIGHTNING_CLI_ARGS + f" --data.num_workers=0 --trainer.gpus={int(torch.cuda.is_available())}", marks=pytest.mark.skipif(**_MARK_REQUIRE_GPU), ), pytest.param(