-
Notifications
You must be signed in to change notification settings - Fork 17
/
Copy pathdataset.py
207 lines (174 loc) · 7.81 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
import logging
import re
import numpy as np
import glob
import os.path
import mne
log = logging.getLogger(__name__)
def session_key(file_name):
""" sort the file name by session """
return re.findall(r'(s\d{2})', file_name)
def natural_key(file_name):
""" provides a human-like sorting key of a string """
key = [int(token) if token.isdigit() else None
for token in re.split(r'(\d+)', file_name)]
return key
def time_key(file_name):
""" provides a time-based sorting key """
splits = file_name.split('/')
[date] = re.findall(r'(\d{4}_\d{2}_\d{2})', splits[-2])
date_id = [int(token) for token in date.split('_')]
recording_id = natural_key(splits[-1])
session_id = session_key(splits[-2])
return date_id + session_id + recording_id
def read_all_file_names(path, extension, key="time"):
""" read all files with specified extension from given path
:param path: parent directory holding the files directly or in subdirectories
:param extension: the type of the file, e.g. '.txt' or '.edf'
:param key: the sorting of the files. natural e.g. 1, 2, 12, 21 (machine 1, 12, 2, 21) or by time since this is
important for cv. time is specified in the edf file names
"""
file_paths = glob.glob(path + '**/*' + extension, recursive=True)
if key == 'time':
return sorted(file_paths, key=time_key)
elif key == 'natural':
return sorted(file_paths, key=natural_key)
def get_info_with_mne(file_path):
""" read info from the edf file without loading the data. loading data is done in multiprocessing since it takes
some time. getting info is done before because some files had corrupted headers or weird sampling frequencies
that caused the multiprocessing workers to crash. therefore get and check e.g. sampling frequency and duration
beforehand
:param file_path: path of the recording file
:return: file name, sampling frequency, number of samples, number of signals, signal names, duration of the rec
"""
try:
edf_file = mne.io.read_raw_edf(file_path, verbose='error')
except ValueError:
return None, None, None, None, None, None
# fix_header(file_path)
# try:
# edf_file = mne.io.read_raw_edf(file_path, verbose='error')
# logging.warning("Fixed it!")
# except ValueError:
# return None, None, None, None, None, None
# some recordings have a very weird sampling frequency. check twice before skipping the file
sampling_frequency = int(edf_file.info['sfreq'])
if sampling_frequency < 10:
sampling_frequency = 1 / (edf_file.times[1] - edf_file.times[0])
if sampling_frequency < 10:
return None, sampling_frequency, None, None, None, None
n_samples = edf_file.n_times
signal_names = edf_file.ch_names
n_signals = len(signal_names)
# some weird sampling frequencies are at 1 hz or below, which results in division by zero
duration = n_samples / max(sampling_frequency, 1)
# TODO: return rec object?
return edf_file, sampling_frequency, n_samples, n_signals, signal_names, duration
def get_recording_length(file_path):
""" some recordings were that huge that simply opening them with mne caused the program to crash. therefore, open
the edf as bytes and only read the header. parse the duration from there and check if the file can safely be opened
:param file_path: path of the directory
:return: the duration of the recording
"""
f = open(file_path, 'rb')
header = f.read(256)
f.close()
return int(header[236:244].decode('ascii'))
def load_data(fname, preproc_functions, sensor_types=['EEG']):
cnt, sfreq, n_samples, n_channels, chan_names, n_sec = get_info_with_mne(
fname)
log.info("Load data...")
cnt.load_data()
selected_ch_names = []
if 'EEG' in sensor_types:
wanted_elecs = ['A1', 'A2', 'C3', 'C4', 'CZ', 'F3', 'F4', 'F7', 'F8', 'FP1',
'FP2', 'FZ', 'O1', 'O2',
'P3', 'P4', 'PZ', 'T3', 'T4', 'T5', 'T6']
for wanted_part in wanted_elecs:
wanted_found_name = []
for ch_name in cnt.ch_names:
if ' ' + wanted_part + '-' in ch_name:
wanted_found_name.append(ch_name)
assert len(wanted_found_name) == 1
selected_ch_names.append(wanted_found_name[0])
if 'EKG' in sensor_types:
wanted_found_name = []
for ch_name in cnt.ch_names:
if 'EKG' in ch_name:
wanted_found_name.append(ch_name)
assert len(wanted_found_name) == 1
selected_ch_names.append(wanted_found_name[0])
cnt = cnt.pick_channels(selected_ch_names)
assert np.array_equal(cnt.ch_names, selected_ch_names)
n_sensors = 0
if 'EEG' in sensor_types:
n_sensors += 21
if 'EKG' in sensor_types:
n_sensors += 1
assert len(cnt.ch_names) == n_sensors, (
"Expected {:d} channel names, got {:d} channel names".format(
n_sensors, len(cnt.ch_names)))
# change from volt to mikrovolt
data = (cnt.get_data() * 1e6).astype(np.float32)
fs = cnt.info['sfreq']
log.info("Preprocessing...")
for fn in preproc_functions:
log.info(fn)
data, fs = fn(data, fs)
data = data.astype(np.float32)
fs = float(fs)
return data
def get_all_sorted_file_names_and_labels(train_or_eval, folders):
all_file_names = []
for folder in folders:
full_folder = os.path.join(folder, train_or_eval) + '/'
log.info("Reading {:s}...".format(full_folder))
this_file_names = read_all_file_names(full_folder, '.edf', key='time')
log.info(".. {:d} files.".format(len(this_file_names)))
all_file_names.extend(this_file_names)
log.info("{:d} files in total.".format(len(all_file_names)))
all_file_names = sorted(all_file_names, key=time_key)
labels = ['/abnormal/' in f for f in all_file_names]
labels = np.array(labels).astype(np.int64)
return all_file_names, labels
class DiagnosisSet(object):
def __init__(self, n_recordings, max_recording_mins, preproc_functions,
data_folders,
train_or_eval='train', sensor_types=['EEG'],):
self.n_recordings = n_recordings
self.max_recording_mins = max_recording_mins
self.preproc_functions = preproc_functions
self.train_or_eval = train_or_eval
self.sensor_types = sensor_types
self.data_folders = data_folders
def load(self, only_return_labels=False):
log.info("Read file names")
all_file_names, labels = get_all_sorted_file_names_and_labels(
train_or_eval=self.train_or_eval,
folders=self.data_folders,)
if self.max_recording_mins is not None:
log.info("Read recording lengths...")
assert 'train' == self.train_or_eval
# Computation as:
lengths = [get_recording_length(fname) for fname in all_file_names]
lengths = np.array(lengths)
mask = lengths < self.max_recording_mins * 60
cleaned_file_names = np.array(all_file_names)[mask]
cleaned_labels = labels[mask]
else:
cleaned_file_names = np.array(all_file_names)
cleaned_labels = labels
if only_return_labels:
return cleaned_labels
X = []
y = []
n_files = len(cleaned_file_names[:self.n_recordings])
for i_fname, fname in enumerate(cleaned_file_names[:self.n_recordings]):
log.info("Load {:d} of {:d}".format(i_fname + 1,n_files))
x = load_data(fname, preproc_functions=self.preproc_functions,
sensor_types=self.sensor_types)
assert x is not None
X.append(x)
y.append(cleaned_labels[i_fname])
y = np.array(y)
return X, y