-
Notifications
You must be signed in to change notification settings - Fork 1
/
validate18.py
54 lines (41 loc) · 1.83 KB
/
validate18.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
"""
This code is part of an adaptation/modification from the original project available at:
https://github.com/peterwang512/CNNDetection
The original code was created by Wang et al. and is used here under the terms of the license
specified in the original project's repository. Any use of this adapted/modified code
must respect the terms of such license.
Adaptations and modifications made by: Daniel Cabanas Gonzalez
Modification date: 08/04/2024
"""
import torch
import numpy as np
from networks.resnet import resnet18
from sklearn.metrics import average_precision_score, precision_recall_curve, accuracy_score
from options.test_options import TestOptions
from data import create_dataloader
def validate(model, opt):
data_loader = create_dataloader(opt)
with torch.no_grad():
y_true, y_pred = [], []
for img, label in data_loader:
in_tens = img.cuda()
y_pred.extend(model(in_tens).sigmoid().flatten().tolist())
y_true.extend(label.flatten().tolist())
y_true, y_pred = np.array(y_true), np.array(y_pred)
r_acc = accuracy_score(y_true[y_true==0], y_pred[y_true==0] > 0.5)
f_acc = accuracy_score(y_true[y_true==1], y_pred[y_true==1] > 0.5)
acc = accuracy_score(y_true, y_pred > 0.5)
ap = average_precision_score(y_true, y_pred)
return acc, ap, r_acc, f_acc, y_true, y_pred
if __name__ == '__main__':
opt = TestOptions().parse(print_options=False)
model = resnet18(num_classes=1)
state_dict = torch.load(opt.model_path, map_location='cpu')
model.load_state_dict(state_dict['model'])
model.cuda()
model.eval()
acc, avg_precision, r_acc, f_acc, y_true, y_pred = validate(model, opt)
print("accuracy:", acc)
print("average precision:", avg_precision)
print("accuracy of real images:", r_acc)
print("accuracy of fake images:", f_acc)