forked from MELABIPCAS/DVSCL
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdataset_multi.py
138 lines (89 loc) · 4.56 KB
/
dataset_multi.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import torch
from torch.utils import data
class RandomSampler(data.sampler.Sampler):
def __init__(self, data_source, inputrandom, replacement=False, num_samples=None):
self.data_source = data_source
self.replacement = replacement
self._num_samples = num_samples
self.s = inputrandom
if not isinstance(self.replacement, bool):
raise ValueError("replacement should be a boolean value, but got "
"replacement={}".format(self.replacement))
if self._num_samples is not None and not replacement:
raise ValueError("With replacement=False, num_samples should not be specified, "
"since a random permute will be performed.")
if not isinstance(self.num_samples, int) or self.num_samples <= 0:
raise ValueError("num_samples should be a positive integer "
"value, but got num_samples={}".format(self.num_samples))
@property
def num_samples(self):
# dataset size might change at runtime
if self._num_samples is None:
return len(self.data_source)
return self._num_samples
def __iter__(self):
n = len(self.data_source)
if self.replacement:
return iter(torch.randint(high=n, size=(self.num_samples,), dtype=torch.int64).tolist())
return iter(self.s)
def __len__(self):
return self.num_samples
# 2个数据集拼成一个才需要用到
class ConcatDataset(torch.utils.data.Dataset):
def __init__(self, *datasets):
self.datasets = datasets
def __getitem__(self, i):
return tuple(d[i] for d in self.datasets)
def __len__(self):
return min(len(d) for d in self.datasets)
"""import numpy as np
from torch.utils.data.dataset import Dataset
import pickle
import os
from scipy import signal
import torch
if torch.cuda.is_available():
torch.set_default_tensor_type('torch.cuda.FloatTensor')
else:
torch.set_default_tensor_type('torch.FloatTensor')
############################################################################################
# This file provides basic processing script for the multimodal datasets we use. For other
# datasets, small modifications may be needed (depending on the type of the data, etc.)
############################################################################################
class Multimodal_Datasets(Dataset):
def __init__(self, dataset_path, data='mosei_senti', split_type='train', if_align=False):
super(Multimodal_Datasets, self).__init__()
dataset_path = os.path.join(dataset_path, data + '_data.pkl' if if_align else data + '_data_noalign.pkl')
dataset = pickle.load(open(dataset_path, 'rb'))
# These are torch tensors
self.vision = torch.tensor(dataset[split_type]['vision'].astype(np.float32)).cpu().detach()
self.text = torch.tensor(dataset[split_type]['text'].astype(np.float32)).cpu().detach()
self.audio = dataset[split_type]['audio'].astype(np.float32)
self.audio[self.audio == -np.inf] = 0
self.audio = torch.tensor(self.audio).cpu().detach()
self.labels = torch.tensor(dataset[split_type]['labels'].astype(np.float32)).cpu().detach()
# Note: this is STILL an numpy array
self.meta = dataset[split_type]['id'] if 'id' in dataset[split_type].keys() else None
self.data = data
self.n_modalities = 3 # vision/ text/ audio
def get_n_modalities(self):
return self.n_modalities
def get_seq_len(self):
return self.text.shape[1], self.audio.shape[1], self.vision.shape[1]
def get_dim(self):
return self.text.shape[2], self.audio.shape[2], self.vision.shape[2]
def get_lbl_info(self):
# return number_of_labels, label_dim
return self.labels.shape[1], self.labels.shape[2]
def __len__(self):
return len(self.labels)
def __getitem__(self, index):
X = (index, self.text[index], self.audio[index], self.vision[index])
Y = self.labels[index]
META = (0, 0, 0) if self.meta is None else (self.meta[index][0], self.meta[index][1], self.meta[index][2])
if self.data == 'mosi':
META = (self.meta[index][0].decode('UTF-8'), self.meta[index][1].decode('UTF-8'),
self.meta[index][2].decode('UTF-8'))
if self.data == 'iemocap':
Y = torch.argmax(Y, dim=-1)
return X, Y, META"""