|
| 1 | +import os |
| 2 | +import tensorflow as tf |
| 3 | +from datetime import datetime |
| 4 | +from captchaCnn.util import next_batch |
| 5 | +from captchaCnn.captcha_gen import CAPTCHA_HEIGHT, CAPTCHA_WIDTH, CAPTCHA_LEN, CAPTCHA_LIST |
| 6 | + |
| 7 | + |
| 8 | +def weight_variable(shape, w_alpha=0.01): |
| 9 | + ''' |
| 10 | + 增加噪音,随机生成权重 |
| 11 | + :param shape: |
| 12 | + :param w_alpha: |
| 13 | + :return: |
| 14 | + ''' |
| 15 | + initial = w_alpha * tf.random_normal(shape) |
| 16 | + return tf.Variable(initial) |
| 17 | + |
| 18 | + |
| 19 | +def bias_variable(shape, b_alpha=0.1): |
| 20 | + ''' |
| 21 | + 增加噪音,随机生成偏置项 |
| 22 | + :param shape: |
| 23 | + :param b_alpha: |
| 24 | + :return: |
| 25 | + ''' |
| 26 | + initial = b_alpha * tf.random_normal(shape) |
| 27 | + return tf.Variable(initial) |
| 28 | + |
| 29 | + |
| 30 | +def conv2d(x, w): |
| 31 | + ''' |
| 32 | + 局部变量线性组合,步长为1,模式‘SAME’代表卷积后图片尺寸不变,即零边距 |
| 33 | + :param x: |
| 34 | + :param w: |
| 35 | + :return: |
| 36 | + ''' |
| 37 | + return tf.nn.conv2d(x, w, strides=[1, 1, 1, 1], padding='SAME') |
| 38 | + |
| 39 | + |
| 40 | +def max_pool_2x2(x): |
| 41 | + ''' |
| 42 | + max pooling,取出区域内最大值为代表特征, 2x2pool,图片尺寸变为1/2 |
| 43 | + :param x: |
| 44 | + :return: |
| 45 | + ''' |
| 46 | + return tf.nn.max_pool(x, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME') |
| 47 | + |
| 48 | + |
| 49 | +def cnn_graph(x, keep_prob, size, captcha_list=CAPTCHA_LIST, captcha_len=CAPTCHA_LEN): |
| 50 | + ''' |
| 51 | + 三层卷积神经网络计算图 |
| 52 | + :param x: |
| 53 | + :param keep_prob: |
| 54 | + :param size: |
| 55 | + :param captcha_list: |
| 56 | + :param captcha_len: |
| 57 | + :return: |
| 58 | + ''' |
| 59 | + # 图片reshape为4维向量 |
| 60 | + image_height, image_width = size |
| 61 | + x_image = tf.reshape(x, shape=[-1, image_height, image_width, 1]) |
| 62 | + |
| 63 | + # layer 1 |
| 64 | + # filter定义为3x3x1, 输出32个特征, 即32个filter |
| 65 | + w_conv1 = weight_variable([3, 3, 1, 32]) |
| 66 | + b_conv1 = bias_variable([32]) |
| 67 | + # rulu激活函数 |
| 68 | + h_conv1 = tf.nn.relu(tf.nn.bias_add(conv2d(x_image, w_conv1), b_conv1)) |
| 69 | + # 池化 |
| 70 | + h_pool1 = max_pool_2x2(h_conv1) |
| 71 | + # dropout防止过拟合 |
| 72 | + h_drop1 = tf.nn.dropout(h_pool1, keep_prob) |
| 73 | + |
| 74 | + # layer 2 |
| 75 | + w_conv2 = weight_variable([3, 3, 32, 64]) |
| 76 | + b_conv2 = bias_variable([64]) |
| 77 | + h_conv2 = tf.nn.relu(tf.nn.bias_add(conv2d(h_drop1, w_conv2), b_conv2)) |
| 78 | + h_pool2 = max_pool_2x2(h_conv2) |
| 79 | + h_drop2 = tf.nn.dropout(h_pool2, keep_prob) |
| 80 | + |
| 81 | + # layer 3 |
| 82 | + w_conv3 = weight_variable([3, 3, 64, 64]) |
| 83 | + b_conv3 = bias_variable([64]) |
| 84 | + h_conv3 = tf.nn.relu(tf.nn.bias_add(conv2d(h_drop2, w_conv3), b_conv3)) |
| 85 | + h_pool3 = max_pool_2x2(h_conv3) |
| 86 | + h_drop3 = tf.nn.dropout(h_pool3, keep_prob) |
| 87 | + |
| 88 | + # full connect layer |
| 89 | + image_height = int(h_drop3.shape[1]) |
| 90 | + image_width = int(h_drop3.shape[2]) |
| 91 | + w_fc = weight_variable([image_height*image_width*64, 1024]) |
| 92 | + b_fc = bias_variable([1024]) |
| 93 | + h_drop3_re = tf.reshape(h_drop3, [-1, image_height*image_width*64]) |
| 94 | + h_fc = tf.nn.relu(tf.add(tf.matmul(h_drop3_re, w_fc), b_fc)) |
| 95 | + h_drop_fc = tf.nn.dropout(h_fc, keep_prob) |
| 96 | + |
| 97 | + # out layer |
| 98 | + w_out = weight_variable([1024, len(captcha_list)*captcha_len]) |
| 99 | + b_out = bias_variable([len(captcha_list)*captcha_len]) |
| 100 | + y_conv = tf.add(tf.matmul(h_drop_fc, w_out), b_out) |
| 101 | + return y_conv |
| 102 | + |
| 103 | + |
| 104 | +def optimize_graph(y, y_conv): |
| 105 | + ''' |
| 106 | + 优化计算图 |
| 107 | + :param y: |
| 108 | + :param y_conv: |
| 109 | + :return: |
| 110 | + ''' |
| 111 | + # 交叉熵计算loss 注意logits输入是在函数内部进行sigmod操作 |
| 112 | + # sigmod_cross适用于每个类别相互独立但不互斥,如图中可以有字母和数字 |
| 113 | + # softmax_cross适用于每个类别独立且排斥的情况,如数字和字母不可以同时出现 |
| 114 | + loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=y_conv, labels=y)) |
| 115 | + # 最小化loss优化 |
| 116 | + optimizer = tf.train.AdamOptimizer(learning_rate=0.001).minimize(loss) |
| 117 | + return optimizer |
| 118 | + |
| 119 | + |
| 120 | +def accuracy_graph(y, y_conv, width=len(CAPTCHA_LIST), height=CAPTCHA_LEN): |
| 121 | + ''' |
| 122 | + 偏差计算图 |
| 123 | + :param y: |
| 124 | + :param y_conv: |
| 125 | + :param width: |
| 126 | + :param height: |
| 127 | + :return: |
| 128 | + ''' |
| 129 | + # 这里区分了大小写 实际上验证码一般不区分大小写 |
| 130 | + # 预测值 |
| 131 | + predict = tf.reshape(y_conv, [-1, height, width]) |
| 132 | + max_predict_idx = tf.argmax(predict, 2) |
| 133 | + # 标签 |
| 134 | + label = tf.reshape(y, [-1, height, width]) |
| 135 | + max_label_idx = tf.argmax(label, 2) |
| 136 | + correct_p = tf.equal(max_predict_idx, max_label_idx) |
| 137 | + accuracy = tf.reduce_mean(tf.cast(correct_p, tf.float32)) |
| 138 | + return accuracy |
| 139 | + |
| 140 | + |
| 141 | +def train(height=CAPTCHA_HEIGHT, width=CAPTCHA_WIDTH, y_size=len(CAPTCHA_LIST)*CAPTCHA_LEN): |
| 142 | + ''' |
| 143 | + cnn训练 |
| 144 | + :param height: |
| 145 | + :param width: |
| 146 | + :param y_size: |
| 147 | + :return: |
| 148 | + ''' |
| 149 | + # cnn在图像大小是2的倍数时性能最高, 如果图像大小不是2的倍数,可以在图像边缘补无用像素 |
| 150 | + # 在图像上补2行,下补3行,左补2行,右补2行 |
| 151 | + # np.pad(image,((2,3),(2,2)), 'constant', constant_values=(255,)) |
| 152 | + |
| 153 | + acc_rate = 0.95 |
| 154 | + # 按照图片大小申请占位符 |
| 155 | + x = tf.placeholder(tf.float32, [None, height * width]) |
| 156 | + y = tf.placeholder(tf.float32, [None, y_size]) |
| 157 | + # 防止过拟合 训练时启用 测试时不启用 |
| 158 | + keep_prob = tf.placeholder(tf.float32) |
| 159 | + # cnn模型 |
| 160 | + y_conv = cnn_graph(x, keep_prob, (height, width)) |
| 161 | + # 最优化 |
| 162 | + optimizer = optimize_graph(y, y_conv) |
| 163 | + # 偏差 |
| 164 | + accuracy = accuracy_graph(y, y_conv) |
| 165 | + # 启动会话.开始训练 |
| 166 | + saver = tf.train.Saver() |
| 167 | + sess = tf.Session() |
| 168 | + sess.run(tf.global_variables_initializer()) |
| 169 | + step = 0 |
| 170 | + while 1: |
| 171 | + batch_x, batch_y = next_batch(64) |
| 172 | + sess.run(optimizer, feed_dict={x: batch_x, y: batch_y, keep_prob: 0.75}) |
| 173 | + # 每训练一百次测试一次 |
| 174 | + if step % 100 == 0: |
| 175 | + batch_x_test, batch_y_test = next_batch(100) |
| 176 | + acc = sess.run(accuracy, feed_dict={x: batch_x_test, y: batch_y_test, keep_prob: 1.0}) |
| 177 | + print(datetime.now().strftime('%c'), ' step:', step, ' accuracy:', acc) |
| 178 | + # 偏差满足要求,保存模型 |
| 179 | + if acc > acc_rate: |
| 180 | + model_path = os.getcwd() + os.sep + str(acc_rate) + "captcha.model" |
| 181 | + saver.save(sess, model_path, global_step=step) |
| 182 | + acc_rate += 0.01 |
| 183 | + if acc_rate > 0.99: break |
| 184 | + step += 1 |
| 185 | + sess.close() |
| 186 | + |
| 187 | + |
| 188 | +if __name__ == '__main__': |
| 189 | + train() |
| 190 | + |
0 commit comments