Skip to content

Commit b6cd03d

Browse files
tested plot_loss.py
1 parent f83bfb5 commit b6cd03d

File tree

3 files changed

+53
-45
lines changed

3 files changed

+53
-45
lines changed

script/plot_loss.py

+52-45
Original file line numberDiff line numberDiff line change
@@ -1,51 +1,58 @@
11
import matplotlib.pyplot as plt
2-
f = open("log_sample.tsv","r")
3-
datalist=f.readlines()
2+
import sys
43

5-
epoch = []
6-
train_cost = []
7-
train_cost_recons = []
8-
train_cost_temp = []
9-
train_cost_pred = []
10-
valid_cost = []
11-
valid_cost_recons = []
12-
valid_cost_temp = []
13-
valid_cost_pred = []
4+
def plot_loss(loss_txt,loss_image):
5+
f = open(loss_txt,"r")
6+
datalist=f.readlines()
147

15-
for data in datalist[1:]:
16-
a = data.split('\t')
17-
epoch.append(a[1])
18-
train_cost.append(float(a[2]))
19-
train_cost_recons.append(float(a[9]))
20-
train_cost_temp.append(float(a[10])*0.5)
21-
train_cost_pred.append(float(a[11])*0.1)
22-
valid_cost.append(float(a[3]))
23-
valid_cost_recons.append(float(a[12]))
24-
valid_cost_temp.append(float(a[13])*0.5)
25-
valid_cost_pred.append(float(a[14])*0.1)
8+
epoch = []
9+
train_cost = []
10+
train_cost_recons = []
11+
train_cost_temp = []
12+
train_cost_pred = []
13+
valid_cost = []
14+
valid_cost_recons = []
15+
valid_cost_temp = []
16+
valid_cost_pred = []
2617

27-
plt.style.use("grayscale")
28-
fig = plt.figure(figsize=(16.0,6.0))
29-
ax1 = fig.add_subplot(1, 2, 1)
30-
ax1.set_title("train cost")
31-
ax1.set_xlabel("epoch")
32-
ax1.set_xlim([0,1000])
33-
ax1.set_ylim([0.01,5.0])
34-
ax1.plot(train_cost,label="total")
35-
ax1.plot(train_cost_recons,label="recons")
36-
ax1.plot(train_cost_temp,label="alpha*temporal")
37-
ax1.plot(train_cost_pred,label="beta*prediction")
38-
ax1.legend()
18+
for data in datalist[1:]:
19+
a = data.split('\t')
20+
epoch.append(a[1])
21+
train_cost.append(float(a[2]))
22+
train_cost_recons.append(float(a[9]))
23+
train_cost_temp.append(float(a[10])*0.5)
24+
train_cost_pred.append(float(a[11])*0.1)
25+
valid_cost.append(float(a[3]))
26+
valid_cost_recons.append(float(a[12]))
27+
valid_cost_temp.append(float(a[13])*0.5)
28+
valid_cost_pred.append(float(a[14])*0.1)
3929

40-
ax2 = fig.add_subplot(1, 2, 2)
41-
ax2.set_title("valid cost")
42-
ax2.set_xlabel("epoch")
43-
ax2.set_xlim([0,1000])
44-
ax2.set_ylim([0.01,5.0])
45-
ax2.plot(train_cost,label="total")
46-
ax2.plot(train_cost_recons,label="recons")
47-
ax2.plot(train_cost_temp,label="alpha*temporal")
48-
ax2.plot(train_cost_pred,label="beta*prediction")
49-
ax2.legend()
30+
plt.style.use("grayscale")
31+
fig = plt.figure(figsize=(16.0,6.0))
32+
ax1 = fig.add_subplot(1, 2, 1)
33+
ax1.set_title("train cost")
34+
ax1.set_xlabel("epoch")
35+
ax1.set_xlim([0,1000])
36+
ax1.set_ylim([0.01,5.0])
37+
ax1.plot(train_cost,label="total")
38+
ax1.plot(train_cost_recons,label="recons")
39+
ax1.plot(train_cost_temp,label="alpha*temporal")
40+
ax1.plot(train_cost_pred,label="beta*prediction")
41+
ax1.legend()
5042

51-
plt.savefig("plot_loss_sample.png")
43+
ax2 = fig.add_subplot(1, 2, 2)
44+
ax2.set_title("valid cost")
45+
ax2.set_xlabel("epoch")
46+
ax2.set_xlim([0,1000])
47+
ax2.set_ylim([0.01,5.0])
48+
ax2.plot(train_cost,label="total")
49+
ax2.plot(train_cost_recons,label="recons")
50+
ax2.plot(train_cost_temp,label="alpha*temporal")
51+
ax2.plot(train_cost_pred,label="beta*prediction")
52+
ax2.legend()
53+
54+
plt.savefig(loss_image)
55+
56+
if __name__ == '__main__':
57+
args = sys.argv
58+
plot_loss(args[1],args[2])
71.8 KB
Loading

script/test_plot_loss.sh

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
python plot_loss.py sample_plot_loss/log_sample.txt sample_plot_loss/sample_plot_loss.png

0 commit comments

Comments
 (0)