Skip to content

Commit

Permalink
FaceID part 2
Browse files Browse the repository at this point in the history
  • Loading branch information
oguzhanbsolak committed Dec 7, 2023
1 parent 5085a53 commit 479207b
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 10 deletions.
16 changes: 7 additions & 9 deletions models/model_irse_DRL.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,15 +71,14 @@ def forward(self, x): # pylint: disable=arguments-differ
return x



class Ensemble(nn.Module):
"""
Ensemble of Teacher and DRL
"""
def __init__(self, resnet, DRL):
def __init__(self, resnet, drl):
super().__init__()
self.resnet = resnet
self.DRL = DRL
self.DRL = drl
self.Teacher_mode = False

def forward(self, x):
Expand All @@ -99,16 +98,15 @@ def forward(self, x):

class Flatten(nn.Module):
"""Flattens the input"""
def forward(self, input):
def forward(self, x):
"""Forward prop"""
return input.view(input.size(0), -1) # pylint: disable=redefined-builtin
return x.view(x.size(0), -1)


def l2_norm(input, axis=1):
def l2_norm(x, axis=1):
"""l2 norm"""
norm = torch.norm(input, 2, axis, True)
output = torch.div(input, norm)

norm = torch.norm(x, 2, axis, True)
output = torch.div(x, norm)
return output


Expand Down
2 changes: 1 addition & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -1010,7 +1010,7 @@ def scaf_test(val_loader, model, accuracy_calculator):
accuracies = accuracy_calculator.get_accuracy(
test_embeddings, test_embeddings, test_labels, test_labels, True
)
msglogger.info(f"Test set accuracy (Precision@1) = {accuracies['precision_at_1']}")
msglogger.info('Test set accuracy (Precision@1) = %f', accuracies['precision_at_1'])
return accuracies["precision_at_1"], 0 , 0 , 0

def validate(val_loader, model, criterion, loggers, args, epoch=-1, tflogger=None):
Expand Down

0 comments on commit 479207b

Please sign in to comment.