-
Notifications
You must be signed in to change notification settings - Fork 19
/
Copy pathplot_c3c4cz.py
81 lines (58 loc) · 2.17 KB
/
plot_c3c4cz.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import numpy as np
from matplotlib import pyplot as plt
class MotorImageryDataset:
def __init__(self, dataset='A01T.npz'):
if not dataset.endswith('.npz'):
dataset += '.npz'
self.data = np.load(dataset)
self.Fs = 250 # 250Hz from original paper
# keys of data ['s', 'etyp', 'epos', 'edur', 'artifacts']
self.raw = self.data['s'].T
self.events_type = self.data['etyp'].T
self.events_position = self.data['epos'].T
self.events_duration = self.data['edur'].T
self.artifacts = self.data['artifacts'].T
# Types of motor imagery
self.mi_types = {769: 'left', 770: 'right',
771: 'foot', 772: 'tongue', 783: 'unknown'}
def get_trials_from_channel(self, channel=7):
# Channel default is C3
startrial_code = 768
starttrial_events = self.events_type == startrial_code
idxs = [i for i, x in enumerate(starttrial_events[0]) if x]
trials = []
classes = []
for index in idxs:
try:
type_e = self.events_type[0, index+1]
class_e = self.mi_types[type_e]
classes.append(class_e)
start = self.events_position[0, index]
stop = start + self.events_duration[0, index]
trial = self.raw[channel, start:stop]
trial = trial.reshape((1, -1))
trials.append(trial)
except:
continue
return trials, classes
def get_trials_from_channels(self, channels=[7, 9, 11]):
trials_c = []
classes_c = []
for c in channels:
t, c = self.get_trials_from_channel(channel=c)
tt = np.concatenate(t, axis=0)
trials_c.append(tt)
classes_c.append(c)
return trials_c, classes_c
datasetA1 = MotorImageryDataset('A01T.npz')
trials, classes = datasetA1.get_trials_from_channels([7, 9, 11])
plt.subplot(3, 1, 1)
plt.imshow(trials[0])
plt.title('C3', size=22)
plt.subplot(3, 1, 2)
plt.imshow(trials[1])
plt.title('Cz', size=22)
plt.subplot(3, 1, 3)
plt.imshow(trials[2])
plt.title('C4', size=22)
plt.show()