二. YOLO数据预处理流程

2.1 流程简述

原始数据一般是图片数据和标注数据,其中标注数据目前有两种,一种是VOC格式的.xml文件存储标注信息,另外一种标注格式CoCo用json来存储标注信息,这次我给大家准备的是VOC格式的转YOLO,此类脚本很多。主要涉及三个方面:

  1. 数据集划分(2.2节)
  2. VOC格式转YOLO格式(2.3节)
  3. YOLO网络训练数据加载(2.4节)

三个方面的准备,为下一步网络训练准备好数据。

如果对你能有一些帮助,点个赞不过分叭,点赞是我继续创作动力。

2.2 数据格式统一

我习惯采用VOC数据格式,目录结构如下:
在这里插入图片描述

下面的脚本主要是对voc数据集格式的数据进行划分训练集 集验证集 测试集,具体内容看注解

import os
import random
from tqdm import tqdm
"""
作者:小小博
时间:2022.3.21
环境:打完游戏的夜晚
本脚本有功能:
1.对VOC格式的数据集进行训练集、验证集、测试集划分
"""
# -------------------------------------------------------#
#   voc_annotation.py
#   脚本主要是对voc数据集格式的数据进行划分训练集 集验证集 测试集
#   目录结构如下:
#   |-VOCdevkit
#       |-VOC2007
#           |-Annotations
#           |-ImageSets
#               |-Main
#                   |-test.txt
#                   |-train.txt
#                   |-trainval.txt
#                   |-val.txt
#           |-JPEGImages
# -------------------------------------------------------#

# -------------------------------------------------------#
#  classes_path 是存放voc(20类)类别名称的txt文件
#  trainval_percent = train_percent 为训练集和验证集的比例
#  例如数据集有1000张图片 trainval_percent 训练集和验证集占90% train_percent再占90%
#  那么训练集+验证集 = 900 张  测试集 = 100张
#  其中 810张 属于训练集 90张属于验证集 810+90 = 900
#  VOCdevkit_path  VOC 数据集的目录
# -------------------------------------------------------#
classes_path = 'model_data/voc_classes.txt'
trainval_percent = 0.9
train_percent = 0.9
VOCdevkit_path = 'VOCdevkit'

if __name__ == "__main__":
    # -------------------------------------------------------#
    #  random.seed(0)随机数种子,使得我们每次生成的训练集和测试集划分是一致
    # -------------------------------------------------------#
    random.seed(0)
    # -------------------------------------------------------#
    #  xmlfilepath 图像标签地址 .xml
    #  save_path  划分结果的存储地址 .txt
    #  temp_xml Annotations下所有xml文件名['000001.xml', '000002.xml', ...,'001000.xml']
    #  如果 Annotations 还有其他类型的文件 例如 .txt .jpg 得要过滤一下
    #  total_xml 全部存的是
    # -------------------------------------------------------#
    xml_path = os.path.join(VOCdevkit_path, 'VOC2007/Annotations')
    save_path = os.path.join(VOCdevkit_path, 'VOC2007/ImageSets/Main')
    temp_xml = os.listdir(xml_path)
    total_xml = [xml for xml in temp_xml  if xml.endswith(".xml")]
    num = len(total_xml)    # 标签总数
    list = range(num)   # 标签总数(0,num)
    tv = int(num*trainval_percent)  # 训练集+验证集的总数
    tr = int(tv*train_percent)  # 训练集总数
    trainval = random.sample(list, tv)    # 例如1000张样本 随机选取900张的图片的索引作为训练和验证集索引
    train = random.sample(trainval, tr)    # 900 索引中再随机810的索引作为训练集索引

    print("训练集和验证集的数量: ", tv)
    print("训练集的数量: ", tr)
    print("验证集的数量: ", tv - tr)
    print("测试集的数量: ", num - tv)

    ftrainval = open(os.path.join(save_path,'trainval.txt'), 'w')   # 训练集和验证集txt
    ftrain = open(os.path.join(save_path,'train.txt'), 'w')         # 训练集txt
    fval = open(os.path.join(save_path,'val.txt'), 'w')             # 验证集txt
    ftest = open(os.path.join(save_path,'test.txt'), 'w')           # 测试集txt

    tq_epochs = tqdm(list)                                          # 进度条
    for i in tq_epochs:
        # -------------------------------------------------------#
        #  total_xml[i][:-4] 取出每一个文件名从.xml前面文件名不需要文件后缀
        #  i 就是标签所在的索引号
        #  如果 索引 i 在 trainval 就写入训练集和验证集 否则就写入测试集
        #       如果 索引 i 在 trainval 同时也在train  就写入训练集 否则 就写入验证集
        # -------------------------------------------------------#
        name = total_xml[i][:-4]+'\n'
        if i in trainval:
            ftrainval.write(name)
            if i in train:
                ftrain.write(name)
            else:
                fval.write(name)
        else:
            ftest.write(name)

    ftrainval.close()   # 关闭写入流
    ftrain.close()      # 关闭写入流
    fval.close()        # 关闭写入流
    ftest.close()       # 关闭写入流

