diff --git a/recognition/arcface_torch/README.md b/recognition/arcface_torch/README.md index 1ed55b079..6899b4671 100644 --- a/recognition/arcface_torch/README.md +++ b/recognition/arcface_torch/README.md @@ -1,7 +1,9 @@ # Distributed Arcface Training in Pytorch This is a deep learning library that makes face recognition efficient, and effective, which can train tens of millions -identity on a single server. +identity on a single server. + +[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/killing-two-birds-with-one-stone-efficient/face-verification-on-ijb-c)](https://paperswithcode.com/sota/face-verification-on-ijb-c?p=killing-two-birds-with-one-stone-efficient) ## Requirements @@ -38,8 +40,12 @@ Node 1: python -m torch.distributed.launch --nproc_per_node=8 --nnodes=2 --node_rank=1 --master_addr="ip1" --master_port=12581 train.py configs/webface42m_r100_lr01_pfc02_bs4k_16gpus ``` -config.num_classes = 85742 -config.num_image = 5822653 +### 3. Run ViT-B on a machine with 24k batchsize: + +```shell +python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr="127.0.0.1" --master_port=12345 train_v2.py configs/wf42m_pfc03_40epoch_8gpu_vit_b.py +``` + ## Download Datasets or Prepare Datasets - [MS1MV2](https://github.com/deepinsight/insightface/tree/master/recognition/_datasets_#ms1m-arcface-85k-ids58m-images-57) (87k IDs, 5.8M images) @@ -83,6 +89,7 @@ globalised multi-racial testset contains 242,143 identities and 1,624,305 images | WF12M | r100 | 94.69 | 97.59 | 95.97 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/wf12m_r100/training.log) | | WF42M-PFC-0.2 | r100 | 96.27 | 97.70 | 96.31 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/wf42m_pfc02_r100/training.log) | | WF42M-PFC-0.2 | ViT-T-1.5G | 92.04 | 97.27 | 95.68 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/wf42m_pfc02_40epoch_8gpu_vit_t/training.log) | +| WF42M-PFC-0.3 | ViT-B-11G | 97.16 | 97.91 | 97.05 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/pfc03_wf42m_vit_b_8gpu/training.log) | #### 2. Training on Multi-Host GPU diff --git a/recognition/arcface_torch/configs/base.py b/recognition/arcface_torch/configs/base.py index 17a369bb9..3c2e307c9 100644 --- a/recognition/arcface_torch/configs/base.py +++ b/recognition/arcface_torch/configs/base.py @@ -39,6 +39,8 @@ # For Large Sacle Dataset, such as WebFace42M config.dali = False +# Gradient ACC +config.gradient_acc = 1 # setup seed config.seed = 2048 diff --git a/recognition/arcface_torch/configs/wf42m_pfc03_40epoch_8gpu_vit_b.py b/recognition/arcface_torch/configs/wf42m_pfc03_40epoch_8gpu_vit_b.py new file mode 100644 index 000000000..37105d455 --- /dev/null +++ b/recognition/arcface_torch/configs/wf42m_pfc03_40epoch_8gpu_vit_b.py @@ -0,0 +1,28 @@ +from easydict import EasyDict as edict + +# make training faster +# our RAM is 256G +# mount -t tmpfs -o size=140G tmpfs /train_tmp + +config = edict() +config.margin_list = (1.0, 0.0, 0.4) +config.network = "vit_b_dp005_mask_005" +config.resume = False +config.output = None +config.embedding_size = 512 +config.sample_rate = 0.3 +config.fp16 = True +config.weight_decay = 0.1 +config.batch_size = 256 +config.gradient_acc = 12 # total batchsize is 256 * 12 +config.optimizer = "adamw" +config.lr = 0.001 +config.verbose = 2000 +config.dali = False + +config.rec = "/train_tmp/WebFace42M" +config.num_classes = 2059906 +config.num_image = 42474557 +config.num_epoch = 40 +config.warmup_epoch = config.num_epoch // 10 +config.val_targets = [] diff --git a/recognition/arcface_torch/dataset.py b/recognition/arcface_torch/dataset.py index c03725e4f..f1b51797f 100644 --- a/recognition/arcface_torch/dataset.py +++ b/recognition/arcface_torch/dataset.py @@ -32,6 +32,7 @@ def get_dataloader( # Synthetic if root_dir == "synthetic": train_set = SyntheticDataset() + dali = False # Mxnet RecordIO elif os.path.exists(rec) and os.path.exists(idx): diff --git a/recognition/arcface_torch/partial_fc_v2.py b/recognition/arcface_torch/partial_fc_v2.py new file mode 100644 index 000000000..0752554ca --- /dev/null +++ b/recognition/arcface_torch/partial_fc_v2.py @@ -0,0 +1,260 @@ + +import math +from typing import Callable + +import torch +from torch import distributed +from torch.nn.functional import linear, normalize + + +class PartialFC_V2(torch.nn.Module): + """ + https://arxiv.org/abs/2203.15565 + A distributed sparsely updating variant of the FC layer, named Partial FC (PFC). + When sample rate less than 1, in each iteration, positive class centers and a random subset of + negative class centers are selected to compute the margin-based softmax loss, all class + centers are still maintained throughout the whole training process, but only a subset is + selected and updated in each iteration. + .. note:: + When sample rate equal to 1, Partial FC is equal to model parallelism(default sample rate is 1). + Example: + -------- + >>> module_pfc = PartialFC(embedding_size=512, num_classes=8000000, sample_rate=0.2) + >>> for img, labels in data_loader: + >>> embeddings = net(img) + >>> loss = module_pfc(embeddings, labels) + >>> loss.backward() + >>> optimizer.step() + """ + _version = 2 + + def __init__( + self, + margin_loss: Callable, + embedding_size: int, + num_classes: int, + sample_rate: float = 1.0, + fp16: bool = False, + ): + """ + Paramenters: + ----------- + embedding_size: int + The dimension of embedding, required + num_classes: int + Total number of classes, required + sample_rate: float + The rate of negative centers participating in the calculation, default is 1.0. + """ + super(PartialFC_V2, self).__init__() + assert ( + distributed.is_initialized() + ), "must initialize distributed before create this" + self.rank = distributed.get_rank() + self.world_size = distributed.get_world_size() + + self.dist_cross_entropy = DistCrossEntropy() + self.embedding_size = embedding_size + self.sample_rate: float = sample_rate + self.fp16 = fp16 + self.num_local: int = num_classes // self.world_size + int( + self.rank < num_classes % self.world_size + ) + self.class_start: int = num_classes // self.world_size * self.rank + min( + self.rank, num_classes % self.world_size + ) + self.num_sample: int = int(self.sample_rate * self.num_local) + self.last_batch_size: int = 0 + + self.is_updated: bool = True + self.init_weight_update: bool = True + self.weight = torch.nn.Parameter(torch.normal(0, 0.01, (self.num_local, embedding_size))) + + # margin_loss + if isinstance(margin_loss, Callable): + self.margin_softmax = margin_loss + else: + raise + + def sample(self, labels, index_positive): + """ + This functions will change the value of labels + Parameters: + ----------- + labels: torch.Tensor + pass + index_positive: torch.Tensor + pass + optimizer: torch.optim.Optimizer + pass + """ + with torch.no_grad(): + positive = torch.unique(labels[index_positive], sorted=True).cuda() + if self.num_sample - positive.size(0) >= 0: + perm = torch.rand(size=[self.num_local]).cuda() + perm[positive] = 2.0 + index = torch.topk(perm, k=self.num_sample)[1].cuda() + index = index.sort()[0].cuda() + else: + index = positive + self.weight_index = index + + labels[index_positive] = torch.searchsorted(index, labels[index_positive]) + + return self.weight[self.weight_index] + + def forward( + self, + local_embeddings: torch.Tensor, + local_labels: torch.Tensor, + ): + """ + Parameters: + ---------- + local_embeddings: torch.Tensor + feature embeddings on each GPU(Rank). + local_labels: torch.Tensor + labels on each GPU(Rank). + Returns: + ------- + loss: torch.Tensor + pass + """ + local_labels.squeeze_() + local_labels = local_labels.long() + + batch_size = local_embeddings.size(0) + if self.last_batch_size == 0: + self.last_batch_size = batch_size + assert self.last_batch_size == batch_size, ( + f"last batch size do not equal current batch size: {self.last_batch_size} vs {batch_size}") + + _gather_embeddings = [ + torch.zeros((batch_size, self.embedding_size)).cuda() + for _ in range(self.world_size) + ] + _gather_labels = [ + torch.zeros(batch_size).long().cuda() for _ in range(self.world_size) + ] + _list_embeddings = AllGather(local_embeddings, *_gather_embeddings) + distributed.all_gather(_gather_labels, local_labels) + + embeddings = torch.cat(_list_embeddings) + labels = torch.cat(_gather_labels) + + labels = labels.view(-1, 1) + index_positive = (self.class_start <= labels) & ( + labels < self.class_start + self.num_local + ) + labels[~index_positive] = -1 + labels[index_positive] -= self.class_start + + if self.sample_rate < 1: + weight = self.sample(labels, index_positive) + else: + weight = self.weight + + with torch.cuda.amp.autocast(self.fp16): + norm_embeddings = normalize(embeddings) + norm_weight_activated = normalize(weight) + logits = linear(norm_embeddings, norm_weight_activated) + if self.fp16: + logits = logits.float() + logits = logits.clamp(-1, 1) + + logits = self.margin_softmax(logits, labels) + loss = self.dist_cross_entropy(logits, labels) + return loss + + +class DistCrossEntropyFunc(torch.autograd.Function): + """ + CrossEntropy loss is calculated in parallel, allreduce denominator into single gpu and calculate softmax. + Implemented of ArcFace (https://arxiv.org/pdf/1801.07698v1.pdf): + """ + + @staticmethod + def forward(ctx, logits: torch.Tensor, label: torch.Tensor): + """ """ + batch_size = logits.size(0) + # for numerical stability + max_logits, _ = torch.max(logits, dim=1, keepdim=True) + # local to global + distributed.all_reduce(max_logits, distributed.ReduceOp.MAX) + logits.sub_(max_logits) + logits.exp_() + sum_logits_exp = torch.sum(logits, dim=1, keepdim=True) + # local to global + distributed.all_reduce(sum_logits_exp, distributed.ReduceOp.SUM) + logits.div_(sum_logits_exp) + index = torch.where(label != -1)[0] + # loss + loss = torch.zeros(batch_size, 1, device=logits.device) + loss[index] = logits[index].gather(1, label[index]) + distributed.all_reduce(loss, distributed.ReduceOp.SUM) + ctx.save_for_backward(index, logits, label) + return loss.clamp_min_(1e-30).log_().mean() * (-1) + + @staticmethod + def backward(ctx, loss_gradient): + """ + Args: + loss_grad (torch.Tensor): gradient backward by last layer + Returns: + gradients for each input in forward function + `None` gradients for one-hot label + """ + ( + index, + logits, + label, + ) = ctx.saved_tensors + batch_size = logits.size(0) + one_hot = torch.zeros( + size=[index.size(0), logits.size(1)], device=logits.device + ) + one_hot.scatter_(1, label[index], 1) + logits[index] -= one_hot + logits.div_(batch_size) + return logits * loss_gradient.item(), None + + +class DistCrossEntropy(torch.nn.Module): + def __init__(self): + super(DistCrossEntropy, self).__init__() + + def forward(self, logit_part, label_part): + return DistCrossEntropyFunc.apply(logit_part, label_part) + + +class AllGatherFunc(torch.autograd.Function): + """AllGather op with gradient backward""" + + @staticmethod + def forward(ctx, tensor, *gather_list): + gather_list = list(gather_list) + distributed.all_gather(gather_list, tensor) + return tuple(gather_list) + + @staticmethod + def backward(ctx, *grads): + grad_list = list(grads) + rank = distributed.get_rank() + grad_out = grad_list[rank] + + dist_ops = [ + distributed.reduce(grad_out, rank, distributed.ReduceOp.SUM, async_op=True) + if i == rank + else distributed.reduce( + grad_list[i], i, distributed.ReduceOp.SUM, async_op=True + ) + for i in range(distributed.get_world_size()) + ] + for _op in dist_ops: + _op.wait() + + grad_out *= len(grad_list) # cooperate with distributed loss function + return (grad_out, *[None for _ in range(len(grad_list))]) + + +AllGather = AllGatherFunc.apply diff --git a/recognition/arcface_torch/train_v2.py b/recognition/arcface_torch/train_v2.py new file mode 100755 index 000000000..c41695431 --- /dev/null +++ b/recognition/arcface_torch/train_v2.py @@ -0,0 +1,209 @@ +import argparse +import logging +import os + +import numpy as np +import torch +from torch import distributed +from torch.utils.data import DataLoader +from torch.utils.tensorboard import SummaryWriter + +from backbones import get_model +from dataset import get_dataloader +from losses import CombinedMarginLoss +from lr_scheduler import PolyScheduler +from partial_fc_v2 import PartialFC_V2 +from utils.utils_callbacks import CallBackLogging, CallBackVerification +from utils.utils_config import get_config +from utils.utils_logging import AverageMeter, init_logging +from utils.utils_distributed_sampler import setup_seed + +assert torch.__version__ >= "1.9.0", "In order to enjoy the features of the new torch, \ +we have upgraded the torch to 1.9.0. torch before than 1.9.0 may not work in the future." + +try: + world_size = int(os.environ["WORLD_SIZE"]) + rank = int(os.environ["RANK"]) + distributed.init_process_group("nccl") +except KeyError: + world_size = 1 + rank = 0 + distributed.init_process_group( + backend="nccl", + init_method="tcp://127.0.0.1:12584", + rank=rank, + world_size=world_size, + ) + + +def main(args): + + # get config + cfg = get_config(args.config) + # global control random seed + setup_seed(seed=cfg.seed, cuda_deterministic=False) + + torch.cuda.set_device(args.local_rank) + + os.makedirs(cfg.output, exist_ok=True) + init_logging(rank, cfg.output) + + summary_writer = ( + SummaryWriter(log_dir=os.path.join(cfg.output, "tensorboard")) + if rank == 0 + else None + ) + + train_loader = get_dataloader( + cfg.rec, + args.local_rank, + cfg.batch_size, + cfg.dali, + cfg.seed, + cfg.num_workers + ) + + backbone = get_model( + cfg.network, dropout=0.0, fp16=cfg.fp16, num_features=cfg.embedding_size).cuda() + + backbone = torch.nn.parallel.DistributedDataParallel( + module=backbone, broadcast_buffers=False, device_ids=[args.local_rank], bucket_cap_mb=16, + find_unused_parameters=True) + + backbone.train() + # FIXME using gradient checkpoint if there are some unused parameters will cause error + backbone._set_static_graph() + + margin_loss = CombinedMarginLoss( + 64, + cfg.margin_list[0], + cfg.margin_list[1], + cfg.margin_list[2], + cfg.interclass_filtering_threshold + ) + + if cfg.optimizer == "sgd": + module_partial_fc = PartialFC_V2( + margin_loss, cfg.embedding_size, cfg.num_classes, + cfg.sample_rate, cfg.fp16) + module_partial_fc.train().cuda() + # TODO the params of partial fc must be last in the params list + opt = torch.optim.SGD( + params=[{"params": backbone.parameters()}, {"params": module_partial_fc.parameters()}], + lr=cfg.lr, momentum=0.9, weight_decay=cfg.weight_decay) + + elif cfg.optimizer == "adamw": + module_partial_fc = PartialFC_V2( + margin_loss, cfg.embedding_size, cfg.num_classes, + cfg.sample_rate, cfg.fp16) + module_partial_fc.train().cuda() + opt = torch.optim.AdamW( + params=[{"params": backbone.parameters()}, {"params": module_partial_fc.parameters()}], + lr=cfg.lr, weight_decay=cfg.weight_decay) + else: + raise + + cfg.total_batch_size = cfg.batch_size * world_size + cfg.warmup_step = cfg.num_image // cfg.total_batch_size * cfg.warmup_epoch + cfg.total_step = cfg.num_image // cfg.total_batch_size * cfg.num_epoch + + lr_scheduler = PolyScheduler( + optimizer=opt, + base_lr=cfg.lr, + max_steps=cfg.total_step, + warmup_steps=cfg.warmup_step, + last_epoch=-1 + ) + + start_epoch = 0 + global_step = 0 + if cfg.resume: + dict_checkpoint = torch.load(os.path.join(cfg.output, f"checkpoint_gpu_{rank}.pt")) + start_epoch = dict_checkpoint["epoch"] + global_step = dict_checkpoint["global_step"] + backbone.module.load_state_dict(dict_checkpoint["state_dict_backbone"]) + module_partial_fc.load_state_dict(dict_checkpoint["state_dict_softmax_fc"]) + opt.load_state_dict(dict_checkpoint["state_optimizer"]) + lr_scheduler.load_state_dict(dict_checkpoint["state_lr_scheduler"]) + del dict_checkpoint + + for key, value in cfg.items(): + num_space = 25 - len(key) + logging.info(": " + key + " " * num_space + str(value)) + + callback_verification = CallBackVerification( + val_targets=cfg.val_targets, rec_prefix=cfg.rec, summary_writer=summary_writer + ) + callback_logging = CallBackLogging( + frequent=cfg.frequent, + total_step=cfg.total_step, + batch_size=cfg.batch_size, + start_step = global_step, + writer=summary_writer + ) + + loss_am = AverageMeter() + amp = torch.cuda.amp.grad_scaler.GradScaler(growth_interval=100) + + for epoch in range(start_epoch, cfg.num_epoch): + + if isinstance(train_loader, DataLoader): + train_loader.sampler.set_epoch(epoch) + for _, (img, local_labels) in enumerate(train_loader): + global_step += 1 + local_embeddings = backbone(img) + loss: torch.Tensor = module_partial_fc(local_embeddings, local_labels) + + if cfg.fp16: + amp.scale(loss).backward() + if global_step % cfg.gradient_acc == 0: + amp.unscale_(opt) + torch.nn.utils.clip_grad_norm_(backbone.parameters(), 5) + amp.step(opt) + amp.update() + opt.zero_grad() + else: + loss.backward() + if global_step % cfg.gradient_acc == 0: + torch.nn.utils.clip_grad_norm_(backbone.parameters(), 5) + opt.step() + opt.zero_grad() + lr_scheduler.step() + + with torch.no_grad(): + loss_am.update(loss.item(), 1) + callback_logging(global_step, loss_am, epoch, cfg.fp16, lr_scheduler.get_last_lr()[0], amp) + + if global_step % cfg.verbose == 0 and global_step > 0: + callback_verification(global_step, backbone) + + if cfg.save_all_states: + checkpoint = { + "epoch": epoch + 1, + "global_step": global_step, + "state_dict_backbone": backbone.module.state_dict(), + "state_dict_softmax_fc": module_partial_fc.state_dict(), + "state_optimizer": opt.state_dict(), + "state_lr_scheduler": lr_scheduler.state_dict() + } + torch.save(checkpoint, os.path.join(cfg.output, f"checkpoint_gpu_{rank}.pt")) + + if rank == 0: + path_module = os.path.join(cfg.output, "model.pt") + torch.save(backbone.module.state_dict(), path_module) + + if cfg.dali: + train_loader.reset() + + if rank == 0: + path_module = os.path.join(cfg.output, "model.pt") + torch.save(backbone.module.state_dict(), path_module) + + +if __name__ == "__main__": + torch.backends.cudnn.benchmark = True + parser = argparse.ArgumentParser( + description="Distributed Arcface Training in Pytorch") + parser.add_argument("config", type=str, help="py config file") + parser.add_argument("--local_rank", type=int, default=0, help="local_rank") + main(parser.parse_args())