-
Notifications
You must be signed in to change notification settings - Fork 19
/
utils.py
58 lines (51 loc) · 1.79 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import requests
import zipfile
import numpy as np
import matplotlib.pyplot as plt
def get_data():
print('Downloading started')
url = 'http://bbci.de/competition/download/competition_iv/BCICIV_1calib_1000Hz_mat.zip'
username = 'replace_with_your_own_username'
password = 'replace_with_your_own_password'
req = requests.get(url, auth=(username,password))
filename = url.split('/')[-1]
with open(filename,'wb') as output_file:
output_file.write(req.content)
print('Downloading Completed')
# Change to your path
print('Unzipping')
path_to_zip_file = 'BCICIV_1calib_1000Hz_mat.zip'
with zipfile.ZipFile(path_to_zip_file, 'r') as zip_ref:
zip_ref.extractall('data/')
print('Unzipping Completed')
return
def parse_log(fp):
log = {
'log_info': fp,
'train_loss': [],
'val_loss': [],
'train_acc':[],
'val_acc': []
}
f = open(fp, 'r')
lines = f.readlines()
for line in lines:
if 'train Loss' in line:
line = line.split()
log['train_loss'].append(float(line[2]))
log['train_acc'].append(float(line[4]))
elif 'val Loss' in line:
line = line.split()
log['val_loss'].append(float(line[2]))
log['val_acc'].append(float(line[4]))
return log
def plot_log(fp):
log = parse_log(fp)
f, ax = plt.subplots(1, 2, figsize=(15,5))
ax[0].plot(log['train_acc'], label='Training Accuracy', linestyle='dashed')
ax[0].plot(log['val_acc'], label='Validation Accuracy', linestyle='dashed')
ax[0].legend()
ax[1].plot(log['train_loss'], label='Training Loss', linestyle='dashed')
ax[1].plot(log['val_loss'], label='Validation Loss', linestyle='dashed')
ax[1].legend()
f.suptitle(f'Experiment Log: {fp}')