-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
61 lines (48 loc) · 2.32 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import tensorflow as tf
import numpy as np
import src.pred as pr
from src.cnn_model import crack_captcha_cnn
from src.pred import get_next_batch
number = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
# alphabet = ['a','b','c','d','e','f','g','h','i','j','k','l','m','n','o','p','q','r','s','t','u','v','w','x','y','z']
# ALPHABET = ['A','B','C','D','E','F','G','H','I','J','K','L','M','N','O','P','Q','R','S','T','U','V','W','X','Y','Z']
# 定义超参数:
IMAGE_HEIGHT = 60
IMAGE_WIDTH = 160
MAX_CAPTCHA = 4 # 4
CHAR_SET_LEN = len(number)
#按图片大小申请占位符
X = tf.placeholder(tf.float32, [None, IMAGE_HEIGHT * IMAGE_WIDTH]) # X结构--固定输入图像的规模
Y = tf.placeholder(tf.float32, [None, MAX_CAPTCHA * CHAR_SET_LEN]) # Y结构--label:4*10个
keep_prob = tf.placeholder(tf.float32) # dropout
# 训练
def train_crack_captcha_cnn():
output = crack_captcha_cnn()
#loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=y_conv, labels=y))
loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=output,labels=Y))
optimizer = tf.train.AdamOptimizer(learning_rate=0.01).minimize(loss)
predict = tf.reshape(output, [-1, MAX_CAPTCHA, CHAR_SET_LEN])
max_idx_p = tf.argmax(predict, 2)
max_idx_l = tf.argmax(tf.reshape(Y, [-1, MAX_CAPTCHA, CHAR_SET_LEN]), 2)
correct_pred = tf.equal(max_idx_p, max_idx_l)
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
step = 0
while True:
batch_x, batch_y = get_next_batch(64)
_, loss_ = sess.run([optimizer, loss], feed_dict={X: batch_x, Y: batch_y, keep_prob: 0.75})
print(step, loss_)
# 每100 step计算一次准确率
if step % 10 == 0:
batch_x_test, batch_y_test = get_next_batch(100)
acc = sess.run(accuracy, feed_dict={X: batch_x_test, Y: batch_y_test, keep_prob: 1.})
print(step, acc)
# 如果准确率大于50%,保存模型,完成训练
if acc > 0.50:
saver.save(sess, "./model/crack_capcha.model", global_step=step)
break
step += 1
if __name__ == '__main__':
train_crack_captcha_cnn()