使用Tensorflow從0開始搭建精靈寶可夢的檢測APP

本文爲本人原創,轉載請註明來源鏈接

環境要求

  • Tensorflow1.12.0
  • cuda 9.0
  • python3.6.10
  • Android Studio
  • Anaconda

安裝Tensorflow

  1. 使用conda 安裝GPU版Tensorflow

    conda install tensorflow-gpu=1.12.0

  2. 找到tensorflow的安裝位置

    我的位置在: home/jiading/.conda/envs/tensorflow12/lib/python3.6/site-packages/tensorflow

  3. 通過conda安裝的tensorflow是不包括models這一模塊的,需要從Github上下載: https://github.com/tensorflow/models

    將它克隆到tensorflow文件夾下:

  4. 打開models\research\object_detection,按照https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/installation.md提示的進行安裝

  5. 運行 python object_detection/builders/model_builder_test.py 測試是否安裝成功

下載和處理數據集

我們採用的數據集是https://www.kaggle.com/lantian773030/pokemonclassification。如果你使用colab訓練,可以直接將數據集下載到colab中: https://blog.csdn.net/qq_35654046/article/details/87621396

原始的數據集只有圖像和類別,可以用於分類,但是用於目標檢測的話需要在此基礎上進一步標定數據,在圖像中框出神奇寶貝的位置。

這裏我們使用labelme這個軟件進行標定。labelme可以直接通過pip安裝: pip install labelme -i https://pypi.tuna.tsinghua.edu.cn/simp le

在終端(Bash和Windows的Powershell都可以)中直接輸出Labelme即可打開軟件.labelme的簡單教程可以看這裏: https://www.cnblogs.com/wangxiaocvpr/p/9997690.html

標定數據後,我們在各個神奇寶貝的文件夾中得到了和原圖像同名的Json文件:

打開json文件,我們可以看到有很長的imageData:

這其實就是對原圖像的儲存,所以我們之後處理時只需要這個json文件即可,由此可以還原出原圖像

如果要達到比較好的效果,要標定的數據還是不少的。

將labelme轉換爲voc格式

我們最終要把數據集轉換爲tfrecord,但是在此之前我們需要將其轉換爲規範的voc格式,以便於再轉爲tfrecord

這裏我們使用Github上提供的腳本: https://github.com/veraposeidon/labelme2Datasets。這個項目的說明也是中文的,我就不多說了(可以使用我fork後修改的版本,下文有說改了哪些地方:https://github.com/JiaDingCN/labelme2Datasets )。

最後得到VOC格式的數據如下:

注意原項目的代碼中有一兩個小bug,這其實無傷大雅,改了就好了,但是原項目沒有生成val數據集的功能,只能生成training和test.所以我改了一點:

原來的split_dataset只有 test_ratio :測試集比例,我加上了'val_ratio'

注意,其實理論上可以直接用這個工具生成coco形式的數據,然後使用tensorflow中tensorflow/models/research/object_detection/dataset_tools/create_coco_tf_record.py來生成tfrecord,但在我實際使用中發現create_coco_tf_record.py製作出來的是分散的數據,如下:

當然人家在代碼中也說了: Please note that this tool creates sharded output files. ,是我自己沒仔細看。這個格式應該也是能用的,但是我目前不知道方法,所以最後就沒有用這個方法

將voc格式數據轉換爲tfrecord

最終我採用的是這篇博客中的代碼,生成的tfrecord如下:

開始訓練

這裏我訓練使用的是Tensorflow lite教程中推薦的COCO SSD MobileNet v1:

下載地址: http://storage.googleapis.com/download.tensorflow.org/models/tflite/coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.zip

當然也可以不選擇預訓練模型,而是從頭訓練。這樣的話就不需要下載上面的文件,你只需要一個config文件即可。該網絡的config文件在object_detection/samples/config/ssd_mobilenet_v1_coco.config。如何配置依然可以看這篇博文: https://www.cnblogs.com/gezhuangzhuang/p/10613468.html

訓練完成後,我們就可以在train_dir中看到得到的模型:

導出圖

我們可以使用object_detection下的export_inference_graph.py導出圖,但是對於ssd模型, 官方推薦 使用export_tflite_ssd_graph.py(親測用上面的那個腳本導出的模型無法轉換爲tflite格式):

python export_tflite_ssd_graph.py --input_type image_tensor --pipeline_config_path /home/jiading/Pokemon/ssd_mobilenet_v1_0.75_depth_300x300_coco14_sync.config --trained.checkpoint_prefix /home/jiading/Pokemon/train/model.ckpt-2955 --output_directory /home/jiading/Pokemon/frozen_inference_graph.pb -add_postprocessing_op True --max_detection 10

測試

我們可以使用tensorflow的object_detection自帶的jupyter notebook腳本來做測試:

將PATH_TO_FROZEN_GRAPH改爲pb文件的位置

需要一個labelmap文件,內容如下:

用一個腳本很容易寫出來,這個就不提了

加載一張圖片

運行結果

轉換爲tensorflow lite模型

~/.conda/envs/tensorflow12/lib/python3.6/site-packages/tensorflow/models/research/object_detection$ tflite_convert --output_file=/home/jiading/Pokemon/tflite/detect.tflite --graph_def_file=/home/jiading/Pokemon/frozen_inference_graph/tflite_graph.pb --input_arrays='normalized_input_image_tensor' --output_arrays='TFLite_Detection_PostProcess','TFLite_Detection_PostProcess:1','TFLite_Detection_PostProcess:2','TFLite_Detection_PostProcess:3' --input_shape=1,300,300,3 --allow_custom_ops

部署在安卓端

安卓的例子在 ObjectDetection-Android\examples-master\lite\examples\object_detection\android 下,打開後我們首先需要製作一個labelmap:

原本的例子會利用gradle下載模型,我們可以將地址替換掉

,將我們自己的這兩個文件放進去:

部署時可能遇到的bug

我們可以比對自己的模型和原本的模型在輸入輸出上有沒有區別: https://blog.csdn.net/killfunst/article/details/94301161

import numpy as np
import tensorflow as tf


# Load TFLite model and allocate tensors.
interpreter = tf.contrib.lite.Interpreter(model_path="")
interpreter.allocate_tensors()

input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

print(input_details)
print(output_details)

像我之前在導出圖時設置的 --max_detection 5 ,但是看輸出發現:

上面是我的,下面是原本模型的,改爲10後再導出就沒有問題了

如果還有問題,可以考慮將DetectorActivity中的 private static final boolean TF_OD_API_IS_QUANTIZED 設置爲false。同時,如果出現維度錯誤,可以考慮修改TFLiteObjectDetectionAPIModel.java下的 private static final int NUM_DETECTIONS

最終效果:

一點點換皮

將原項目中的圖標和軟件名換掉之後:

相關文章