Skip to content

Commit 4b07e3d

Browse files
committed
update
1 parent 7e492ea commit 4b07e3d

File tree

2 files changed

+20
-9
lines changed

2 files changed

+20
-9
lines changed

model.py

+20-9
Original file line numberDiff line numberDiff line change
@@ -28,32 +28,43 @@
2828
loss = tf.losses.softmax_cross_entropy(onehot_labels=tfy, logits=out)
2929
accuracy = tf.metrics.accuracy( # return (acc, update_op), and create 2 local variables
3030
labels=tf.argmax(tfy, axis=1), predictions=tf.argmax(out, axis=1),)[1]
31-
opt = tf.train.AdamOptimizer(learning_rate=0.01)
31+
opt = tf.train.GradientDescentOptimizer(learning_rate=0.1)
3232
train_op = opt.minimize(loss)
3333

3434
sess = tf.Session()
3535
sess.run(tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()))
3636

3737
# training
3838
plt.ion()
39-
for t in range(2000):
39+
plt.figure(figsize=(8, 4))
40+
accuracies, steps = [], []
41+
for t in range(4000):
42+
# training
4043
batch_index = np.random.randint(len(train_data), size=32)
4144
sess.run(train_op, {tf_input: train_data[batch_index]})
45+
4246
if t % 50 == 0:
43-
acc_, pred_ = sess.run([accuracy, prediction], {tf_input: test_data})
44-
print(
45-
"Step: %i" % t,
46-
"| Accurate: %.2f" % acc_,
47-
)
47+
# testing
48+
acc_, pred_, loss_ = sess.run([accuracy, prediction, loss], {tf_input: test_data})
49+
accuracies.append(acc_)
50+
steps.append(t)
51+
print("Step: %i" % t,"| Accurate: %.2f" % acc_,"| Loss: %.2f" % loss_,)
4852

49-
# visualize training
53+
# visualize testing
54+
plt.subplot(121)
5055
plt.cla()
5156
for c in range(4):
5257
bp, = plt.bar(x=c+0.1, height=sum((np.argmax(pred_, axis=1) == c)), width=0.2, color='red')
5358
bt, = plt.bar(x=c-0.1, height=sum((np.argmax(test_data[:, 21:], axis=1) == c)), width=0.2, color='blue')
5459
plt.xticks(range(4), ["accepted", "good", "unaccepted", "very good"])
5560
plt.legend(handles=[bp, bt], labels=["prediction", "target"])
56-
plt.pause(0.1)
61+
plt.ylim((0, 400))
62+
plt.subplot(122)
63+
plt.cla()
64+
plt.plot(steps, accuracies, label="accuracy")
65+
plt.ylim(ymax=1)
66+
plt.ylabel("accuracy")
67+
plt.pause(0.01)
5768

5869
plt.ioff()
5970
plt.show()

res.png

565 KB
Loading

0 commit comments

Comments
 (0)