Skip to content

Commit

Permalink
faceID Part 2
Browse files Browse the repository at this point in the history
  • Loading branch information
oguzhanbsolak committed Dec 6, 2023
1 parent d9a07bb commit 5085a53
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 21 deletions.
2 changes: 1 addition & 1 deletion .github/linters/.python-lint
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ ignored-classes = ModelProto
max-line-length = 99
[DESIGN]
max-locals=100
max-statements=300
max-statements=350
min-public-methods=1
max-branches=120
max-module-lines=5000
Expand Down
2 changes: 1 addition & 1 deletion .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ ignored-classes = ModelProto
max-line-length = 99
[DESIGN]
max-locals=100
max-statements=300
max-statements=350
min-public-methods=1
max-branches=120
max-module-lines=5000
Expand Down
12 changes: 6 additions & 6 deletions models/model_irse_DRL.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,17 +30,17 @@
# https://www.maximintegrated.com/en/aboutus/legal/copyrights.html
#
###################################################################################################

"""
FaceID Teacher Model to be used for Knowledge Distillation
"""
from collections import namedtuple

import torch
import torch.nn as nn
from torch import nn
import torch.nn.functional as F
import torchvision.transforms.functional as FT


# Supports: ['IR_50', 'IR_101', 'IR_152', 'IR_SE_50', 'IR_SE_101', 'IR_SE_152']

