-
Notifications
You must be signed in to change notification settings - Fork 38
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
7cab67d
commit f83f1b5
Showing
8 changed files
with
288 additions
and
0 deletions.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
import torch | ||
from torch import nn | ||
from configs.paths_config import model_paths | ||
from models.encoders.model_irse import Backbone | ||
|
||
|
||
class IDLoss(nn.Module): | ||
def __init__(self): | ||
super(IDLoss, self).__init__() | ||
print('Loading ResNet ArcFace') | ||
self.facenet = Backbone(input_size=112, num_layers=50, drop_ratio=0.6, mode='ir_se') | ||
self.facenet.load_state_dict(torch.load(model_paths['ir_se50'])) | ||
self.face_pool = torch.nn.AdaptiveAvgPool2d((112, 112)) | ||
self.facenet.eval() | ||
|
||
def extract_feats(self, x): | ||
x = x[:, :, 35:223, 32:220] # Crop interesting region | ||
x = self.face_pool(x) | ||
x_feats = self.facenet(x) | ||
return x_feats | ||
|
||
def forward(self, y_hat, y, x): | ||
n_samples = x.shape[0] | ||
x_feats = self.extract_feats(x) | ||
y_feats = self.extract_feats(y) # Otherwise use the feature from there | ||
y_hat_feats = self.extract_feats(y_hat) | ||
y_feats = y_feats.detach() | ||
loss = 0 | ||
sim_improvement = 0 | ||
id_logs = [] | ||
count = 0 | ||
for i in range(n_samples): | ||
diff_target = y_hat_feats[i].dot(y_feats[i]) | ||
diff_input = y_hat_feats[i].dot(x_feats[i]) | ||
diff_views = y_feats[i].dot(x_feats[i]) | ||
id_logs.append({'diff_target': float(diff_target), | ||
'diff_input': float(diff_input), | ||
'diff_views': float(diff_views)}) | ||
loss += 1 - diff_target | ||
id_diff = float(diff_target) - float(diff_views) | ||
sim_improvement += id_diff | ||
count += 1 | ||
|
||
return loss / count, sim_improvement / count, id_logs |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
import torch | ||
import torch.nn as nn | ||
|
||
from criteria.lpips.networks import get_network, LinLayers | ||
from criteria.lpips.utils import get_state_dict | ||
|
||
|
||
class LPIPS(nn.Module): | ||
r"""Creates a criterion that measures | ||
Learned Perceptual Image Patch Similarity (LPIPS). | ||
Arguments: | ||
net_type (str): the network type to compare the features: | ||
'alex' | 'squeeze' | 'vgg'. Default: 'alex'. | ||
version (str): the version of LPIPS. Default: 0.1. | ||
""" | ||
def __init__(self, net_type: str = 'alex', version: str = '0.1'): | ||
|
||
assert version in ['0.1'], 'v0.1 is only supported now' | ||
|
||
super(LPIPS, self).__init__() | ||
|
||
# pretrained network | ||
self.net = get_network(net_type).to("cuda") | ||
|
||
# linear layers | ||
self.lin = LinLayers(self.net.n_channels_list).to("cuda") | ||
self.lin.load_state_dict(get_state_dict(net_type, version)) | ||
|
||
def forward(self, x: torch.Tensor, y: torch.Tensor): | ||
feat_x, feat_y = self.net(x), self.net(y) | ||
|
||
diff = [(fx - fy) ** 2 for fx, fy in zip(feat_x, feat_y)] | ||
res = [l(d).mean((2, 3), True) for d, l in zip(diff, self.lin)] | ||
|
||
return torch.sum(torch.cat(res, 0)) / x.shape[0] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
from typing import Sequence | ||
|
||
from itertools import chain | ||
|
||
import torch | ||
import torch.nn as nn | ||
from torchvision import models | ||
|
||
from criteria.lpips.utils import normalize_activation | ||
|
||
|
||
def get_network(net_type: str): | ||
if net_type == 'alex': | ||
return AlexNet() | ||
elif net_type == 'squeeze': | ||
return SqueezeNet() | ||
elif net_type == 'vgg': | ||
return VGG16() | ||
else: | ||
raise NotImplementedError('choose net_type from [alex, squeeze, vgg].') | ||
|
||
|
||
class LinLayers(nn.ModuleList): | ||
def __init__(self, n_channels_list: Sequence[int]): | ||
super(LinLayers, self).__init__([ | ||
nn.Sequential( | ||
nn.Identity(), | ||
nn.Conv2d(nc, 1, 1, 1, 0, bias=False) | ||
) for nc in n_channels_list | ||
]) | ||
|
||
for param in self.parameters(): | ||
param.requires_grad = False | ||
|
||
|
||
class BaseNet(nn.Module): | ||
def __init__(self): | ||
super(BaseNet, self).__init__() | ||
|
||
# register buffer | ||
self.register_buffer( | ||
'mean', torch.Tensor([-.030, -.088, -.188])[None, :, None, None]) | ||
self.register_buffer( | ||
'std', torch.Tensor([.458, .448, .450])[None, :, None, None]) | ||
|
||
def set_requires_grad(self, state: bool): | ||
for param in chain(self.parameters(), self.buffers()): | ||
param.requires_grad = state | ||
|
||
def z_score(self, x: torch.Tensor): | ||
return (x - self.mean) / self.std | ||
|
||
def forward(self, x: torch.Tensor): | ||
x = self.z_score(x) | ||
|
||
output = [] | ||
for i, (_, layer) in enumerate(self.layers._modules.items(), 1): | ||
x = layer(x) | ||
if i in self.target_layers: | ||
output.append(normalize_activation(x)) | ||
if len(output) == len(self.target_layers): | ||
break | ||
return output | ||
|
||
|
||
class SqueezeNet(BaseNet): | ||
def __init__(self): | ||
super(SqueezeNet, self).__init__() | ||
|
||
self.layers = models.squeezenet1_1(True).features | ||
self.target_layers = [2, 5, 8, 10, 11, 12, 13] | ||
self.n_channels_list = [64, 128, 256, 384, 384, 512, 512] | ||
|
||
self.set_requires_grad(False) | ||
|
||
|
||
class AlexNet(BaseNet): | ||
def __init__(self): | ||
super(AlexNet, self).__init__() | ||
|
||
self.layers = models.alexnet(True).features | ||
self.target_layers = [2, 5, 8, 10, 12] | ||
self.n_channels_list = [64, 192, 384, 256, 256] | ||
|
||
self.set_requires_grad(False) | ||
|
||
|
||
class VGG16(BaseNet): | ||
def __init__(self): | ||
super(VGG16, self).__init__() | ||
|
||
self.layers = models.vgg16(True).features | ||
self.target_layers = [4, 9, 16, 23, 30] | ||
self.n_channels_list = [64, 128, 256, 512, 512] | ||
|
||
self.set_requires_grad(False) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
from collections import OrderedDict | ||
|
||
import torch | ||
|
||
|
||
def normalize_activation(x, eps=1e-10): | ||
norm_factor = torch.sqrt(torch.sum(x ** 2, dim=1, keepdim=True)) | ||
return x / (norm_factor + eps) | ||
|
||
|
||
def get_state_dict(net_type: str = 'alex', version: str = '0.1'): | ||
# build url | ||
url = 'https://raw.githubusercontent.com/richzhang/PerceptualSimilarity/' \ | ||
+ f'master/lpips/weights/v{version}/{net_type}.pth' | ||
|
||
# download | ||
old_state_dict = torch.hub.load_state_dict_from_url( | ||
url, progress=True, | ||
map_location=None if torch.cuda.is_available() else torch.device('cpu') | ||
) | ||
|
||
# rename keys | ||
new_state_dict = OrderedDict() | ||
for key, val in old_state_dict.items(): | ||
new_key = key | ||
new_key = new_key.replace('lin', '') | ||
new_key = new_key.replace('model.', '') | ||
new_state_dict[new_key] = val | ||
|
||
return new_state_dict |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
import torch | ||
from torch import nn | ||
import torch.nn.functional as F | ||
from configs.paths_config import model_paths | ||
|
||
|
||
class MocoLoss(nn.Module): | ||
|
||
def __init__(self): | ||
super(MocoLoss, self).__init__() | ||
print("Loading MOCO model from path: {}".format(model_paths["moco"])) | ||
self.model = self.__load_model() | ||
self.model.cuda() | ||
self.model.eval() | ||
|
||
@staticmethod | ||
def __load_model(): | ||
import torchvision.models as models | ||
model = models.__dict__["resnet50"]() | ||
# freeze all layers but the last fc | ||
for name, param in model.named_parameters(): | ||
if name not in ['fc.weight', 'fc.bias']: | ||
param.requires_grad = False | ||
checkpoint = torch.load(model_paths['moco'], map_location="cpu") | ||
state_dict = checkpoint['state_dict'] | ||
# rename moco pre-trained keys | ||
for k in list(state_dict.keys()): | ||
# retain only encoder_q up to before the embedding layer | ||
if k.startswith('module.encoder_q') and not k.startswith('module.encoder_q.fc'): | ||
# remove prefix | ||
state_dict[k[len("module.encoder_q."):]] = state_dict[k] | ||
# delete renamed or unused k | ||
del state_dict[k] | ||
msg = model.load_state_dict(state_dict, strict=False) | ||
assert set(msg.missing_keys) == {"fc.weight", "fc.bias"} | ||
# remove output layer | ||
model = nn.Sequential(*list(model.children())[:-1]).cuda() | ||
return model | ||
|
||
def extract_feats(self, x): | ||
x = F.interpolate(x, size=224) | ||
x_feats = self.model(x) | ||
x_feats = nn.functional.normalize(x_feats, dim=1) | ||
x_feats = x_feats.squeeze() | ||
return x_feats | ||
|
||
def forward(self, y_hat, y, x): | ||
n_samples = x.shape[0] | ||
x_feats = self.extract_feats(x) | ||
y_feats = self.extract_feats(y) | ||
y_hat_feats = self.extract_feats(y_hat) | ||
y_feats = y_feats.detach() | ||
loss = 0 | ||
sim_improvement = 0 | ||
sim_logs = [] | ||
count = 0 | ||
for i in range(n_samples): | ||
diff_target = y_hat_feats[i].dot(y_feats[i]) | ||
diff_input = y_hat_feats[i].dot(x_feats[i]) | ||
diff_views = y_feats[i].dot(x_feats[i]) | ||
sim_logs.append({'diff_target': float(diff_target), | ||
'diff_input': float(diff_input), | ||
'diff_views': float(diff_views)}) | ||
loss += 1 - diff_target | ||
sim_diff = float(diff_target) - float(diff_views) | ||
sim_improvement += sim_diff | ||
count += 1 | ||
|
||
return loss / count, sim_improvement / count, sim_logs |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
import torch | ||
from torch import nn | ||
|
||
|
||
class WNormLoss(nn.Module): | ||
|
||
def __init__(self, start_from_latent_avg=True): | ||
super(WNormLoss, self).__init__() | ||
self.start_from_latent_avg = start_from_latent_avg | ||
|
||
def forward(self, latent, latent_avg=None): | ||
if self.start_from_latent_avg: | ||
latent = latent - latent_avg | ||
return torch.sum(latent.norm(2, dim=(1, 2))) / latent.shape[0] |