-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathplot.py
74 lines (66 loc) · 2.47 KB
/
plot.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import sys
from pathlib import Path
import json
import re
import matplotlib.pyplot as plt
import numpy as np
# params = {'text.usetex': True, 'font.family': 'serif'}
# plt.rcParams.update(params)
default_cycler = list(plt.rcParams['axes.prop_cycle'])
import os
def get_style(label):
style = {} # default_cycler[pos].copy()
"""for pos, key in enumerate(['acc', 'auc', 'map']):
if key in label:
style = default_cycler[pos].copy()
style['lw'] = 1 + pos
if 'best' in label:
style['linestyle'] = '--'"""
return style
def plot_after(data, filename):
plt.clf()
fig, axes = plt.subplots(1, 4, figsize=(16, 4))
pos = 0
displayed = {
'auc': 'Area under the ROC curve',
'acc': 'Accuracy',
'map': 'Mean average precision',
'mean test variance': 'Mean variance',
'auc_of_mean': 'AUC of mean'
}
permutation = [0, 1, 2, 3, 4]
# print(data['metrics']['random'])
for metric in sorted(set(data['metrics']['mean']) - {
'nb_train_samples', 'epoch', 'best epoch'}):
# print(metric)
if 'best' in metric or 'all' in metric or 'nll' in metric:
continue
# print(pos, metric)
ax_i = 0
ax_j = permutation[pos]
axes[ax_j].title.set_text(displayed[metric])
axes[ax_j].set_xlabel('Number of items asked')
axes[ax_j].set_xticks(np.arange(4, 24, 4))
for _, strategy in enumerate(['random', 'mean', 'variance']):
if strategy in data['metrics']:
# if strategy in {'random', 'variance'}:
# continue
# print(strategy)
x = np.array(data['metrics'][strategy]['nb_train_samples']) / 20
print(strategy, metric, ' & '.join(map(str, np.round(data['metrics'][strategy][metric], 3))))
axes[ax_j].plot(x, data['metrics'][strategy][metric], label=strategy, **get_style(strategy))
plt.legend()
pos += 1
# for ip in range(2):
# for jp in range(3):
# axes[ip, jp].legend()
fig_name = str(filename).replace('txt', 'after.pdf')
plt.savefig(f'{fig_name}', bbox_inches='tight')
return fig_name
if __name__ == '__main__':
r = re.compile(r'-([0-9]+).txt')
logfilename = sys.argv[1]
with open(logfilename) as f:
data = json.load(f)
fig_name = plot_after(data, logfilename)
os.system(f'open {fig_name}')