-
Notifications
You must be signed in to change notification settings - Fork 1
/
construct_vecbase.py
57 lines (45 loc) · 1.88 KB
/
construct_vecbase.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import torch
import torchvision.models as models
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
from PIL import Image
import os
import numpy as np
class RetrivalDataset(Dataset):
def __init__(self, folder_path):
self.file_list = [os.path.join(folder_path, f) for f in sorted(os.listdir(folder_path)) if f.endswith('.png')]
self.transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
def __len__(self):
return len(self.file_list)
def __getitem__(self, index):
img_path = self.file_list[index]
image = Image.open(img_path).convert('RGB')
image = self.transform(image)
return image, img_path
def extract_features(dataset_path, batch_size=32):
model = models.vgg16(pretrained=True)
model = model.features
model.eval()
model.to('cuda' if torch.cuda.is_available() else 'cpu')
dataset = RetrivalDataset(dataset_path)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=4)
features = []
file_paths = []
with torch.no_grad():
for data in dataloader:
images, paths = data
images = images.to('cuda' if torch.cuda.is_available() else 'cpu')
output = model(images)
output = torch.nn.functional.adaptive_avg_pool2d(output, (1, 1))
features.append(output.squeeze().cpu().numpy())
file_paths.extend(paths)
features = np.vstack(features)
return features, file_paths
if __name__ == '__main__':
features, file_paths = extract_features('/home/tiger/gh/dataset/DIV2K/DIV2K_train_HR')
np.save('/home/tiger/gh/dataset/div_feat.npy', features)
np.save('/home/tiger/gh/dataset/div_path.npy', file_paths)