28
28
loss = tf .losses .softmax_cross_entropy (onehot_labels = tfy , logits = out )
29
29
accuracy = tf .metrics .accuracy ( # return (acc, update_op), and create 2 local variables
30
30
labels = tf .argmax (tfy , axis = 1 ), predictions = tf .argmax (out , axis = 1 ),)[1 ]
31
- opt = tf .train .AdamOptimizer (learning_rate = 0.01 )
31
+ opt = tf .train .GradientDescentOptimizer (learning_rate = 0.1 )
32
32
train_op = opt .minimize (loss )
33
33
34
34
sess = tf .Session ()
35
35
sess .run (tf .group (tf .global_variables_initializer (), tf .local_variables_initializer ()))
36
36
37
37
# training
38
38
plt .ion ()
39
- for t in range (2000 ):
39
+ plt .figure (figsize = (8 , 4 ))
40
+ accuracies , steps = [], []
41
+ for t in range (4000 ):
42
+ # training
40
43
batch_index = np .random .randint (len (train_data ), size = 32 )
41
44
sess .run (train_op , {tf_input : train_data [batch_index ]})
45
+
42
46
if t % 50 == 0 :
43
- acc_ , pred_ = sess . run ([ accuracy , prediction ], { tf_input : test_data })
44
- print (
45
- "Step: %i" % t ,
46
- "| Accurate: %.2f" % acc_ ,
47
- )
47
+ # testing
48
+ acc_ , pred_ , loss_ = sess . run ([ accuracy , prediction , loss ], { tf_input : test_data })
49
+ accuracies . append ( acc_ )
50
+ steps . append ( t )
51
+ print ( "Step: %i" % t , "| Accurate: %.2f" % acc_ , "| Loss: %.2f" % loss_ , )
48
52
49
- # visualize training
53
+ # visualize testing
54
+ plt .subplot (121 )
50
55
plt .cla ()
51
56
for c in range (4 ):
52
57
bp , = plt .bar (x = c + 0.1 , height = sum ((np .argmax (pred_ , axis = 1 ) == c )), width = 0.2 , color = 'red' )
53
58
bt , = plt .bar (x = c - 0.1 , height = sum ((np .argmax (test_data [:, 21 :], axis = 1 ) == c )), width = 0.2 , color = 'blue' )
54
59
plt .xticks (range (4 ), ["accepted" , "good" , "unaccepted" , "very good" ])
55
60
plt .legend (handles = [bp , bt ], labels = ["prediction" , "target" ])
56
- plt .pause (0.1 )
61
+ plt .ylim ((0 , 400 ))
62
+ plt .subplot (122 )
63
+ plt .cla ()
64
+ plt .plot (steps , accuracies , label = "accuracy" )
65
+ plt .ylim (ymax = 1 )
66
+ plt .ylabel ("accuracy" )
67
+ plt .pause (0.01 )
57
68
58
69
plt .ioff ()
59
70
plt .show ()
0 commit comments