Skip to content

Commit

Permalink
Add files via upload
Browse files Browse the repository at this point in the history
  • Loading branch information
williamyang1991 authored Jul 19, 2023
1 parent 7cab67d commit f83f1b5
Show file tree
Hide file tree
Showing 8 changed files with 288 additions and 0 deletions.
Empty file added criteria/__init__.py
Empty file.
44 changes: 44 additions & 0 deletions criteria/id_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import torch
from torch import nn
from configs.paths_config import model_paths
from models.encoders.model_irse import Backbone


class IDLoss(nn.Module):
def __init__(self):
super(IDLoss, self).__init__()
print('Loading ResNet ArcFace')
self.facenet = Backbone(input_size=112, num_layers=50, drop_ratio=0.6, mode='ir_se')
self.facenet.load_state_dict(torch.load(model_paths['ir_se50']))
self.face_pool = torch.nn.AdaptiveAvgPool2d((112, 112))
self.facenet.eval()

def extract_feats(self, x):
x = x[:, :, 35:223, 32:220] # Crop interesting region
x = self.face_pool(x)
x_feats = self.facenet(x)
return x_feats

def forward(self, y_hat, y, x):
n_samples = x.shape[0]
x_feats = self.extract_feats(x)
y_feats = self.extract_feats(y) # Otherwise use the feature from there
y_hat_feats = self.extract_feats(y_hat)
y_feats = y_feats.detach()
loss = 0
sim_improvement = 0
id_logs = []
count = 0
for i in range(n_samples):
diff_target = y_hat_feats[i].dot(y_feats[i])
diff_input = y_hat_feats[i].dot(x_feats[i])
diff_views = y_feats[i].dot(x_feats[i])
id_logs.append({'diff_target': float(diff_target),
'diff_input': float(diff_input),
'diff_views': float(diff_views)})
loss += 1 - diff_target
id_diff = float(diff_target) - float(diff_views)
sim_improvement += id_diff
count += 1

return loss / count, sim_improvement / count, id_logs
Empty file added criteria/lpips/__init__.py
Empty file.
35 changes: 35 additions & 0 deletions criteria/lpips/lpips.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import torch
import torch.nn as nn

from criteria.lpips.networks import get_network, LinLayers
from criteria.lpips.utils import get_state_dict


class LPIPS(nn.Module):
r"""Creates a criterion that measures
Learned Perceptual Image Patch Similarity (LPIPS).
Arguments:
net_type (str): the network type to compare the features:
'alex' | 'squeeze' | 'vgg'. Default: 'alex'.
version (str): the version of LPIPS. Default: 0.1.
"""
def __init__(self, net_type: str = 'alex', version: str = '0.1'):

assert version in ['0.1'], 'v0.1 is only supported now'

super(LPIPS, self).__init__()

# pretrained network
self.net = get_network(net_type).to("cuda")

# linear layers
self.lin = LinLayers(self.net.n_channels_list).to("cuda")
self.lin.load_state_dict(get_state_dict(net_type, version))

def forward(self, x: torch.Tensor, y: torch.Tensor):
feat_x, feat_y = self.net(x), self.net(y)

diff = [(fx - fy) ** 2 for fx, fy in zip(feat_x, feat_y)]
res = [l(d).mean((2, 3), True) for d, l in zip(diff, self.lin)]

return torch.sum(torch.cat(res, 0)) / x.shape[0]
96 changes: 96 additions & 0 deletions criteria/lpips/networks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
from typing import Sequence

from itertools import chain

import torch
import torch.nn as nn
from torchvision import models

from criteria.lpips.utils import normalize_activation


def get_network(net_type: str):
if net_type == 'alex':
return AlexNet()
elif net_type == 'squeeze':
return SqueezeNet()
elif net_type == 'vgg':
return VGG16()
else:
raise NotImplementedError('choose net_type from [alex, squeeze, vgg].')


class LinLayers(nn.ModuleList):
def __init__(self, n_channels_list: Sequence[int]):
super(LinLayers, self).__init__([
nn.Sequential(
nn.Identity(),
nn.Conv2d(nc, 1, 1, 1, 0, bias=False)
) for nc in n_channels_list
])

for param in self.parameters():
param.requires_grad = False


class BaseNet(nn.Module):
def __init__(self):
super(BaseNet, self).__init__()

# register buffer
self.register_buffer(
'mean', torch.Tensor([-.030, -.088, -.188])[None, :, None, None])
self.register_buffer(
'std', torch.Tensor([.458, .448, .450])[None, :, None, None])

def set_requires_grad(self, state: bool):
for param in chain(self.parameters(), self.buffers()):
param.requires_grad = state

def z_score(self, x: torch.Tensor):
return (x - self.mean) / self.std

def forward(self, x: torch.Tensor):
x = self.z_score(x)