2.3 VOC格式转YOLO格式

VOC标签文件是存在.xml文件中的,而YOLO需要的格式.txt中 (11 0.344193 0.611 0.416431 0.262) 第一个数是类别索引,后面四个数是(x,y,w,h)归一化的坐标信息,具体操作看下面脚本的注释。

"""
作者:小小博
时间:2022.3.21
本脚本有三个个功能:
1.根据train.txt和val.txt将voc数据集标注信息(.xml)转为yolo标注格式(.txt),生成dataset文件(train+val)
2.统计训练集、验证集和测试集的数据并生成相应train_path.txt和val_path.txt test_path.txt文件
3.创建data.data文件,记录classes个数, train以及val数据集文件(.txt)路径和dataset_classes.names文件路径
"""
import os
from tqdm import tqdm
from lxml import etree
import json
import shutil
from os.path import *

# -------------------------------------------------------#
#  对数据和标签进行处理把VOC格式转为YOLO需要的格式 (class_index ,x,y,w,h)
#  dir_path 根目录
#  images_path 图片的绝对地址
#  xml_path 标签的绝对地址
#  train_txt_path 训练集的绝对地址
#  val_txt_path 验证集的绝对地址
#  test_txt_path 测试集的绝对地址
#  label_json_path 标签名称的json文件绝对地址
#  save_file_root 保存路径地址
# -------------------------------------------------------#
dir_path = dirname(abspath(__file__))
images_path = os.path.join(dir_path, "VOCdevkit/VOC2007", "JPEGImages")
xml_path = os.path.join(dir_path, "VOCdevkit/VOC2007", "Annotations")
train_txt_path = os.path.join(dir_path, "VOCdevkit/VOC2007", "ImageSets/Main", "train.txt")
val_txt_path = os.path.join(dir_path, "VOCdevkit/VOC2007", "ImageSets/Main", "val.txt")
test_txt_path = os.path.join(dir_path, "VOCdevkit/VOC2007", "ImageSets/Main", "test.txt")
label_json_path = os.path.join(dir_path,"data", "pascal_voc_classes.json")
save_file_root = os.path.join(dir_path, "dataset")

