-
Notifications
You must be signed in to change notification settings - Fork 19
/
data.py
118 lines (108 loc) · 3.81 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
from re import X
import scipy.io
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.dataloader import default_collate
from sklearn.preprocessing import OrdinalEncoder
class dataset(Dataset):
def __init__(self, X, y, train=True):
self.X = X
self.y = y
self.train=train
def __len__(self):
return len(self.y)
# def __getitem__(self, idx):
# rng = np.random.randint(0, high=200)
# if self.train:
# x = self.X[idx][:, rng:rng + 600]
# else:
# x = self.X[idx][:, 200: 800]
# return x, self.y[idx]
def __getitem__(self, idx):
x = self.X[idx]
if self.train:
# rn = np.random.randint(0, high=500)
# x = x[:, rn:rn+4000]
x = x[:, 0:4000]
else:
x = x[:, 0:4000]
return x, self.y[idx]
def generate_x_train(mat):
# out: num_data_points * chl * trial_length
data = []
last_label = False
for i in range(0, len(mat['mrk'][0][0][0][0])-1):
start_idx = mat['mrk'][0][0][0][0][i]
end_idx = mat['mrk'][0][0][0][0][i+1]
# to resolve shape issues, we use a shifted window
# (possible overlapping but acceptable given it's trivial)
end_idx += (8000 + start_idx - end_idx)
data.append(mat['cnt'][start_idx: end_idx,].T)
# add the last datapoint
if len(mat['cnt']) - mat['mrk'][0][0][0][0][-1] >= 8000:
last_label = True
start_idx = mat['mrk'][0][0][0][0][-1]
end_idx = start_idx + 8000
data.append(mat['cnt'][start_idx: end_idx,].T)
return np.array(data), last_label
def generate_y_train(mat, last_label):
# out: 1 * num_labels
class1, class2 = mat['nfo']['classes'][0][0][0][0][0], mat['nfo']['classes'][0][0][0][1][0]
mapping = {-1: class1, 1: class2}
labels = np.vectorize(mapping.get)(mat['mrk'][0][0][1])[0]
if not last_label:
labels = labels[:-1]
return labels
def generate_data(files):
X, y = [], []
for file in files:
print(file)
mat = scipy.io.loadmat(file)
X_batch, last_label = generate_x_train(mat)
X.append(X_batch)
y.append(generate_y_train(mat, last_label))
X, y = np.concatenate(X, axis=0), np.concatenate(y)
y = OrdinalEncoder().fit_transform(y.reshape(-1, 1))
return X, y
def split_data(X, y):
def get_idx():
np.random.seed(seed=42)
rng = np.random.choice(len(y), len(y), replace=False)
return rng
train_size, val_size, test_size = 1000, 197, 200
indices = get_idx()
train_idx, val_idx, test_idx = indices[0: train_size], \
indices[train_size: train_size + val_size], indices[train_size + val_size:]
train_X, train_y, val_X, val_y, test_X, test_y = \
X[train_idx], y[train_idx], X[val_idx], y[val_idx], X[test_idx], y[test_idx]
return train_X, train_y, val_X, val_y, test_X, test_y
def get_loaders(train_X, train_y, val_X, val_y, test_X, test_y):
train_set, val_set, test_set = dataset(train_X, train_y, True), dataset(val_X, val_y, False), dataset(test_X, test_y, False)
data_loader_train = torch.utils.data.DataLoader(
train_set,
batch_size=1,
num_workers=1,
pin_memory=True,
drop_last=False,
)
data_loader_val = torch.utils.data.DataLoader(
val_set,
batch_size=1,
num_workers=1,
pin_memory=True,
drop_last=False,
)
data_loader_test = torch.utils.data.DataLoader(
test_set,
batch_size=1,
num_workers=1,
pin_memory=True,
drop_last=False,
)
dataloaders = {
'train': data_loader_train,
'val': data_loader_val,
'test': data_loader_test
}
return dataloaders