-
Notifications
You must be signed in to change notification settings - Fork 5.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
e78eee5
commit 4f227ba
Showing
6 changed files
with
510 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
28 changes: 28 additions & 0 deletions
28
recognition/arcface_torch/configs/wf42m_pfc03_40epoch_8gpu_vit_b.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
from easydict import EasyDict as edict | ||
|
||
# make training faster | ||
# our RAM is 256G | ||
# mount -t tmpfs -o size=140G tmpfs /train_tmp | ||
|
||
config = edict() | ||
config.margin_list = (1.0, 0.0, 0.4) | ||
config.network = "vit_b_dp005_mask_005" | ||
config.resume = False | ||
config.output = None | ||
config.embedding_size = 512 | ||
config.sample_rate = 0.3 | ||
config.fp16 = True | ||
config.weight_decay = 0.1 | ||
config.batch_size = 256 | ||
config.gradient_acc = 12 # total batchsize is 256 * 12 | ||
config.optimizer = "adamw" | ||
config.lr = 0.001 | ||
config.verbose = 2000 | ||
config.dali = False | ||
|
||
config.rec = "/train_tmp/WebFace42M" | ||
config.num_classes = 2059906 | ||
config.num_image = 42474557 | ||
config.num_epoch = 40 | ||
config.warmup_epoch = config.num_epoch // 10 | ||
config.val_targets = [] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,260 @@ | ||
|
||
import math | ||
from typing import Callable | ||
|
||
import torch | ||
from torch import distributed | ||
from torch.nn.functional import linear, normalize | ||
|
||
|
||
class PartialFC_V2(torch.nn.Module): | ||
""" | ||
https://arxiv.org/abs/2203.15565 | ||
A distributed sparsely updating variant of the FC layer, named Partial FC (PFC). | ||
When sample rate less than 1, in each iteration, positive class centers and a random subset of | ||
negative class centers are selected to compute the margin-based softmax loss, all class | ||
centers are still maintained throughout the whole training process, but only a subset is | ||
selected and updated in each iteration. | ||
.. note:: | ||
When sample rate equal to 1, Partial FC is equal to model parallelism(default sample rate is 1). | ||
Example: | ||
-------- | ||
>>> module_pfc = PartialFC(embedding_size=512, num_classes=8000000, sample_rate=0.2) | ||
>>> for img, labels in data_loader: | ||
>>> embeddings = net(img) | ||
>>> loss = module_pfc(embeddings, labels) | ||
>>> loss.backward() | ||
>>> optimizer.step() | ||
""" | ||
_version = 2 | ||
|
||
def __init__( | ||
self, | ||
margin_loss: Callable, | ||
embedding_size: int, | ||
num_classes: int, | ||
sample_rate: float = 1.0, | ||
fp16: bool = False, | ||
): | ||
""" | ||
Paramenters: | ||
----------- | ||
embedding_size: int | ||
The dimension of embedding, required | ||
num_classes: int | ||
Total number of classes, required | ||
sample_rate: float | ||
The rate of negative centers participating in the calculation, default is 1.0. | ||
""" | ||
super(PartialFC_V2, self).__init__() | ||
assert ( | ||
distributed.is_initialized() | ||
), "must initialize distributed before create this" | ||
self.rank = distributed.get_rank() | ||
self.world_size = distributed.get_world_size() | ||
|
||
self.dist_cross_entropy = DistCrossEntropy() | ||
self.embedding_size = embedding_size | ||
self.sample_rate: float = sample_rate | ||
self.fp16 = fp16 | ||
self.num_local: int = num_classes // self.world_size + int( | ||
self.rank < num_classes % self.world_size | ||
) | ||
self.class_start: int = num_classes // self.world_size * self.rank + min( | ||
self.rank, num_classes % self.world_size | ||
) | ||
self.num_sample: int = int(self.sample_rate * self.num_local) | ||
self.last_batch_size: int = 0 | ||
|
||
self.is_updated: bool = True | ||
self.init_weight_update: bool = True | ||
self.weight = torch.nn.Parameter(torch.normal(0, 0.01, (self.num_local, embedding_size))) | ||
|
||
# margin_loss | ||
if isinstance(margin_loss, Callable): | ||
self.margin_softmax = margin_loss | ||
else: | ||
raise | ||
|
||
def sample(self, labels, index_positive): | ||
""" | ||
This functions will change the value of labels | ||
Parameters: | ||
----------- | ||
labels: torch.Tensor | ||
pass | ||
index_positive: torch.Tensor | ||
pass | ||
optimizer: torch.optim.Optimizer | ||
pass | ||
""" | ||
with torch.no_grad(): | ||
positive = torch.unique(labels[index_positive], sorted=True).cuda() | ||
if self.num_sample - positive.size(0) >= 0: | ||
perm = torch.rand(size=[self.num_local]).cuda() | ||
perm[positive] = 2.0 | ||
index = torch.topk(perm, k=self.num_sample)[1].cuda() | ||
index = index.sort()[0].cuda() | ||
else: | ||
index = positive | ||
self.weight_index = index | ||
|
||
labels[index_positive] = torch.searchsorted(index, labels[index_positive]) | ||
|
||
return self.weight[self.weight_index] | ||
|
||
def forward( | ||
self, | ||
local_embeddings: torch.Tensor, | ||
local_labels: torch.Tensor, | ||
): | ||
""" | ||
Parameters: | ||
---------- | ||
local_embeddings: torch.Tensor | ||
feature embeddings on each GPU(Rank). | ||
local_labels: torch.Tensor | ||
labels on each GPU(Rank). | ||
Returns: | ||
------- | ||
loss: torch.Tensor | ||
pass | ||
""" | ||
local_labels.squeeze_() | ||
local_labels = local_labels.long() | ||
|
||
batch_size = local_embeddings.size(0) | ||
if self.last_batch_size == 0: | ||
self.last_batch_size = batch_size | ||
assert self.last_batch_size == batch_size, ( | ||
f"last batch size do not equal current batch size: {self.last_batch_size} vs {batch_size}") | ||
|
||
_gather_embeddings = [ | ||
torch.zeros((batch_size, self.embedding_size)).cuda() | ||
for _ in range(self.world_size) | ||
] | ||
_gather_labels = [ | ||
torch.zeros(batch_size).long().cuda() for _ in range(self.world_size) | ||
] | ||
_list_embeddings = AllGather(local_embeddings, *_gather_embeddings) | ||
distributed.all_gather(_gather_labels, local_labels) | ||
|
||
embeddings = torch.cat(_list_embeddings) | ||
labels = torch.cat(_gather_labels) | ||
|
||
labels = labels.view(-1, 1) | ||
index_positive = (self.class_start <= labels) & ( | ||
labels < self.class_start + self.num_local | ||
) | ||
labels[~index_positive] = -1 | ||
labels[index_positive] -= self.class_start | ||
|
||
if self.sample_rate < 1: | ||
weight = self.sample(labels, index_positive) | ||
else: | ||
weight = self.weight | ||
|
||
with torch.cuda.amp.autocast(self.fp16): | ||
norm_embeddings = normalize(embeddings) | ||
norm_weight_activated = normalize(weight) | ||
logits = linear(norm_embeddings, norm_weight_activated) | ||
if self.fp16: | ||
logits = logits.float() | ||
logits = logits.clamp(-1, 1) | ||
|
||
logits = self.margin_softmax(logits, labels) | ||
loss = self.dist_cross_entropy(logits, labels) | ||
return loss | ||
|
||
|
||
class DistCrossEntropyFunc(torch.autograd.Function): | ||
""" | ||
CrossEntropy loss is calculated in parallel, allreduce denominator into single gpu and calculate softmax. | ||
Implemented of ArcFace (https://arxiv.org/pdf/1801.07698v1.pdf): | ||
""" | ||
|
||
@staticmethod | ||
def forward(ctx, logits: torch.Tensor, label: torch.Tensor): | ||
""" """ | ||
batch_size = logits.size(0) | ||
# for numerical stability | ||
max_logits, _ = torch.max(logits, dim=1, keepdim=True) | ||
# local to global | ||
distributed.all_reduce(max_logits, distributed.ReduceOp.MAX) | ||
logits.sub_(max_logits) | ||
logits.exp_() | ||
sum_logits_exp = torch.sum(logits, dim=1, keepdim=True) | ||
# local to global | ||
distributed.all_reduce(sum_logits_exp, distributed.ReduceOp.SUM) | ||
logits.div_(sum_logits_exp) | ||
index = torch.where(label != -1)[0] | ||
# loss | ||
loss = torch.zeros(batch_size, 1, device=logits.device) | ||
loss[index] = logits[index].gather(1, label[index]) | ||
distributed.all_reduce(loss, distributed.ReduceOp.SUM) | ||
ctx.save_for_backward(index, logits, label) | ||
return loss.clamp_min_(1e-30).log_().mean() * (-1) | ||
|
||
@staticmethod | ||
def backward(ctx, loss_gradient): | ||
""" | ||
Args: | ||
loss_grad (torch.Tensor): gradient backward by last layer | ||
Returns: | ||
gradients for each input in forward function | ||
`None` gradients for one-hot label | ||
""" | ||
( | ||
index, | ||
logits, | ||
label, | ||
) = ctx.saved_tensors | ||
batch_size = logits.size(0) | ||
one_hot = torch.zeros( | ||
size=[index.size(0), logits.size(1)], device=logits.device | ||
) | ||
one_hot.scatter_(1, label[index], 1) | ||
logits[index] -= one_hot | ||
logits.div_(batch_size) | ||
return logits * loss_gradient.item(), None | ||
|
||
|
||
class DistCrossEntropy(torch.nn.Module): | ||
def __init__(self): | ||
super(DistCrossEntropy, self).__init__() | ||
|
||
def forward(self, logit_part, label_part): | ||
return DistCrossEntropyFunc.apply(logit_part, label_part) | ||
|
||
|
||
class AllGatherFunc(torch.autograd.Function): | ||
"""AllGather op with gradient backward""" | ||
|
||
@staticmethod | ||
def forward(ctx, tensor, *gather_list): | ||
gather_list = list(gather_list) | ||
distributed.all_gather(gather_list, tensor) | ||
return tuple(gather_list) | ||
|
||
@staticmethod | ||
def backward(ctx, *grads): | ||
grad_list = list(grads) | ||
rank = distributed.get_rank() | ||
grad_out = grad_list[rank] | ||
|
||
dist_ops = [ | ||
distributed.reduce(grad_out, rank, distributed.ReduceOp.SUM, async_op=True) | ||
if i == rank | ||
else distributed.reduce( | ||
grad_list[i], i, distributed.ReduceOp.SUM, async_op=True | ||
) | ||
for i in range(distributed.get_world_size()) | ||
] | ||
for _op in dist_ops: | ||
_op.wait() | ||
|
||
grad_out *= len(grad_list) # cooperate with distributed loss function | ||
return (grad_out, *[None for _ in range(len(grad_list))]) | ||
|
||
|
||
AllGather = AllGatherFunc.apply |
Oops, something went wrong.