# -------------------------------------------------------#
#  保存训练集、验证集、测试集图片绝对地址到.txt
#  同时把三个绝对地址和类别数保存到dataset_data.txt
#  方便后续进行数据装载操作 dataset
# -------------------------------------------------------#
train_annotation_dir = os.path.join(dir_path, "dataset", "train", "labels")
val_annotation_dir = os.path.join(dir_path, "dataset", "val", "labels")
test_annotation_dir = os.path.join(dir_path, "dataset", "test", "labels")
train_path_txt = os.path.join(dir_path, "train_path.txt")
val_path_txt = os.path.join(dir_path, "val_path.txt")
test_path_txt = os.path.join(dir_path, "test_path.txt")
dataset_data = os.path.join(dir_path, "dataset.data")
classes_label = os.path.join(dir_path, "dataset_classes.names")
# -------------------------------------------------------#
#  检查文件/文件夹都是否存在
# -------------------------------------------------------#
assert os.path.exists(images_path), "images path not exist..."
assert os.path.exists(xml_path), "xml path not exist..."
assert os.path.exists(train_txt_path), "train txt file not exist..."
assert os.path.exists(val_txt_path), "val txt file not exist..."
assert os.path.exists(test_txt_path), "test txt file not exist..."
assert os.path.exists(label_json_path), "label_json_path does not exist..."
# -------------------------------------------------------#
#  如果dataset不存在 就创建一个
# -------------------------------------------------------#
if os.path.exists(save_file_root) is False:
    os.makedirs(save_file_root)

# -------------------------------------------------------#
# 将xml文件解析成字典形式
# {'bndbox': {'xmin': '48', 'ymin': '240', 'xmax': '195', 'ymax': '371'}}
# -------------------------------------------------------#

def parse_xml_to_dict(xml):
    if len(xml) == 0:  # 遍历到底层,直接返回tag对应的信息
        return {xml.tag: xml.text}
    # -------------------------------------------------------#
    # 递归遍历标签信息
    # 因为object可能有多个,所以需要放入列表里
    # -------------------------------------------------------#
    result = {}
    for child in xml:
        child_result = parse_xml_to_dict(child)
        if child.tag != 'object':
            result[child.tag] = child_result[child.tag]
        else:
            if child.tag not in result:
                result[child.tag] = []
            result[child.tag].append(child_result[child.tag])
    return {xml.tag: result}