class DRL(nn.Module):
"""
Dimensionality reduction layers
Expand Down Expand Up @@ -85,7 +85,7 @@ def __init__(self, resnet, DRL):
def forward(self, x):
"""Forward prop"""
if x.shape[1] == 6:
if (not self.Teacher_mode):
if not self.Teacher_mode:
self.Teacher_mode=True
x = x[:,3: ,:,:]
x_flip = FT.hflip(x)
Expand All @@ -101,7 +101,7 @@ class Flatten(nn.Module):
"""Flattens the input"""
def forward(self, input):
"""Forward prop"""
return input.view(input.size(0), -1)
return input.view(input.size(0), -1) # pylint: disable=redefined-builtin


def l2_norm(input, axis=1):
Expand Down
13 changes: 9 additions & 4 deletions parsecmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,13 +68,16 @@ def get_parser(model_names, dataset_names):
help='when simulating, use "round()" in AvgPool operations '
'(default: use "floor()")')
parser.add_argument('--dr', type=int, default=None,
help='Embedding dimensionality for dimensionality reduction (default: None)')
help='Embedding dimensionality for dimensionality'
'reduction (default: None)')
parser.add_argument('--scaf-margin', default=28.6,
type=float, help='Margin hyperparameter for Sub-center ArcFace Loss')
type=float, help='Margin hyperparameter'
'for Sub-center ArcFace Loss')
parser.add_argument('--scaf-scale', default=64,
type=int, help='Scale hyperparameter for Sub-center ArcFace Loss')
parser.add_argument('--backbone-checkpoint', type=str, default=None, metavar='PATH',
help='path to checkpoint from which to load backbone weights (default: None)')
help='path to checkpoint from which to load'
'backbone weights (default: None)')
parser.add_argument('--copy-output-folder', type=str, default=None, metavar='PATH',
help='Path to copy output folder (default: None)')
parser.add_argument('--kd-relationbased', action='store_true', default=False,
Expand Down Expand Up @@ -104,7 +107,9 @@ def get_parser(model_names, dataset_names):
optimizer_args.add_argument('--lr', '--learning-rate',
type=float, metavar='LR', help='initial learning rate')
optimizer_args.add_argument('--scaf-lr', default=1e-4,
type=float, metavar='SCAF_LR', help='initial learning rate for Sub-center ArcFace Loss optimizer')
type=float, metavar='SCAF_LR',
help='initial learning rate for Sub-center'
'ArcFace Loss optimizer')
optimizer_args.add_argument('--momentum', default=0.9, type=float,
metavar='M', help='momentum')
optimizer_args.add_argument('--weight-decay', '--wd', default=1e-4, type=float,
Expand Down
23 changes: 14 additions & 9 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,9 @@ def main():

elif args.dr:

criterion = pml_losses.SubCenterArcFaceLoss(num_classes=args.num_classes, embedding_size=args.dr, margin=args.scaf_margin, scale=args.scaf_scale)
criterion = pml_losses.SubCenterArcFaceLoss(num_classes=args.num_classes,
embedding_size=args.dr, margin=
args.scaf_margin, scale=args.scaf_scale)
if args.resumed_checkpoint_path:
checkpoint = torch.load(args.resumed_checkpoint_path,
map_location=lambda storage, loc: storage)
Expand All @@ -409,11 +411,9 @@ def main():

distance_fn = CosineSimilarity()
custom_knn = CustomKNN(distance_fn, batch_size=args.batch_size)
accuracy_calculator = AccuracyCalculator(knn_func=custom_knn, include=("precision_at_1",), k=1)
accuracy_calculator = AccuracyCalculator(knn_func=custom_knn,
include=("precision_at_1",), k=1)

if args.validation_split != 0:
msglogger.info("WARNING: DR works with whole training set, overwriting validation split to 0")
args.validation_split = 0

else:
if not args.regression:
Expand Down Expand Up @@ -464,7 +464,8 @@ def main():
compression_scheduler = distiller.file_config(model, optimizer, args.compress,
compression_scheduler,
(start_epoch-1)
if args.resumed_checkpoint_path else None, loss_optimizer)
if args.resumed_checkpoint_path
else None, loss_optimizer)
elif compression_scheduler is None:
compression_scheduler = distiller.CompressionScheduler(model)

Expand Down Expand Up @@ -494,7 +495,8 @@ def main():
dlw = distiller.DistillationLossWeights(args.kd_distill_wt, args.kd_student_wt,
args.kd_teacher_wt)
if args.kd_relationbased:
args.kd_policy = kd_relationbased.RelationBasedKDPolicy(model, teacher, dlw, args.act_mode_8bit)
args.kd_policy = kd_relationbased.RelationBasedKDPolicy(model, teacher,
dlw, args.act_mode_8bit)
else:
args.kd_policy = distiller.KnowledgeDistillationPolicy(model, teacher,
args.kd_temp, dlw)
Expand Down Expand Up @@ -997,16 +999,18 @@ def update_bn_stats(train_loader, model, args):
_ = model(inputs)

def get_all_embeddings(dataset, model):
"""Get all embeddings from the test set"""
tester = testers.BaseTester()
return tester.get_all_embeddings(dataset, model)

def scaf_test(val_loader, model, accuracy_calculator):
"""Perform test for SCAF"""
test_embeddings, test_labels = get_all_embeddings(val_loader.dataset, model)
test_labels = test_labels.squeeze(1)
accuracies = accuracy_calculator.get_accuracy(
test_embeddings, test_embeddings, test_labels, test_labels, True
)
msglogger.info("Test set accuracy (Precision@1) = {}".format(accuracies["precision_at_1"]))
msglogger.info(f"Test set accuracy (Precision@1) = {accuracies['precision_at_1']}")
return accuracies["precision_at_1"], 0 , 0 , 0

def validate(val_loader, model, criterion, loggers, args, epoch=-1, tflogger=None):
Expand Down Expand Up @@ -1255,7 +1259,8 @@ def save_tensor(t, f, regression=True):
target /= 128.

if args.generate_sample is not None and args.act_mode_8bit and not sample_saved:
sample.generate(args.generate_sample, inputs, target, output, args.dataset, False, args.slice_sample)
sample.generate(args.generate_sample, inputs, target, output,
args.dataset, False, args.slice_sample)
sample_saved = True

if args.csv_prefix is not None:
Expand Down

0 comments on commit 5085a53

Please sign in to comment.