-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathembeddings.py
76 lines (61 loc) · 2.77 KB
/
embeddings.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import os
from PIL import Image
import PIL
import matplotlib.pyplot as plt
from tqdm import tqdm
import numpy as np
import pandas as pd
import pickle
import torch
from datasets import Dataset, Image
from torch.utils.data import DataLoader
from typing import List, Union
from transformers import CLIPProcessor, CLIPModel
device = "cpu"
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
preprocess = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
image_path = os.listdir('/Users/shashankvats/projects_openai/FlickerSearchEngine/flicker30k/Images/')
image_path = ['/Users/shashankvats/projects_openai/FlickerSearchEngine/flicker30k/Images/' + path for path in image_path if '.jpg' in path]
image_path.sort()
captions_df = pd.read_csv('captions.csv')
def encode_images(images: Union[List[str], List[PIL.Image.Image]], batch_size: int):
def transform_fn(el):
if isinstance(el['image'], PIL.Image.Image):
imgs = el['image']
else:
imgs = [Image().decode_example(_) for _ in el['image']]
return preprocess(images=imgs, return_tensors='pt')
dataset = Dataset.from_dict({'image': images})
dataset = dataset.cast_column('image',Image(decode=False)) if isinstance(images[0], str) else dataset
dataset.set_format('torch')
dataset.set_transform(transform_fn)
dataloader = DataLoader(dataset, batch_size=batch_size)
image_embeddings = []
pbar = tqdm(total=len(images) // batch_size, position=0)
with torch.no_grad():
for batch in dataloader:
batch = {k:v.to(device) for k,v in batch.items()}
image_embeddings.extend(model.get_image_features(**batch).detach().cpu().numpy())
pbar.update(1)
pbar.close()
return np.stack(image_embeddings)
def encode_text( text: List[str], batch_size: int):
dataset = Dataset.from_dict({'text': text})
dataset = dataset.map(lambda el: preprocess(text=el['text'], return_tensors="pt",
max_length=77, padding="max_length", truncation=True),
batched=True,
remove_columns=['text'])
dataset.set_format('torch')
dataloader = DataLoader(dataset, batch_size=batch_size)
text_embeddings = []
pbar = tqdm(total=len(text) // batch_size, position=0)
with torch.no_grad():
for batch in dataloader:
batch = {k: v.to(device) for k, v in batch.items()}
text_embeddings.extend(model.get_text_features(**batch).detach().cpu().numpy())
pbar.update(1)
pbar.close()
return np.stack(text_embeddings)
vector_embedding = np.array(encode_images(image_path,32))
with open('flicker30k_image_embeddings.pkl','wb') as f:
pickle.dump(vector_embedding, f)