Skip to content

Latest commit

 

History

History
216 lines (160 loc) · 9.23 KB

README.md

File metadata and controls

216 lines (160 loc) · 9.23 KB

基于TensorFlow训练花朵识别模型的源码和Demo

下面就通过对现有的 Google Inception-V3 模型进行 retrain ,对 5 种花朵样本数据的进行训练,来完成一个可以识别五种花朵的模型,并将新训练的模型进行测试部属,让大家体验一下完整的流程。

花朵训练样本

安装 TensorFlow (Mac 为例)

其他平台可以直接参考官网说明:Installing TensorFlow

首先检查系统是否安装了 Python

要安装 TensorFlow ,你的系统必须依据安装了以下任一 Python 版本:

  • Python 2.7
  • Python 3.3+

如果做数据处理较多的话,建议安装Anaconda, Anaconda 是一种Python语言的免费增值开源发行版 ,用于进行大规模数据处理, 预测分析, 和科学计算, 致力于简化包的管理和部署。Anaconda使用软件包管理系统Conda进行包管理。安装完成后输入shell下输入python即可查看Anaconda对应的Python 版本,我使用的是Python 2.7.14:

➜  ~ python
Python 2.7.14 |Anaconda, Inc.| (default, Dec  7 2017, 11:07:58)
[GCC 4.2.1 Compatible Clang 4.0.1 (tags/RELEASE_401/final)] on darwin
Type "help", "copyright", "credits" or "license" for more information.

如果你的系统还没有安装符合以上版本的 Python,现在安装。

通过 pip 安装 TensorFlow

# Python 2
➜ pip install tensorflow
# Python 3
➜ pip3 install tensorflow 

通过官方样例测试 TensorFlow 是否正常安装

进入 Python 环境后输入以下代码,当出现 “Hello, TensorFlow!” 时表明已经安装成功,可正常使用 TensorFlow 了。

➜ python
import tensorflow as tf
hello = tf.constant('Hello, TensorFlow!')
sess = tf.Session()
print(sess.run(hello))
Hello, TensorFlow!

准备训练样本

现在我们要训练花朵的识别模型,这是 Google 在TensorFlow里面提供的一个例子,其中包含了5类花朵的训练图片。可以新建个flower_demo文件夹,用于存放数据和训练的模型。

下载并解压得到训练样本

cd flower_demo
# 下载和解压花朵训练数据
curl -O http://download.tensorflow.org/example_images/flower_photos.tgz
tar xzf flower_photos.tgz

打开训练样本文件夹 flower_photos ,里面有 5 种类别的花:daisy(雏菊), dandelion(蒲公英), roses(玫瑰), sunflowers(向日葵) , tulips(郁金香),总共3672张,每个类别的大概有 600-900 张训练样本图片,具体如下:

cd flower_photos
for dir in `find ./ -maxdepth 1 -type d`;do echo -n -e "$dir\t";find $dir -type f|wc -l ;done;
./	    3672
.//roses	     641
.//sunflowers	     699
.//daisy	     633
.//dandelion	     898
.//tulips	     799

开始训练

下载训练模型使用的 retrain 脚本 该脚本会自动下载 google Inception v3 模型相关文件,retrain.py 是 Google 提供的以ImageNet图片分类模型为基础模型,利用flower_photos数据迁移训练花朵识别模型的脚本。

 cd flower_demo
 curl -O https://raw.githubusercontent.com/tensorflow/tensorflow/r1.1/tensorflow/examples/image_retraining/retrain.py

启动训练脚本,开始训练模型

在运行 retrain.py 脚本时,需要配置一些运行命令参数,比如指定模型输入输出相关名称和其他训练要求的配置。其中--how_many_training_steps=4000配置代表训练迭代次数,默认值为4000,如果机器较差,可以适当减少这个值。

➜ cd flower_demo
➜ python3 retrain.py \
  --bottleneck_dir=bottlenecks \
  --how_many_training_steps=4000 \
  --model_dir=inception \
  --summaries_dir=training_summaries/basic \
  --output_graph=retrained_graph.pb \
  --output_labels=retrained_labels.txt \
  --image_dir=flower_photos

这里我们训练4000steps,时间不是很久,我在配备4.2 GHz Intel Core i7处理器的iMac上,不适用GPU大概就5分钟就能训练完成。模型训练完成后,可以看到测试集上Final test accuracy = 92.1%,也就是说我们训练的5类花朵识别模型,在测试集上已经有92%的识别准确率了。其中生成的 retrained_labels.txtretrained_graph.pb 这两个是模型相关文件。

