-
Notifications
You must be signed in to change notification settings - Fork 8
/
main_eval.py
122 lines (100 loc) · 4.98 KB
/
main_eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import argparse
import json
import os
import torch
from PIL import Image
from tqdm import tqdm
import open_clip
models = [
('RN50', 'openai'),
('RN101', 'openai'),
('RN50x4', 'openai'),
('ViT-B-32', 'openai'),
('RN50x16', 'openai'),
('RN50x64', 'openai'),
('ViT-L-14', 'openai'),
# ('ViT-B-32-quickgelu', 'datacomp_s_s13m_b4k'),
# ('ViT-B-32-quickgelu', 'datacomp_m_s128m_b4k'),
# ('ViT-B-16', 'datacomp_l_s1b_b8k'),
# ('ViT-L-14', 'datacomp_xl_s13b_b90k'),
('ViT-H-14', 'laion2b_s32b_b79k'),
('ViT-g-14', 'laion2b_s12b_b42k'),
('ViT-bigG-14', 'laion2b_s39b_b160k'),
('roberta-ViT-B-32', 'laion2b_s12b_b32k'),
('xlm-roberta-base-ViT-B-32', 'laion5b_s13b_b90k'),
('xlm-roberta-large-ViT-H-14', 'frozen_laion5b_s13b_b90k'),
]
def load_model(args, pretrained, device):
model, _, transform = open_clip.create_model_and_transforms(
model_name=args.model,
pretrained=pretrained,
cache_dir=args.model_cache_dir,
device=device
)
model = model.to(device)
tokenizer = open_clip.get_tokenizer(args.model)
model.eval()
return model, tokenizer, transform
@torch.no_grad()
def text_retrieval(pos_text, neg_text, image, model, tokenizer, transform, device):
pos_text = tokenizer(pos_text).to(device)
pos_text_embedding = model.encode_text(pos_text, normalize=True)
neg_text = tokenizer(neg_text).to(device)
neg_text_embedding = model.encode_text(neg_text, normalize=True)
image_embedding = model.encode_image(transform(image).unsqueeze(dim=0).to(device), normalize=True)
pos_score = pos_text_embedding @ image_embedding.t()
neg_score = neg_text_embedding @ image_embedding.t()
return 1 if pos_score.item() > neg_score.item() else 0
def evaluate(image_root, dataset, model, tokenizer, transform, device):
metrics = {}
for c, data_dict in dataset.items():
correct_cnt = 0
for i, data in tqdm(data_dict.items(), desc=f'evaluating {c}'):
image_path = os.path.join(image_root, data['filename'])
image = Image.open(image_path)
correct = text_retrieval(data['caption'], data['negative_caption'], image, model, tokenizer, transform, device)
correct_cnt += correct
count = len(data_dict)
metrics[c] = correct_cnt / count
return metrics
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--model', type=str, default="RN50", help="Model architecture to use from OpenCLIP")
parser.add_argument('--pretrained', type=str, default="openai", help="Model checkpoint name to use from OpenCLIP")
parser.add_argument('--model_cache_dir', default=None, type=str, help="Directory to where downloaded models are cached")
parser.add_argument('--output', type=str, default=None, help="Directory to where results are saved")
parser.add_argument('--coco_image_root', type=str, default=None)
parser.add_argument('--data_root', type=str, default='./data')
parser.add_argument('--all', action="store_true", default=False, help="Whether to test all the pretrained models in the paper")
args = parser.parse_args()
device = 'cuda' if torch.cuda.is_available() else 'cpu'
data_dict = {
'add_obj' : f'{args.data_root}/add_obj.json',
'add_att' : f'{args.data_root}/add_att.json',
'replace_obj': f'{args.data_root}/replace_obj.json',
'replace_att': f'{args.data_root}/replace_att.json',
'replace_rel': f'{args.data_root}/replace_rel.json',
'swap_obj' : f'{args.data_root}/swap_obj.json',
'swap_att' : f'{args.data_root}/swap_att.json',
}
dataset = {}
for c, data_path in data_dict.items():
dataset[c] = json.load(open(data_path, 'r', encoding='utf-8'))
os.makedirs(args.output, exist_ok=True)
if args.all:
print("Evaluating all models")
for modelname, pretrained in models:
print(f"Evaluating {modelname}-{pretrained}")
args.model = modelname
model, tokenizer, transform = load_model(args, pretrained, device)
metrics = evaluate(args.coco_image_root, dataset, model, tokenizer, transform, device)
print(metrics)
print(f"Dump results to: {os.path.join(args.output, f'{args.model}-{pretrained}.json')}")
json.dump(metrics, open(os.path.join(args.output, f'{args.model}-{pretrained}.json'), 'w'), indent=4)
else:
print(f"Evaluating {args.model}-{args.pretrained}")
model, tokenizer, transform = load_model(args, args.pretrained, device)
metrics = evaluate(args.coco_image_root, dataset, model, tokenizer, transform, device)
print(metrics)
print(f"Dump results to: {os.path.join(args.output, f'{args.model}-{args.pretrained}.json')}")
json.dump(metrics, open(os.path.join(args.output, f'{args.model}-{args.pretrained}.json'), 'w'), indent=4)