# -------------------------------------------------------#
#   文件夹的格式
#   |-dataset
#       |-train
#           |-images
#           |-labels
#       |-val
#           |-images
#           |-labels
#       |-test
#           |-images
#           |-labels
#   先判断文件夹是否存在,不存在就创建
# -------------------------------------------------------#
def translate_info(file_names: list, save_root: str, class_dict: dict, train_val='train'):
    '''
    :param file_names: 文件名称列表不含文件后缀
    :param save_root: 转换后的存储地址
    :param class_dict: json格式{key:value,key:value,...,key:value}标签名索引
    :param train_val: 保存文件夹名
    :return:
    '''
    save_txt_path = os.path.join(save_root, train_val, "labels")
    if os.path.exists(save_txt_path) is False:
        os.makedirs(save_txt_path)
    save_images_path = os.path.join(save_root, train_val, "images")
    if os.path.exists(save_images_path) is False:
        os.makedirs(save_images_path)
    # -------------------------------------------------------#
    #  tqdm 进度条,对处理过程进行可视化,建议自己查API 学习一下
    # -------------------------------------------------------#
    for file in tqdm(file_names, desc="translate {} file...".format(train_val)):
        # -------------------------------------------------------#
        #  检查下图像文件是否存在,如果你的图片是.png 自行修改
        # -------------------------------------------------------#
        img_path = os.path.join(images_path, file + ".jpg")
        assert os.path.exists(img_path), "file:{} not exist...".format(img_path)
        # -------------------------------------------------------#
        #  检查xml文件是否存在
        # -------------------------------------------------------#
        xml_full_path = os.path.join(xml_path, file + ".xml")
        assert os.path.exists(xml_full_path), "file:{} not exist...".format(xml_full_path)
        # -------------------------------------------------------#
        #  读xml内容并对其进行处理
        #  <size>
        # 		<width>353</width>
        # 		<height>500</height>
        # 		<depth>3</depth>
        # 	</size>
        # -------------------------------------------------------#
        with open(xml_full_path) as fid:
            xml_str = fid.read()
        xml = etree.fromstring(xml_str)
        data = parse_xml_to_dict(xml)["annotation"]
        img_height = int(data["size"]["height"])
        img_width = int(data["size"]["width"])
        # -------------------------------------------------------#
        #  将.xml的内容转换为YOLO的格式并写入.txt文件中
        # -------------------------------------------------------#
        with open(os.path.join(save_txt_path, file + ".txt"), "w") as f:
            assert "object" in data.keys(), "file: '{}' lack of object key.".format(xml_full_path)
            for index, obj in enumerate(data["object"]):
                # -------------------------------------------------------#
                #  获取每个目标框架的(x1,y1),(x2,y2) 左上 和 右下标
                #  例如:
                #  <object>
                # 		<name>dog</name>
                # 		<pose>Left</pose>
                # 		<truncated>1</truncated>
                # 		<difficult>0</difficult>
                # 		<bndbox>
                # 			<xmin>48</xmin>
                # 			<ymin>240</ymin>
                # 			<xmax>195</xmax>
                # 			<ymax>371</ymax>
                # 		</bndbox>
                # 	</object>
                # -------------------------------------------------------#
                xmin = float(obj["bndbox"]["xmin"])
                xmax = float(obj["bndbox"]["xmax"])
                ymin = float(obj["bndbox"]["ymin"])
                ymax = float(obj["bndbox"]["ymax"])
                # -------------------------------------------------------#
                #  通过key 去查 value 把目标的名 数字化
                # -------------------------------------------------------#
                class_index = class_dict[obj["name"]] - 1  # 目标id从0开始
                # -------------------------------------------------------#
                #  将box信息转换到yolo格式
                # -------------------------------------------------------#
                x_center = xmin + (xmax - xmin) / 2
                y_center = ymin + (ymax - ymin) / 2
                w = xmax - xmin
                h = ymax - ymin
                # -------------------------------------------------------#
                #  yolo格式(class_index,x,y,w,h) 进行归一化 方便网络训练
                #  绝对坐标转相对坐标,保存6位小数
                # info=['6', '0.13', '0.857357', '0.116', '0.141141']
                # -------------------------------------------------------#
                x_center = round(x_center / img_width, 6)
                y_center = round(y_center / img_height, 6)
                w = round(w / img_width, 6)
                h = round(h / img_height, 6)
                info = [str(i) for i in [class_index, x_center, y_center, w, h]]
                # -------------------------------------------------------#
                #  如果只有一个目标框就直接写入,并且用空格分隔
                #  如果有多个目标框就要加换行符,并且用空格分隔
                # -------------------------------------------------------#
                if index == 0:
                    f.write(" ".join(info))
                else:
                    f.write("\n" + " ".join(info))
        # -------------------------------------------------------#
        #  将图像复制到 save_images_path
        # -------------------------------------------------------#
        shutil.copyfile(img_path, os.path.join(save_images_path, img_path.split(os.sep)[-1]))


def create_class_names(class_dict: dict):
    # -------------------------------------------------------#
    #  keys集合 目标名称 ['aeroplane', 'bicycle', 'bird',...,'tvmonitor']
    # -------------------------------------------------------#
    keys = class_dict.keys()
    with open(dir_path+"/dataset_classes.names", "w") as w:
        for index, k in enumerate(keys):
            if index + 1 == len(keys):
                w.write(k)
            else:
                w.write(k + "\n")
