-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathparam_comp.py
89 lines (71 loc) · 1.92 KB
/
param_comp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import json
import os
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
# get paths
home_path = os.getcwd()
results_path = home_path + '/results/'
plot_path = home_path + '/plots/'
# set filename
fnames = [
'training_v1',
'training_v2',
# 'training_v3'
]
dataDict = {}
for fname in fnames:
fpath = results_path + fname
with open(fpath, 'r') as fp:
results = json.load(fp)
for model, params in results.items():
dataDict[model] = {
'embeddings':params['embeddings'],
'embedding_dim':params['embedding_dim'],
'seq_length':params['seq_length'],
# 'units':params['units'],
'layers':len(params['units']),
'trainLoss':np.mean(params['trainLossHist'][-50:]) / params['seq_length'],
'valLoss':np.mean(params['valLossHist'][-50:]) / params['seq_length']
}
# make DF
plotData = pd.DataFrame(dataDict).T
# # get plotVals
layers = plotData.groupby(['layers', 'embeddings']).mean()['trainLoss']
layers = layers.reset_index()
g = sns.catplot(
data=layers,
kind="bar",
x="layers",
y="trainLoss",
hue="embeddings",
palette="Blues",
alpha=1.0,
height=6
)
plt.ylim(1.4, 2.2)
plt.xlabel('Number of layers')
plt.ylabel('Loss', rotation=0, labelpad=15)
plt.title('Mean loss for last 5000 steps')
plt.savefig(plot_path + 'compEmbedd', dpi=200)
plt.show()
# # get plotVals
layers = plotData.groupby(['layers', 'seq_length']).mean()['trainLoss']
layers = layers.reset_index()
g = sns.catplot(
data=layers,
kind="bar",
x="layers",
y="trainLoss",
hue="seq_length",
palette='rocket',
alpha=1.0,
height=6
)
plt.ylim(1.4, 2.2)
plt.xlabel('Number of layers')
plt.ylabel('Loss', rotation=0, labelpad=15)
plt.title('Mean loss for last 5000 steps')
plt.savefig(plot_path + 'compSeqLen.png', dpi=200)
plt.show()