在深度学习中,图像分类已然是一种初级任务。在计算机视觉领域,使用深度学习对图像进行高维特征提取,实现对图像的分类已经非常成熟。常见的深度学习模型,如lenet-5、AlexNet、VGG、Inception、Resnet等系列模型都可以实现图像分类任务。

        一般我们使用Tensorflow或者Pytorch构建深度学习模型,并使用cifar10、cifar100、Imagenet等数据集即可进行一个图像分类模型的训练。要解决实际问题,实现算法落地,就需要根据应用场景对我们需要分类的数据进行数据收集并构建数据集,以完成针对特定分类任务的模型训练和测试。

        例如,我们对几种小鸟(斑鸠、麻雀、白鹭)进行分类。首先需要收集这几种鸟类的图片,然后人为的对这些图像进分类,可以使用不同的文件名对图像进区分,最简单的办法就是将不同类别归入不同文件夹,例如建立一个麻雀文件夹存入所有麻雀相关图片。这样一个小鸟分类的数据就是如下形式:

麻雀数据集:

白鹭数据集:

 

斑鸠数据集:

 

        对数据集划分训练集和验证集,可以用相同目录结构来划分,也可以在训练时通过随机读取方式将一部分图像作为训练集一部分作为验证集。另外数据集的制作对后期训练结果也会产生影响,首先保证不同类别的数据集图片数量尽量一致,为了保证训练数据覆盖足够多特征和场景,也需要确保每个分类图像数量在一百张以上。

        完成数据的收集和整理后,需要对数据集进行预处理。由于图像数据集一般比较大,训练时采用batch方式对图像进行加载,即从成千上万张照片中抽取一部分加载进内存进行训练。另外,还需要对不同分类进行编码也就是制作标签,比如'斑鸠': 0, '白鹭': 1, '麻雀': 2。这需要提前将不同数据生成对应的读取路径和对应编码,存入一个表格文件以方便训练时加载。

        这里使用python对数据集进行读取和预处理。

预处理脚本:

 

import os, glob
import random, csv


def load_csv(root, filename, name2label):
    # root:数据集根目录
    # filename:csv文件名
    # name2label:类别名编码表
    if not os.path.exists(os.path.join(root, filename)):
        images = []
        for name in name2label.keys():
            images += glob.glob(os.path.join(root, name, '*.png'))
            images += glob.glob(os.path.join(root, name, '*.jpg'))
            images += glob.glob(os.path.join(root, name, '*.jpeg'))

        print(len(images), images)

        random.shuffle(images)
        with open(os.path.join(root, filename), mode='w', newline='') as f:
            writer = csv.writer(f)
            for img in images:
                name = img.split(os.sep)[-2]
                label = name2label[name]
                writer.writerow([img, label])
            print('written into csv file:', filename)

    # read from csv file
    images, labels = [], []
    with open(os.path.join(root, filename)) as f:
        reader = csv.reader(f)
        for row in reader:
            img, label = row
            label = int(label)

            images.append(img)
            labels.append(label)

    assert len(images) == len(labels)

    return images, labels


def load_birds(root, mode='train'):
    # 创建数字编码表
    name2label = {}  # "sq...":0
    for name in sorted(os.listdir(os.path.join(root))):
        if not os.path.isdir(os.path.join(root, name)):
            continue
        # 给每个类别编码一个数字
        name2label[name] = len(name2label.keys())

    # 读取Label信息
    # [file1,file2,], [3,1]
    images, labels = load_csv(root, 'images.csv', name2label)

    if mode == 'train':  # 60%
        images = images[:int(0.6 * len(images))]
        labels = labels[:int(0.6 * len(labels))]
    elif mode == 'val':  # 20% = 60%->80%
        images = images[int(0.6 * len(images)):int(0.8 * len(images))]
        labels = labels[int(0.6 * len(labels)):int(0.8 * len(labels))]
    else:  # 20% = 80%->100%
        images = images[int(0.8 * len(images)):]
        labels = labels[int(0.8 * len(labels)):]

    return images, labels, name2label


def main():
    import time
    images, labels, table = load_birds('datasets', 'train')
    print('images', len(images), images)
    print('labels', len(labels), labels)
    print(table)


if __name__ == '__main__':
    main()

生成数据集表格:

 

 

Logo

CSDN联合极客时间,共同打造面向开发者的精品内容学习社区,助力成长!

更多推荐