forked from mrktracy/masked_rpm
-
Notifications
You must be signed in to change notification settings - Fork 0
/
datasets.py
169 lines (128 loc) · 6.04 KB
/
datasets.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
import numpy as np
import torch
from torch.utils.data import Dataset
import random
from collections import defaultdict
class CustomMNIST(Dataset):
def __init__(self, mnist_data, num_samples):
self.mnist_data = mnist_data
self.num_samples = num_samples
self.label_to_images = defaultdict(list)
for img, label in mnist_data:
self.label_to_images[label].append(img)
self.random_nums = [random.randint(1,8) for _ in range(num_samples)]
self.random_order = [random.randint(0, 1) for _ in range(num_samples)]
def __len__(self):
return self.num_samples
def __getitem__(self, idx):
random_num = self.random_nums[idx]
random_order = self.random_order[idx]
low_num = random_num - 1
high_num = random_num + 1
low_imgs = random.sample(self.label_to_images[low_num],8)
high_imgs = random.sample(self.label_to_images[high_num],8)
if random_order == 0:
question_imgs = low_imgs[0:8] + high_imgs[0:8]
else:
question_imgs = high_imgs[0:8] + low_imgs[0:8]
question_tensor = torch.stack(question_imgs)
target = random_num - 1
return question_tensor, target
class RPMSentencesRaw(Dataset):
def __init__(self, files):
self.files = files
def __getitem__(self, idx):
filename = self.files[idx]
data = np.load(filename)
image = data['image'].reshape(16,160,160)
imagetensor = torch.from_numpy(image).float() / 255 # convert context panels to tensor
imagetensor = imagetensor.unsqueeze(1)
target = data['target'].item()
return imagetensor, target
def __len__(self):
length = len(self.files)
return length
class RPMSentencesNew(Dataset):
def __init__(self, files, ResNetAutoencoder, device):
self.files = files
self.autoencoder = ResNetAutoencoder
self.device = device
def __getitem__(self, idx):
filename = self.files[idx]
data = np.load(filename)
image = data['image'].reshape(16,160,160)
imagetensor = torch.from_numpy(image).float() / 255 # convert context panels to tensor
imagetensor = imagetensor.unsqueeze(1).to(self.device)
# get panel embeddings
# num_gpus = torch.cuda.device_count()
# embeddings = self.autoencoder.module.get_embedding(imagetensor) if num_gpus > 1 else self.autoencoder.get_embedding(imagetensor)
embeddings = self.autoencoder.get_embedding(imagetensor)
target = np.zeros((8,))
target[data['target'].item()]=1
return embeddings, target
def __len__(self):
length = len(self.files)
return length
# 1. Dataset
class RPMSentences(Dataset):
def __init__(self, files, ResNetAutoencoder, embed_dim, device):
self.files = files
self.autoencoder = ResNetAutoencoder
self.embed_dim = embed_dim
self.device = device
def __getitem__(self, idx):
mask = torch.ones(self.embed_dim).to(self.device) # create masking token
mask_exp = torch.ones(self.embed_dim*2).to(self.device) # create mask for tensor output
pad = torch.zeros([1,self.embed_dim]).to(self.device) # create padding token
fileidx = idx // (8*4)
panelidx = idx % 8
filename = self.files[fileidx]
data = np.load(filename)
image = data['image']
imagetensor = torch.from_numpy(image[0:8,:,:]).float() / 255 # convert context panels to tensor
imagetensor = imagetensor.unsqueeze(1).to(self.device)
embeddings = self.autoencoder.get_embedding(imagetensor) # get panel embeddings
maskedsentence = embeddings.clone() # create masked sentence
maskedsentence[panelidx, :] = mask # replace one panel with mask token
paddedmaskedsentence = torch.cat([maskedsentence, pad], 0) # (9, 256)
# rotate grid
paddedmaskedgrid = paddedmaskedsentence.reshape([3, 3, self.embed_dim])
paddedmaskedgrid_rotated = torch.rot90(paddedmaskedgrid, k=idx%4, dims=[0,1])
final_sentence = paddedmaskedgrid_rotated.reshape([9, self.embed_dim])
mask_tensor = torch.zeros(9, self.embed_dim*2)
mask_tensor[panelidx, :] = mask_exp # ones where the mask is, 0s elsewhere
# rotate mask tensor
maskgrid = mask_tensor.reshape([3, 3, self.embed_dim*2])
maskgrid_rotated = torch.rot90(maskgrid, k=idx%4, dims=[0,1])
final_mask_tensor = maskgrid_rotated.reshape([9, self.embed_dim*2])
target = embeddings[panelidx, :] # extract target panel embedding
return final_sentence, target, final_mask_tensor
def __len__(self):
length = len(self.files)*8
return length
# Dataset for evaluation
class RPMFullSentences(Dataset):
def __init__(self, files, ResNetAutoencoder, embed_dim, device):
self.files = files
self.autoencoder = ResNetAutoencoder
self.embed_dim = embed_dim
self.device = device
def __getitem__(self, idx):
mask = torch.ones([1,self.embed_dim]).to(self.device) # create masking token
mask_exp = torch.ones(self.embed_dim*2).to(self.device) # create mask token for tensor output
filename = self.files[idx]
data = np.load(filename)
image = data['image']
target_num = data['target'].item()
imagetensor = torch.from_numpy(image).float() / 255 # convert context panels to tensor
imagetensor = imagetensor.unsqueeze(1).to(self.device)
embeddings = self.autoencoder.get_embedding(imagetensor) # get panel embeddings
sentence = embeddings[0:8, :]
maskedsentence = torch.cat([sentence, mask], 0) # create masked sentence
target_embed = embeddings[target_num+8,:] # extract target panel embedding
mask_tensor = torch.zeros(9, self.embed_dim*2)
mask_tensor[8, :] = mask_exp # ones where the mask is, 0s elsewhere
return maskedsentence, target_embed, imagetensor, target_num, embeddings, mask_tensor
def __len__(self):
length = len(self.files)
return length