-
Notifications
You must be signed in to change notification settings - Fork 0
/
evaluate_lmmd.py
136 lines (96 loc) · 4.97 KB
/
evaluate_lmmd.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
# from Model_0 import resnet101
from Model_7 import resnet101
# from hierarchy_cls_train import model_save_path,train_loader,valid_loader,DEVICE,NUM_CLASSES, GRAYSCALE
import torch
import torch.nn as nn
import os
from util import compute_accuracy_model0,calculate_num_class,hierarchy_dict, \
calculate_num_class_model0, compute_accuracy_model12, \
compute_accuracy_model7_track_based, track_based_accuracy,\
track_based_accuracy_majority_vote,Otsu_Threshold, compute_accuracy_model7_track_based_level_2_only, track_based_accuracy_level2_only
from IPython import embed
from torchvision import transforms
from fish_rail_dataloader_track_based import Fish_Rail_Dataset, Fish_Rail_Tracking_Result
from torch.utils.data import DataLoader
import timm
from prefetch_generator import BackgroundGenerator
from loss_funcs.classifier import Classifier
from loss_funcs.lmmd import LMMDLoss
class DataLoaderX(DataLoader):
def __iter__(self):
return BackgroundGenerator(super().__iter__())
GRAYSCALE = False
# NUM_CLASSES = calculate_num_class(hierarchy_dict) #37
# NUM_CLASSES = calculate_num_class_model0(hierarchy_dict) # model0 31
NUM_level_1_CLASSES, NUM_level_2_CLASSES= calculate_num_class(hierarchy_dict)
model_name = 'resnext50_32x4d'
model_save_path = './checkpoints/' +model_name+'_lmmd'
DEVICE = 'cuda:0'
BATCH_SIZE=256 *3
img_size=224
# model-7
save_path_val = './per img predictions val/'+model_name+'_lmmd'
save_path_tr = './per img predictions train/'+model_name+'_lmmd'
custom_transform = transforms.Compose([transforms.Resize((img_size, img_size)),
transforms.ToTensor()])
valid_gt_path = './rail_cropped_data/labels_track_based/fish-rail-valid-plus_sleeper_shark_nonfish-level2_only.csv'
train_gt_path = 'rail_cropped_data/labels_track_based/fish-rail-train-plus_sleeper_shark_nonfish-level2_only.csv'
img_dir = './rail_cropped_data/cropped_box_with_sleeper_shark_non_fish'
train_dataset = Fish_Rail_Dataset(csv_path=train_gt_path,
img_dir=img_dir,
transform=custom_transform,
hierarchy = hierarchy_dict)
valid_dataset = Fish_Rail_Dataset(csv_path=valid_gt_path,
img_dir=img_dir,
transform=custom_transform,
hierarchy = hierarchy_dict)
train_loader = DataLoaderX(dataset=train_dataset,
batch_size=BATCH_SIZE,
shuffle=False,
num_workers=0)
valid_loader = DataLoaderX(dataset=valid_dataset,
batch_size=BATCH_SIZE,
shuffle=False,
num_workers=0)
uda_img_dir = 'test_sleeper_shark/AK-50308-220423_214636-C1H-025-220524_210051_809_1-20230105T014047Z-001/AK-50308-220423_214636-C1H-025-220524_210051_809_1'
tracking_result = 'test_sleeper_shark/AK-50308-220423_214636-C1H-025-220524_210051_809_1-20230105T014047Z-001-result/AK-50308-220423_214636-C1H-025-220524_210051_809_1/tracking_result_with_huber.csv'
target_dataset = Fish_Rail_Tracking_Result(csv_path=tracking_result,
img_dir=uda_img_dir,
transform=custom_transform,
crop=True)
### load model
# best_epoch=23 # nonfish + sleeper shark
best_epoch=11 # nonfish + sleeper shark
best_epoch=6 # nonfish + sleeper shark
stop_at_level_1_threshold=0.85
model = timm.create_model(model_name, pretrained=True, num_classes=0)
PATH = os.path.join(model_save_path,'parameters_epoch_'+str(best_epoch)+'.pkl')
model.load_state_dict(torch.load(PATH))
model.to(DEVICE)
clf = Classifier(num_class=NUM_level_2_CLASSES, feature_dim=2048)
PATH = os.path.join(model_save_path,'clf_parameters_epoch_'+str(best_epoch)+'.pkl')
clf.load_state_dict(torch.load(PATH))
clf.to(DEVICE)
### 最后测试一下 for model7
model.eval()
# for model7
with torch.set_grad_enabled(False): # save memory during inference
avg_level_2_acc_p1p2_31_val, acc_2_p1p2_31_val = compute_accuracy_model7_track_based_level_2_only(
[model,clf], valid_loader, best_epoch, DEVICE, save_path_val, lmmd=True)
##根据记录下来的confidence,计算tarck-based的accuracy
avg_level_2_acc_p1p2_31_val_track, acc_2_p1p2_31_val_track = \
track_based_accuracy_level2_only(save_path_val, best_epoch)
print(
'Track-based Epoch: %03d | Valid: Level-2 Avg p1p2 max out of 31: %.3f%%' % (
best_epoch,
avg_level_2_acc_p1p2_31_val_track * 100,
))
print('Track-based Individual accuracy: Valid: '
'Level-2 p1p2 max out of 31:', acc_2_p1p2_31_val_track)
print('Image-based Epoch: %03d | Valid: Level-2 Avg p1p2 max out of 31: %.3f%%' % (
best_epoch,
avg_level_2_acc_p1p2_31_val * 100
))
print('Image-based Individual accuracy: Valid: '
'Level-2 p1p2 max out of 31:', acc_2_p1p2_31_val)
embed()