# -------------------------------------------------------#
# 创建记录图像的列表 .txt
# 例如:
#   F:\mypytorch\pythonProject1\myYolo2\dataset\test\images\000007.jpg
#   F:\mypytorch\pythonProject1\myYolo2\dataset\test\images\000035.jpg
#   F:\mypytorch\pythonProject1\myYolo2\dataset\test\images\000036.jpg
#   F:\mypytorch\pythonProject1\myYolo2\dataset\test\images\000049.jpg
#   F:\mypytorch\pythonProject1\myYolo2\dataset\test\images\000055.jpg
#   txt_path =  train_path.txt  val_path.txt test_path.txt保存地址
#   dataset_dir = F:\mypytorch\pythonProject1\myYolo2\dataset\test\labels
# -------------------------------------------------------#
def calculate_data_txt(txt_path, dataset_dir):
    with open(txt_path, "w") as w:
        for file_name in os.listdir(dataset_dir):
            print(file_name)
            if file_name == "classes.txt":
                continue
            # -------------------------------------------------------#
            #  F:\mypytorch\pythonProject1\myYolo2\dataset\test\labels
            #  将上述的绝对地址labels替换为images,同时把文件名从.分割加上JPG格式
            #  在加上换行符 一条一条的写入  train_path.txt val_path.txt test_path.txt 文件
            # -------------------------------------------------------#
            img_path = os.path.join(dataset_dir.replace("labels", "images"),
                                    file_name.split(".")[0]) + ".jpg"
            line = img_path + "\n"
            assert os.path.exists(img_path), "file:{} not exist!".format(img_path)
            w.write(line)
    w.close()
# -------------------------------------------------------#
# 创建记录数据的列表 dataset.data
# 例如:
#   classes=20
#   train=F:\mypytorch\pythonProject1\myYolo2\train_path.txt
#   valid=F:\mypytorch\pythonProject1\myYolo2\val_path.txt
#   test=F:\mypytorch\pythonProject1\myYolo2\test_path.txt
#   names=F:\mypytorch\pythonProject1\myYolo2\dataset_classes.names
# -------------------------------------------------------#
def create_dataset_data(create_data_path, label_path, train_path, val_path,test_path, classes_info):

    with open(create_data_path, "w") as w:
        w.write("classes={}".format(len(classes_info)) + "\n")  # 记录类别个数
        w.write("train={}".format(train_path) + "\n")           # 记录训练集对应txt文件路径
        w.write("valid={}".format(val_path) + "\n")             # 记录验证集对应txt文件路径
        w.write("test={}".format(test_path) + "\n")             # 记录测试集对应txt文件路径
        w.write("names={}".format(classes_label) + "\n")        # 记录label.names文件路径
    w.close()

def main():
    # -------------------------------------------------------#
    #  读入json文件,并转为为json格式{key:value,key:value,...,key:value}形式
    # -------------------------------------------------------#
    json_file = open(label_json_path, 'r')
    class_dict = json.load(json_file)
    # -------------------------------------------------------#
    #  读取train.txt中的所有行信息,删除空行
    # -------------------------------------------------------#
    with open(train_txt_path, "r") as r:
        train_file_names = [i for i in r.read().splitlines() if len(i.strip()) > 0]
    # -------------------------------------------------------#
    #  读取训练集voc格式转换为YOLO格式
    # -------------------------------------------------------#
    translate_info(train_file_names, save_file_root, class_dict, "train")


    with open(val_txt_path, "r") as r:
        val_file_names = [i for i in r.read().splitlines() if len(i.strip()) > 0]
    # -------------------------------------------------------#
    #  读取验证集voc格式转换为YOLO格式
    # -------------------------------------------------------#
    translate_info(val_file_names, save_file_root, class_dict, "val")

    with open(test_txt_path, "r") as r:
        test_file_names = [i for i in r.read().splitlines() if len(i.strip()) > 0]
    # -------------------------------------------------------#
    #  读取测试集voc格式转换为YOLO格式
    # -------------------------------------------------------#
    translate_info(test_file_names, save_file_root, class_dict, "test")
    # -------------------------------------------------------#
    #  创建dataset_classes.names文件
    # -------------------------------------------------------#
    create_class_names(class_dict)
    # -------------------------------------------------------#
    #  统计训练集和验证集的数据并生成相应txt文件
    # -------------------------------------------------------#
    calculate_data_txt(train_path_txt, train_annotation_dir)
    calculate_data_txt(val_path_txt, val_annotation_dir)
    calculate_data_txt(test_path_txt, test_annotation_dir)
    classes_info = [line.strip() for line in open(classes_label, "r").readlines() if len(line.strip()) > 0]
    # -------------------------------------------------------#
    #  dataset.data文件,记录classes个数, train、val、test数据集文件(.txt)路径和label.names文件路径
    # -------------------------------------------------------#
    create_dataset_data(dataset_data, classes_label, train_path_txt, val_path_txt,test_path_txt, classes_info)

