tf2-yolov3训练自己的数据集
tf2相比于tf1来说更加的友好,支持了Eager模式,代码和keras基本相同,所以代码也很简单,下面就如何用tf2-yolov3训练自己的数据集。项目的代码包:链接: tf2-yolov3.需要自行下载至于tf2-yolov3的原理可以参考这个链接,我觉得是讲的最好一个:链接: yolov3算法的一点理解.tf2-yolov3训练自己的数据集1、配置相关的环境2、使用官方权重进行预测二级目录三
tf2相比于tf1来说更加的友好,支持了Eager模式,代码和keras基本相同,所以代码也很简单,下面就如何用tf2-yolov3训练自己的数据集。
项目的代码包:链接: tf2-yolov3.需要自行下载
至于tf2-yolov3的原理可以参考这个链接,我觉得是讲的最好一个:链接: yolov3算法的一点理解.
tf2-yolov3训练自己的数据集
1、配置相关的环境
我是在linux上跑的,linux上配环境比较简单,相关windows配环境可以看这个博客:
链接: tensorflow-gpu环境搭建超级详细博客.
2、使用官方权重进行预测
1、进入到目标文件夹内
cd yolo_tf2.1/
2、输入 python convert.py
生成tf可用的模型
输出的yolov3.tf 保存在checkpoint里面。
3、开始检测
1)检测照片:
python detect.py --image ./data/people.jpg
这样便是成功的
2)打开摄像头进行预测:
python detect_video.py --video 0
3) 对视频流进行预测
python detect_video.py --video test.mp4 --output ./test_output.avi
经过以上测试,表示这个代码包可以正常的使用了,就可以利用TensorFlow2-yolov3来进行检测了,下一步我们来介绍一下如何训练自己的数据集。
3、训练自己的模型文件,并且识别
1)建立数据集文件夹
其中Annootation:存放标注好的**.xml**文件
JPEGImages : 自己搜集好的一些图片
2)添加图片并且标注(labelimg软件)
软件的下载地址:目标检测标注工具labelImg使用方法
记得要将图片保存到Annootation文件夹里面
…直到标注完所有的图片
3)建立.txt文件
//VOC2012//ImageSets//Main路径下
把你要训练的还有验证的数据文件都给写到.txt文件里面,方便程序对数据进行读取。
下面这段程序可以获取图片名称,因为每个人的图片的名称不一样,所以需要做相应的调整:
import os,glob
path = r"C:\Users\TSK\Desktop\yolo_tf2.1\VOCdevkit\VOC2012\JPEGImages"
path_list=os.listdir(path)
path_list.sort() #对列表进行格式化
for i in path_list[0:320] : #训练的样本
print(i[:-4]+" -1")
for i in path_list[320:-1] : #验证的样本
print(i[:-4]+" -1")
4)建立标签.names文件
在yolo_tf2.1/data文件夹下
里面写入的就是自己要训练的类别,有哪些类,就写入那些名称。
5)生成tfrecord文件(train和val)
这个文件的作用大概就是:这么多的图片,你让TensorFlow挨个去读取的话,很占内存,很费时间,原来很占内存,现在只用占一点点,终究一个还是节省内存,读取速度加快。
通过 .txt文件来读取
看自己的 .txt 文件是什么名字,这个地方得相应的改一下
训练集:
python tools/voc2012.py --data_dir ./VOCdevkit_fire/VOC2012 --split train --output_file ./data/voc2012_train_dlsb.tfrecord --classes ./data/dlsb.names
先解释一下部分含义,感觉没啥好解释的,都是字面意思 (捂脸笑)
一开始可能会出现这种情况,转tfrecord文件的时候可能会出点问题
然后我百度了一下,发现是这样一个原因:
错误的意思是:Unicode的解码(Decode)出现错误了,以gbk编码的方式去解码(该字符串变成Unicode),但是此处通过gbk的方式,却无法解码(can’tdecode).’'illegalmultibyte sequence"的意思是非法的多字节序列,也就是说无法解码了。
我在源代码中添加了这个就可以正确的执行了,encoding = 'utf-8'
如下:
我觉得还是那个.txt文件的格式不对,所以他读取不了,给它特定的格式就能够正确的读取了。
测试集:
python tools/voc2012.py --data_dir ./VOCdevkit_fire/VOC2012 --split val --output_file ./data/voc2012_val_dlsb.tfrecord --classes ./data/dlsb.names
出现这样表示已经转tfrecord成功。
6)进行迁移训练
python train.py --dataset ./data/voc2012_train_dlsb.tfrecord --val_dataset ./data/voc2012_val_dlsb.tfrecord --classes ./data/dlsb.names --num_classes 3 --mode fit --transfer darknet --batch_size 4 --epochs 150 --weights ./checkpoints/yolov3.tf --weights_num_classes 80
先简单的进行150个epochs的训练:
静待结果。。。。
损失函数下降到16,还不是特别好
7)进行模型测试
python detect.py --classes ./data/three.names --num_classes 3 --weights ./checkpoints/yolov3_train_150.tf --image ./000002.jpg --yolo_score_threshold 0.3
准确度还不是很高…正在改进中…
更多推荐
所有评论(0)