output = []
for i, (_, layer) in enumerate(self.layers._modules.items(), 1):
x = layer(x)
if i in self.target_layers:
output.append(normalize_activation(x))
if len(output) == len(self.target_layers):
break
return output


class SqueezeNet(BaseNet):
def __init__(self):
super(SqueezeNet, self).__init__()

self.layers = models.squeezenet1_1(True).features
self.target_layers = [2, 5, 8, 10, 11, 12, 13]
self.n_channels_list = [64, 128, 256, 384, 384, 512, 512]

self.set_requires_grad(False)


class AlexNet(BaseNet):
def __init__(self):
super(AlexNet, self).__init__()

self.layers = models.alexnet(True).features
self.target_layers = [2, 5, 8, 10, 12]
self.n_channels_list = [64, 192, 384, 256, 256]

self.set_requires_grad(False)


class VGG16(BaseNet):
def __init__(self):
super(VGG16, self).__init__()

self.layers = models.vgg16(True).features
self.target_layers = [4, 9, 16, 23, 30]
self.n_channels_list = [64, 128, 256, 512, 512]

self.set_requires_grad(False)
30 changes: 30 additions & 0 deletions criteria/lpips/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from collections import OrderedDict

import torch


def normalize_activation(x, eps=1e-10):
norm_factor = torch.sqrt(torch.sum(x ** 2, dim=1, keepdim=True))
return x / (norm_factor + eps)


def get_state_dict(net_type: str = 'alex', version: str = '0.1'):
# build url
url = 'https://raw.githubusercontent.com/richzhang/PerceptualSimilarity/' \
+ f'master/lpips/weights/v{version}/{net_type}.pth'

# download
old_state_dict = torch.hub.load_state_dict_from_url(
url, progress=True,
map_location=None if torch.cuda.is_available() else torch.device('cpu')
)

# rename keys
new_state_dict = OrderedDict()
for key, val in old_state_dict.items():
new_key = key
new_key = new_key.replace('lin', '')
new_key = new_key.replace('model.', '')
new_state_dict[new_key] = val

return new_state_dict
69 changes: 69 additions & 0 deletions criteria/moco_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import torch
from torch import nn
import torch.nn.functional as F
from configs.paths_config import model_paths


class MocoLoss(nn.Module):

def __init__(self):
super(MocoLoss, self).__init__()
print("Loading MOCO model from path: {}".format(model_paths["moco"]))
self.model = self.__load_model()
self.model.cuda()
self.model.eval()

@staticmethod
def __load_model():
import torchvision.models as models
model = models.__dict__["resnet50"]()
# freeze all layers but the last fc
for name, param in model.named_parameters():
if name not in ['fc.weight', 'fc.bias']:
param.requires_grad = False
checkpoint = torch.load(model_paths['moco'], map_location="cpu")
state_dict = checkpoint['state_dict']
# rename moco pre-trained keys
for k in list(state_dict.keys()):
# retain only encoder_q up to before the embedding layer
if k.startswith('module.encoder_q') and not k.startswith('module.encoder_q.fc'):
# remove prefix
state_dict[k[len("module.encoder_q."):]] = state_dict[k]
# delete renamed or unused k
del state_dict[k]
msg = model.load_state_dict(state_dict, strict=False)
assert set(msg.missing_keys) == {"fc.weight", "fc.bias"}
# remove output layer
model = nn.Sequential(*list(model.children())[:-1]).cuda()
return model

def extract_feats(self, x):
x = F.interpolate(x, size=224)
x_feats = self.model(x)
x_feats = nn.functional.normalize(x_feats, dim=1)
x_feats = x_feats.squeeze()
return x_feats

def forward(self, y_hat, y, x):
n_samples = x.shape[0]
x_feats = self.extract_feats(x)
y_feats = self.extract_feats(y)
y_hat_feats = self.extract_feats(y_hat)
y_feats = y_feats.detach()
loss = 0
sim_improvement = 0
sim_logs = []
count = 0
for i in range(n_samples):
diff_target = y_hat_feats[i].dot(y_feats[i])
diff_input = y_hat_feats[i].dot(x_feats[i])
diff_views = y_feats[i].dot(x_feats[i])
sim_logs.append({'diff_target': float(diff_target),
'diff_input': float(diff_input),
'diff_views': float(diff_views)})
loss += 1 - diff_target
sim_diff = float(diff_target) - float(diff_views)
sim_improvement += sim_diff
count += 1

return loss / count, sim_improvement / count, sim_logs
14 changes: 14 additions & 0 deletions criteria/w_norm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import torch
from torch import nn


class WNormLoss(nn.Module):

def __init__(self, start_from_latent_avg=True):
super(WNormLoss, self).__init__()
self.start_from_latent_avg = start_from_latent_avg

def forward(self, latent, latent_avg=None):
if self.start_from_latent_avg:
latent = latent - latent_avg
return torch.sum(latent.norm(2, dim=(1, 2))) / latent.shape[0]

0 comments on commit f83f1b5

Please sign in to comment.