if __name__ == "__main__":
    main()

2.4 数据加载

主要涉及了对读入图片的进行规格化,对图片大小进行统一(填充),例如图片大小统一为416×416像素,同时可以在数据加载函数中使用,多尺度训练、数据增强等策略,这部分可以自行设计,使用DataLoader进行数据装载,后面训练部分会详细介绍。

import random
import os
import numpy as np
from PIL import Image
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset
import torchvision.transforms as transforms
"""
作者:小小博
时间:2022.3.20
环境:夜深人静的晚上
本脚本有功能:
1.对应train_path.txt路径的数据进行YOLO需要的数据进行格式化
"""
# -------------------------------------------------------#
#   图像尺寸统一函数1
# -------------------------------------------------------#
def pad_to_square(img, pad_value):
    # -------------------------------------------------------#
    #    dim_diff = np.abs(h - w) 求出高和宽的绝对值差
    #    填补只有三种方式:
    #       1、h=w 不需要填
    #       2、h>w 需要对图像的宽左右进行填补
    #       3、h<w 需要对图像的高上下进行填补
    #    填补计算只会出现奇数和偶数,如果出现奇数就对上或者左多一像素进行填补
    #    pad1, pad2 是填充像素值
    #    pad = (左, 右, 上, 下)
    #    如果 w >= h 宽度大于高度 就对图片上下进行填充
    #    否则 对图像左右进行填充
    #    填充常量 constant 具体这个函数的内容可以自己去查一查
    # -------------------------------------------------------#
    c, h, w = img.shape
    dim_diff = np.abs(h - w)
    pad1, pad2 = dim_diff // 2, dim_diff - dim_diff // 2
    pad = (0, 0, pad1, pad2) if h <= w else (pad1, pad2, 0, 0)
    img = F.pad(img, pad, "constant", value=pad_value)
    return img, pad

# -------------------------------------------------------#
#   图像尺寸统一函数2
#   把图像重新调整为需要的尺寸 对其进行采样
#   可以使用最邻近上采样、线性插值法、双线性插值法等
#   ’nearest’, ‘linear’, ‘bilinear’, ‘bicubic’ , ‘trilinear’和’area’. 默认使用’nearest’
# -------------------------------------------------------#

def img_resize(image, size):
    image = F.interpolate(image.unsqueeze(0), size=size, mode="nearest").squeeze(0)
    return image