2018-06-02 15:47:00.266119: Step 3950: Train accuracy = 94.0%
2018-06-02 15:47:00.266159: Step 3950: Cross entropy = 0.135385
2018-06-02 15:47:00.327843: Step 3950: Validation accuracy = 93.0% (N=100)
2018-06-02 15:47:00.976543: Step 3960: Train accuracy = 94.0%
2018-06-02 15:47:00.976591: Step 3960: Cross entropy = 0.234760
2018-06-02 15:47:01.038559: Step 3960: Validation accuracy = 91.0% (N=100)
2018-06-02 15:47:01.667255: Step 3970: Train accuracy = 97.0%
2018-06-02 15:47:01.667372: Step 3970: Cross entropy = 0.167394
2018-06-02 15:47:01.731935: Step 3970: Validation accuracy = 87.0% (N=100)
2018-06-02 15:47:02.355780: Step 3980: Train accuracy = 96.0%
2018-06-02 15:47:02.355818: Step 3980: Cross entropy = 0.151201
2018-06-02 15:47:02.418314: Step 3980: Validation accuracy = 91.0% (N=100)
2018-06-02 15:47:03.042364: Step 3990: Train accuracy = 99.0%
2018-06-02 15:47:03.042402: Step 3990: Cross entropy = 0.094383
2018-06-02 15:47:03.103718: Step 3990: Validation accuracy = 91.0% (N=100)
2018-06-02 15:47:03.667861: Step 3999: Train accuracy = 99.0%
2018-06-02 15:47:03.667899: Step 3999: Cross entropy = 0.106797
2018-06-02 15:47:03.729215: Step 3999: Validation accuracy = 94.0% (N=100)
Final test accuracy = 92.1% (N=353)

测试训练完成后的模型

同样的,我们先下载测试模型的脚本 label_image.py,然后从flower_photos/daisy/文件夹下选择图片488202750_c420cbce61.jpg,测试我们训练后的模型的识别准确率,当然你也可以百度搜索一张5类花朵的任意一张图测试识别效果,从下图可以看出,我们训练的算法模型认为这张图属于daisy的概率高达98.9%.

➜ cd flower_demo
➜ curl -L https://goo.gl/3lTKZs > label_image.py
➜ python label_image.py flower_photos/daisy/488202750_c420cbce61.jpg

daisy (score = 0.98921)
sunflowers (score = 0.00948)
dandelion (score = 0.00088)
tulips (score = 0.00038)
roses (score = 0.00005)

蒲公英测试图 有人说label_image.py无法下载,代码如下:

import os, sys
import tensorflow as tf
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

# change this as you see fit
image_path = sys.argv[1]

# Read in the image_data
image_data = tf.gfile.FastGFile(image_path, 'rb').read()

# Loads label file, strips off carriage return
label_lines = [line.rstrip() for line in tf.gfile.GFile("retrained_labels.txt")]

# Unpersists graph from file
with tf.gfile.FastGFile("retrained_graph.pb", 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    tf.import_graph_def(graph_def, name='')

with tf.Session() as sess:
    # Feed the image_data as input to the graph and get first prediction
    softmax_tensor = sess.graph.get_tensor_by_name('final_result:0')
    
    predictions = sess.run(softmax_tensor, {'DecodeJpeg/contents:0': image_data})
    
    # Sort to show labels of first prediction in order of confidence
    top_k = predictions[0].argsort()[-len(predictions[0]):][::-1]
    
    for node_id in top_k:
        human_string = label_lines[node_id]
        score = predictions[0][node_id]
        print('%s (score = %.5f)' % (human_string, score))

我们随便从百度搜索一张蒲公英(dandelion)的图,保存到test/WechatIMG383.jpg,测试结果显示属于蒲公英的概率为99.59%.

python label_image.py test/WechatIMG383.jpg

dandelion (score = 0.99592)
sunflowers (score = 0.00359)
daisy (score = 0.00042)
tulips (score = 0.00005)
roses (score = 0.00001)

以上基本是模型训练和测试的全部过程,希望能让大家对深度学习的完整项目有个大致的了解。

启动 TensorBoard TensorBoard 是 TensorFlow 自带的训练效果可视化的分析工具,我们可以利用此工具检测和分析模型的收敛情况,比如查看loss的下降、acc的提升和查看可视化的网络结构图等。在我们建的工程目录下,启动tensorboard的具体命令如下:

➜ cd flower_demo
➜ tensorboard --logdir training_summaries

启动 TensorBoard 会占用系统 6006 端口 ,再启动一个新的 TensorBoard 之前,必须要 kill 已在运行的 TensorBoard 任务。

➜ pkill -f "tensorboard

启动浏览器查看 TensorBoard

启动TensorBoard后,可以启动浏览器,在地址栏中输入 localhost:6006 来查看训练进度以及loss和准确度的变化,分析模型等。

训练过程中loss和准确率的变化

花朵识别网络模型的后半部分