-
Notifications
You must be signed in to change notification settings - Fork 13
/
clip_score.py
109 lines (88 loc) · 4.63 KB
/
clip_score.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
"""Script for calculating CLIP score."""
import csv
import click
import tqdm
import torch
from torch_utils import distributed as dist
from training import dataset
import open_clip
from torchvision import transforms
from torch_utils.download_util import check_file_by_key
#----------------------------------------------------------------------------
@click.group()
def main():
"""Calculate CLIP score.
python clip-score.py calc --images=path/to/images
torchrun --standalone --nproc_per_node=1 clip-score.py calc --images=path/to/images
"""
#----------------------------------------------------------------------------
@main.command()
@click.option('--images', 'image_path', help='Path to the images', metavar='PATH|ZIP', type=str, required=True)
@click.option('--num', 'num_expected', help='Number of images to use', metavar='INT', type=click.IntRange(min=2), show_default=True)
@click.option('--seed', help='Random seed for selecting the images', metavar='INT', type=int, default=0, show_default=True)
@click.option('--batch', help='Maximum batch size', metavar='INT', type=click.IntRange(min=1), default=250, show_default=True)
@click.option('--desc', help='A description string', metavar='str', type=str)
@torch.no_grad()
def calc(image_path, batch, desc=None, num_expected=None, seed=0, max_batch_size=64,
num_workers=3, prefetch_factor=2, device=torch.device('cuda')):
"""Calculate FID for a given set of images."""
torch.multiprocessing.set_start_method('spawn')
dist.init()
# Rank 0 goes first.
if dist.get_rank() != 0:
torch.distributed.barrier()
# List images.
dist.print0(f'Loading images from "{image_path}"...')
dataset_obj = dataset.ImageFolderDataset(path=image_path, max_size=num_expected, random_seed=seed)
# Loading COCO validation set
prompt_path, _ = check_file_by_key('prompts')
dist.print0(f"Loading MS-COCO 30k captions...")
sample_captions = []
with open(prompt_path, 'r') as file:
reader = csv.DictReader(file)
for row in reader:
text = row['text']
sample_captions.append(text)
# Loading CLIP model
dist.print0(f'Loading CLIP-ViT-g-14 model...')
model, _, preprocess = open_clip.create_model_and_transforms('ViT-g-14', pretrained='laion2b_s34b_b88k')
tokenizer = open_clip.get_tokenizer('ViT-g-14')
model.to(device)
# Other ranks follow.
if dist.get_rank() == 0:
torch.distributed.barrier()
# Divide images into batches.
num_batches = ((len(dataset_obj) - 1) // (max_batch_size * dist.get_world_size()) + 1) * dist.get_world_size()
all_batches = torch.arange(len(dataset_obj)).tensor_split(num_batches)
rank_batches = all_batches[dist.get_rank() :: dist.get_world_size()]
data_loader = torch.utils.data.DataLoader(dataset_obj, batch_sampler=rank_batches, num_workers=num_workers, prefetch_factor=prefetch_factor)
# Accumulate statistics.
dist.print0(f'Calculating statistics for {len(dataset_obj)} images...')
avg_clip_score, batch_idx = 0, 0
to_pil = transforms.ToPILImage()
for images, _ in tqdm.tqdm(data_loader, unit='batch', disable=(dist.get_rank() != 0)):
torch.distributed.barrier()
prompts = sample_captions[rank_batches[batch_idx][0]:rank_batches[batch_idx][-1]+1]
images = torch.stack([preprocess(to_pil(img)) for img in images], dim=0).to(device)
text = tokenizer(prompts).to(device)
image_features = model.encode_image(images)
text_features = model.encode_text(text)
image_features /= image_features.norm(dim=-1, keepdim=True)
text_features /= text_features.norm(dim=-1, keepdim=True)
sd_clip_score = 100 * (image_features * text_features).sum(axis=-1)
avg_clip_score += sd_clip_score.sum()
batch_idx += 1
# if batch_idx % 10 == 0:
# total_samples = batch_idx * max_batch_size
# dist.print0(f"CLIP score under {total_samples} samples: {avg_clip_score / total_samples}")
avg_clip_score /= len(dataset_obj)
dist.print0(f"CLIP score: {avg_clip_score}")
if dist.get_rank() == 0:
Note = open('clip_score.txt', mode='a')
Note.write(f'{desc} {avg_clip_score}\n') if desc is not None else Note.write('{} {} {}\n'.format(image_path.split('/')[-2], image_path.split('/')[-1], avg_clip_score))
Note.close()
torch.distributed.barrier()
#----------------------------------------------------------------------------
if __name__ == "__main__":
main()
#----------------------------------------------------------------------------