diff --git a/models/model_irse_DRL.py b/models/model_irse_DRL.py index 557b025eb..d29889045 100644 --- a/models/model_irse_DRL.py +++ b/models/model_irse_DRL.py @@ -71,15 +71,14 @@ def forward(self, x): # pylint: disable=arguments-differ return x - class Ensemble(nn.Module): """ Ensemble of Teacher and DRL """ - def __init__(self, resnet, DRL): + def __init__(self, resnet, drl): super().__init__() self.resnet = resnet - self.DRL = DRL + self.DRL = drl self.Teacher_mode = False def forward(self, x): @@ -99,16 +98,15 @@ def forward(self, x): class Flatten(nn.Module): """Flattens the input""" - def forward(self, input): + def forward(self, x): """Forward prop""" - return input.view(input.size(0), -1) # pylint: disable=redefined-builtin + return x.view(x.size(0), -1) -def l2_norm(input, axis=1): +def l2_norm(x, axis=1): """l2 norm""" - norm = torch.norm(input, 2, axis, True) - output = torch.div(input, norm) - + norm = torch.norm(x, 2, axis, True) + output = torch.div(x, norm) return output diff --git a/train.py b/train.py index 4440a3e7c..fc761abec 100644 --- a/train.py +++ b/train.py @@ -1010,7 +1010,7 @@ def scaf_test(val_loader, model, accuracy_calculator): accuracies = accuracy_calculator.get_accuracy( test_embeddings, test_embeddings, test_labels, test_labels, True ) - msglogger.info(f"Test set accuracy (Precision@1) = {accuracies['precision_at_1']}") + msglogger.info('Test set accuracy (Precision@1) = %f', accuracies['precision_at_1']) return accuracies["precision_at_1"], 0 , 0 , 0 def validate(val_loader, model, criterion, loggers, args, epoch=-1, tflogger=None):