-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathread_parallel_results.py
48 lines (38 loc) · 1.3 KB
/
read_parallel_results.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import os
import pickle
import numpy as np
import ipdb
def append_data(name, quantity):
with open(name, 'rb') as f:
data = pickle.load(f)
quantity.append(data)
return quantity
def consolidate_results(root, max_parallel_id=0):
NE_list = range(0, 600, 10)
with open(os.path.join(root, 'results.txt'), 'w') as f:
for ine, num_epochs in enumerate(NE_list):
acc_all_le = []
acc_all = []
accDiff_all = []
chklist = []
for run in range(0, max_parallel_id + 1):
linev_name = os.path.join(root, str(run), 'results/linev.pkl')
final_name = os.path.join(root, str(run), 'results/wi_final%d.pkl' % (num_epochs))
if not os.path.exists(final_name):
chklist.append(run)
continue
delta_name = os.path.join(root, str(run), 'results/wi_delta%d.pkl' % (num_epochs))
acc_all_le = append_data(linev_name, acc_all_le)
acc_all = append_data(final_name, acc_all)
accDiff_all = append_data(delta_name, accDiff_all)
if len(acc_all_le) == 0:
continue
acc_all_le = np.hstack(acc_all_le)
acc_all = np.hstack(acc_all)
nTasks = acc_all_le.shape[0]
acc_mean = np.mean(acc_all)
acc_std = 1.96 * np.std(acc_all) / np.sqrt(nTasks)
strn = "accuracy = %4.2f +- %4.2f" % (acc_mean, acc_std)
print(strn)
if __name__ == "__main__":
consolidate_results(root)