-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathdetect_alpha_peak.py
executable file
·183 lines (121 loc) · 6.03 KB
/
detect_alpha_peak.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
#!/usr/local/bin/python3
import json
import mne
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
def get_alpha_freqs(freqs):
# Select band 7 - 14Hz #
# Extract indices of alpha freqs
indexes_alpha_freqs = [i for i, f in zip(range(0, len(freqs)), freqs) if f > 6.9 and f < 14.1]
# Extract alpha freqs
alpha_freqs = np.take(freqs, indexes_alpha_freqs)
return alpha_freqs, indexes_alpha_freqs
def detect_alpha_peak_mean(psd_welch, alpha_freqs, indexes_alpha_freqs):
# Average PSD across all channels
psd_welch_mean = np.mean(psd_welch, axis=0)
# Get the std of the mean
psd_welch_std = np.std(psd_welch, axis=0)
# Extract psd in alpha freqs
psd_in_alpha_freqs_mean = np.take(psd_welch_mean, indexes_alpha_freqs)
# Find peak
pic_loc = mne.preprocessing.peak_finder(psd_in_alpha_freqs_mean)
# Find the corresponding frequency
index_of_the_pic = int(pic_loc[0])
alpha_freq_pic_mean = alpha_freqs[index_of_the_pic] # Apply the index of the pic in alpha_freqs, not in freqs!
return alpha_freq_pic_mean, psd_welch_mean, psd_welch_std, psd_in_alpha_freqs_mean
def detect_alpha_peak_per_channels(psd_welch, alpha_freqs, indexes_alpha_freqs):
# Get alpha peak frequency for all channels #
alpha_freq_pic_per_channel = []
for channel in range(0, psd_welch.shape[0]):
# Extract psd in alpha freqs
psd_in_alpha_freqs_per_channel = np.take(psd_welch[channel, :], indexes_alpha_freqs)
# Find peak
pic_loc = mne.preprocessing.peak_finder(psd_in_alpha_freqs_per_channel)
pic_loc = pic_loc[0]
if len(pic_loc) > 1: # if more than 1 peak is found
pic_loc = max(pic_loc)
# Find the corresponding frequency
index_of_the_pic = int(pic_loc)
alpha_freq_pic = alpha_freqs[index_of_the_pic] # Apply the index of the pic in alpha_freqs, not in freqs!
alpha_freq_pic_per_channel.append(alpha_freq_pic)
return alpha_freq_pic_per_channel, psd_in_alpha_freqs_per_channel
def plot_psd_mean(freqs, alpha_freq_pic_mean, psd_welch_mean, psd_welch_std, alpha_freqs, psd_in_alpha_freqs_mean):
plt.figure()
# Get the index of alpha peak
id_alpha_peak = np.where(freqs==alpha_freq_pic_mean)
# Define lim
plt.xlim(xmin=0, xmax=max(freqs))
# Plot spectrum
plt.plot(freqs, psd_welch_mean, zorder=1)
# Plot a red point on the alpha peak
plt.scatter(alpha_freq_pic_mean, psd_welch_mean[id_alpha_peak], marker='o', color="red", label='alpha peak frequency', zorder=3)
# Plot std of the mean as shaded area
plt.fill_between(freqs, psd_welch_mean-psd_welch_std, psd_welch_mean+psd_welch_std, alpha=0.5, label='standard deviation', zorder=2)
# Shadow the frequencies in which we look for the peak
# plt.ylim(ymin=min(psd_welch_mean))
plt.axvline(x=min(alpha_freqs), zorder=4, color='r', linestyle='--', alpha=0.2, label='alpha band')
plt.axvline(x=max(alpha_freqs), zorder=4, color='r', linestyle='--', alpha=0.2)
# Define labels
plt.xlabel('Frequency (Hz)')
plt.ylabel('Power Spectral Density')
plt.title('Mean power spectrum across all channels')
plt.legend()
# Save fig
plt.savefig('out_dir/psd_mean.png')
def plot_psd_per_channels(freqs, alpha_freq_pic_per_channel, psd_in_alpha_freqs_per_channel, psd_welch, alpha_freqs):
plt.figure()
for channel in range(0, len(alpha_freq_pic_per_channel)):
plt.plot(freqs, psd_welch[channel], zorder=1)
# Define lim
plt.xlim(xmin=0, xmax=max(freqs))
# Plot alpha band
plt.axvline(x=min(alpha_freqs), zorder=4, color='r', linestyle='--', alpha=0.6, label='alpha band')
plt.axvline(x=max(alpha_freqs), zorder=4, color='r', linestyle='--', alpha=0.6)
# Define labels
plt.xlabel('Frequency (Hz)')
plt.ylabel('Power Spectral Density')
plt.title('Power spectrum for all channels')
plt.legend()
# Save fig
plt.savefig('out_dir/psd_channels.png')
def main():
# Load inputs from config.json
with open('config.json') as config_json:
config = json.load(config_json)
# Load csv
path_to_input_file = config.pop('psd')
# To be able to read input raw datatype
# if "out_dir/." in path_to_input_file:
# path_to_input_file = path_to_input_file.replace('out_dir/.', 'out_dir/psd.csv')
# Read the outputs of PSD app #
# Extract PSD
df_psd_welch = pd.read_csv(path_to_input_file)
df_psd_welch = df_psd_welch.drop(["Unnamed: 0"], axis=1)
psd_welch = df_psd_welch.to_numpy()
# Extract freqs
freqs = df_psd_welch.columns.to_numpy()
freqs = freqs.astype(np.float)
# Get alpha freqs
alpha_freqs, indexes_alpha_freqs = get_alpha_freqs(freqs)
# Detect alpha peak on average channels #
alpha_freq_pic_mean, psd_welch_mean, psd_welch_std, psd_in_alpha_freqs_mean = detect_alpha_peak_mean(psd_welch, alpha_freqs, indexes_alpha_freqs)
# Detect alpha peak per channels #
alpha_freq_pic_per_channel, psd_in_alpha_freqs_per_channel = detect_alpha_peak_per_channels(psd_welch, alpha_freqs, indexes_alpha_freqs)
# Create a DataFrame with alpha peak values #
# Values for each channel
channels = [f"channel_{i}" for i in range(0, len(alpha_freq_pic_per_channel))]
d_all_channels = {'channels': channels, 'alpha peak frequency': alpha_freq_pic_per_channel}
df_alpha_peaks = pd.DataFrame(data=d_all_channels)
# Value for mean PSD across channels
d_mean_channels = {'channels': "mean channels", 'alpha peak frequency': alpha_freq_pic_mean}
df_alpha_peaks = df_alpha_peaks.append(d_mean_channels, ignore_index=True)
# Save it into a csv
df_alpha_peaks.to_csv('out_dir/alpha_peak_frequency.csv', index=False)
# Plot spectrum #
# Mean spectrum
plot_psd_mean(freqs, alpha_freq_pic_mean, psd_welch_mean, psd_welch_std, alpha_freqs, psd_in_alpha_freqs_mean)
# All channels
plot_psd_per_channels(freqs, alpha_freq_pic_per_channel, psd_in_alpha_freqs_per_channel, psd_welch, alpha_freqs)
if __name__ == '__main__':
main()