# -------------------------------------------------------#
#   ListDataset 数据预处理类
# -------------------------------------------------------#
class ListDataset(Dataset):
    def __init__(self, list_path, img_size=416, augment=True, multiscale=True, normalized_labels=True):
        # -------------------------------------------------------#
        #   获取对应的 dataset\train\images\下面所有图片的绝对地址
        #   self.img_path
        #   D:\mypro\myYolo2\dataset\train\images\000001.jpg  图片的地址
        # -------------------------------------------------------#
        with open(list_path, "r") as file:
            self.img_path = file.readlines()
        # -------------------------------------------------------#
        #   获取对应的 dataset\train\labels\下面所有图片的b标签绝对地址
        #   self.label_path
        #   D:\mypro\myYolo2\dataset\train\labels\000001.txt  标签的地址
        #   path 是对应图片的绝地地址,把地址中的images替换为labels 再把图片.png或者.jpg格式换为.txt
        #   就得到了对应图片的标签绝绝对
        # -------------------------------------------------------#
        self.label_path = [path.replace("images", "labels").replace(".png", ".txt").replace(".jpg", ".txt") for path in self.img_path ]

        self.img_size = img_size
        self.max_objects = 100
        self.augment = augment                      # 是否开启图像增强策略
        self.multiscale = multiscale                # 是否开启多尺度训练策略
        self.normalized_labels = normalized_labels  # 标签标准化
        self.min_size = self.img_size - 3 * 32      # 输入416 min_size = 320  多尺度训练时用
        self.max_size = self.img_size + 3 * 32      # 输入416 max_size = 512  多尺度训练时用
        self.batch_count = 0                        # 记录batch数 可以对多尺度训练进行设置

    # -------------------------------------------------------#
    #   __getitem__方法
    #   使用索引访问元素时
    #   如果对象为datas,data[key]取值,当实例对象做pdata[0] 运算时,会调用类中的方法__getitem_
    # -------------------------------------------------------#
    def __getitem__(self, index):

        # -------------------------------------------------------#
        #   获取你索引的图像绝对地址,取余操作的好处就是更大的容错
        #   rstrip()去掉末尾的空格
        # -------------------------------------------------------#
        img_path = self.img_path[index % len(self.img_path)].rstrip()
        # -------------------------------------------------------#
        #   Image 读出的图像数据格式是[h,w,c]要转为换为[c,h,w]tensor格式
        # -------------------------------------------------------#
        img = transforms.ToTensor()(Image.open(img_path).convert('RGB'))
        # -------------------------------------------------------#
        #   img.shape = [3,416,416]    len(img.shape) = 3
        #   如果读入的图像是灰度图片只有一个通道,需要统一为三通道 img
        #   unsqueeze(0)升高维度 将[h,w] -> [1,h,w]
        #   expand(3,h,w) [1,h,w] -> 3,h,w]
        # -------------------------------------------------------#
        if len(img.shape) != 3:
            img = img.unsqueeze(0)
            img = img.expand(3, img.shape[1],img.shape[2])
        # -------------------------------------------------------#
        #    _, h, w 分别为图像的 通道数 高度 宽度
        #   如果 self.normalized_label = True  标签标准化
        #   h_factor, w_factor = (h, w)
        #   否则 h_factor, w_factor = (1, 1)
        # -------------------------------------------------------#
        _, h, w = img.shape
        h_factor, w_factor = (h, w) if self.normalized_labels else (1, 1)
        # -------------------------------------------------------#
        #   图像填充,输入网络的图像数据必须满足 高和宽相等,如果不相等就需要对其短边进行填充
        #   pad_to_square 图像填充函数 详细注解请看函数
        #   voc 数据集里数据大多数的 要么长为500像素 或者宽为500像素
        #   最后都统一为500*500像素
        # -------------------------------------------------------#
        img, pad = pad_to_square(img, 0)
        _, padded_h, padded_w = img.shape  # padded_h = 500   padded_w = 500

        # -------------------------------------------------------#
        #   标签操作
        #   label_path 对应图像的标签绝对地址
        # -------------------------------------------------------#
        label_path = self.label_path[index % len(self.img_path)].rstrip()
        targets = None
        if os.path.exists(label_path):
            # -------------------------------------------------------#
            #   读入YOLO的标签格式 [class,x,y,w,h] 类别 中心坐标(x,y)和 宽高
            #   label_path 对应图像的标签绝对地址
            #   boxex.shape (class_num,5) class_num 为该张图像中标记的目标个数
            # -------------------------------------------------------#
            boxes = torch.from_numpy(np.loadtxt(label_path).reshape(-1, 5))

            # -------------------------------------------------------#
            #   将 (x,y,w,h) 格式转为 (x1,y1,x2,y2)
            #  (x1,y1) 目标左上角位置 (x2,y2)图片右下角位置 还原在原始图片上
            # -------------------------------------------------------#
            x1 = w_factor * (boxes[:, 1] - boxes[:, 3] / 2)
            y1 = h_factor * (boxes[:, 2] - boxes[:, 4] / 2)
            x2 = w_factor * (boxes[:, 1] + boxes[:, 3] / 2)
            y2 = h_factor * (boxes[:, 2] + boxes[:, 4] / 2)

            # -------------------------------------------------------#
            #   原始标签是相对于原始图像的位置 原图尺寸为(500,336)
            #   现在填补的图 为(500,500) 所以目标相对位置需要进行改变
            #   pad = (左, 右, 上, 下)
            #   最简单的方法自己画图理解这部分 数形结合
            # -------------------------------------------------------#
            x1 += pad[0]
            y1 += pad[2]
            x2 += pad[0]
            y2 += pad[2]
            # x2 += pad[1]
            # y2 += pad[3]

            # -------------------------------------------------------#
            #   将 (x1,y1,x2,y2)转为 (x,y,w,h)
            #   (x1+x2)/2 得到中心点x坐标 再除填充后的图像宽度归一化
            #   boxes[:, 3] *= w_factor / padded_w 先乘以原始宽度恢复目标的宽度再除以填充后的宽度
            # -------------------------------------------------------#
            boxes[:, 1] = ((x1 + x2) / 2) / padded_w
            boxes[:, 2] = ((y1 + y2) / 2) / padded_h
            boxes[:, 3] *= w_factor / padded_w
            boxes[:, 4] *= h_factor / padded_h

            # -------------------------------------------------------#
            #  boxes.shape (class_num,5)目标数量和加上 5 = [种类,x,y,w,h]
            #  targets.shape (class_num,6) 6 = [batch_index,种类,x,y,w,h]
            #  后面我们对 DataLoader 时需要用第一个数来记录是第一个batch里面的图片
            # -------------------------------------------------------#
            targets = torch.zeros((len(boxes), 6))
            targets[:, 1:] = boxes

        # -------------------------------------------------------#
        #   图像增强策略,根据自己需求添加
        # -------------------------------------------------------#

        # if self.augment:
        #     if np.random.random() < 0.5:
        #         #随机水平翻转
        #

        return img_path, img, targets

    # -------------------------------------------------------#
    #    DataLoader的collate_fn 会把当前的批的内容传到collate_fn
    #    DataLoader(dataset,batch_size=32,shuffle=True,num_workers=0, pin_memory=True,collate_fn=dataset.collate_fn)
    #    最终返回是这个函数的返回值
    # -------------------------------------------------------#
    def collate_fn(self, batch):
        # -------------------------------------------------------#
        #    batch进行解包里面包含了 paths, imgs, targets
        # -------------------------------------------------------#
        paths, imgs, targets = list(zip(*batch))
        # -------------------------------------------------------#
        #   对标签进行非空判断,可能会存在一些图片没有标注的情况
        # -------------------------------------------------------#
        targets = [boxes for boxes in targets if boxes is not None]

        # -------------------------------------------------------#
        #   之前在第一个位置全部是填充0,现在有batch_size 就要为每个添加图片索引
        # -------------------------------------------------------#
        for i, boxes in enumerate(targets):
            boxes[:, 0] = i

        # -------------------------------------------------------#
        #   每十个batch随机resize到不同尺寸 并且在区间[320,512]之间
        #   且必须满足32的倍数的尺寸
        #   主干网了Darknet没有全连接层,所以可以通过随机输入图像大小来增加模型泛化的能力
        # -------------------------------------------------------#
        if self.multiscale and self.batch_count % 10 == 0:
            self.img_size = random.choice(range(self.min_size, self.max_size + 1, 32))
        # torch.stack:沿着一个新维度对输入张量序列进行连接
        imgs = torch.stack([img_resize(img, self.img_size) for img in imgs])
        self.batch_count += 1

        return paths, imgs, targets

    def __len__(self):
        # -------------------------------------------------------#
        # __len__()的作用是返回容器中元素的个数,可以自己设置是返回对象还是属性
        # 通过len()函数
        # -------------------------------------------------------#
        return len(self.img_path)
Logo

为开发者提供学习成长、分享交流、生态实践、资源工具等服务,帮助开发者快速成长。

更多推荐