数据准备

数据加载器

1. 构建数据类

在深度学习中,数据加载器(DataLoader)负责高效地加载和批量化数据,以便可以快速输入到模型中进行训练。PyTorch 提供了 DatasetDataLoader 类来简化这一过程。

1.1 Dataset类

Dataset 类是一个抽象类,它允许你自定义如何从数据源中加载数据。通常,你需要继承 Dataset 类并实现以下两个方法:

  • __len__():返回数据集的大小。
  • __getitem__():返回给定索引的数据样本。
  • 在Pytorch中,构建自定义数据加载类通常需要继承torch.utils.data.Dataset并实现以下几个方法:
  1. _init_ 方法
    用于初始化数据集对象:通常在这里加载数据,或者定义如何从存储中获取数据的路径和方法。

    def __init__(self, data, labels):
        self.data = data
        self.labels = labels
    
  2. _len_ 方法
    返回样本数量:需要实现,以便 Dataloader加载器能够知道数据集的大小。

    def __len__(self):
        return len(self.data)
    
  3. _getitem_ 方法
    根据索引返回样本:将从数据集中提取一个样本,并可能对样本进行预处理或变换。

    def __getitem__(self, index):
        sample = self.data[index]
        label = self.labels[index]
        return sample, label
    

​ 如果你需要进行更多的预处理或数据变换,可以在 _getitem_ 方法中添加额外的逻辑。

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader


# 定义数据加载类
class CustomDataset(Dataset):
    def __init__(self, data, labels):
        """
        初始化数据集
        :data: 样本数据(例如,一个 NumPy 数组或 PyTorch 张量)
        :labels: 样本标签
        """
        self.data = data
        self.labels = labels

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        index = min(max(index, 0), len(self.data) - 1)
        sample = self.data[index]
        label = self.labels[index]
        return sample, label


def test001():
    # 简单的数据集准备
    data_x = torch.randn(666, 20, requires_grad=True, dtype=torch.float32)
    data_y = torch.randn(data_x.shape[0], 1, dtype=torch.float32)
    dataset = CustomDataset(data_x, data_y)
    # 随便打印个数据看一下
    print(dataset[0])


if __name__ == "__main__":
    test001()

1.2 TensorDataset类

TensorDataset 类是 Dataset 的一个简单实现,它将张量打包为元组。对于许多常见的任务(如监督学习),我们可以直接使用 TensorDataset

from torch.utils.data import TensorDataset

data = torch.randn(100, 3)  # 100个样本,3个特征
labels = torch.randint(0, 2, (100,))  # 100个标签

dataset = TensorDataset(data, labels)
为什么要使用这些类?

通过继承或使用 Dataset 类,你可以灵活地处理数据加载的细节,如数据预处理、标签处理、数据增强等。这种方式使得代码更加模块化和可复用。

2. 数据加载器(DataLoader)

DataLoader 是用来批量加载数据的工具,它会自动地将数据分成小批次并进行洗牌。你可以根据需要指定批量大小、是否打乱数据等。

from torch.utils.data import DataLoader

dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
为什么使用数据加载器?
  • 批处理:通过批量化加载数据,减少内存消耗。
  • 多线程加载:通过指定 num_workers 参数,数据可以并行加载,减少I/O瓶颈,提高训练效率。
  • 数据洗牌:在每个 epoch 之前洗牌数据,避免模型过拟合。

数据集加载案例

1. 加载 CSV 数据集

CSV 格式的文件通常包含表格数据,每一行表示一个样本,列表示不同的特征。在 PyTorch 中,我们可以使用 pandas 库加载 CSV 数据并将其转换为 PyTorch 的张量。

import pandas as pd
import torch
from torch.utils.data import Dataset

class CSVDataset(Dataset):
    def __init__(self, file_path):
        self.data = pd.read_csv(file_path)
        self.features = torch.tensor(self.data.iloc[:, :-1].values, dtype=torch.float32)
        self.labels = torch.tensor(self.data.iloc[:, -1].values, dtype=torch.float32)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.features[idx], self.labels[idx]
为什么这样做?

CSV 文件格式简单且广泛使用,通常用于存储结构化数据。通过将其加载到 PyTorch 的张量中,能够直接利用 GPU 进行加速处理。

2. 加载图片数据集

对于图像数据集,通常需要加载和预处理图像。这可以通过 torchvision 库来实现,它提供了 ImageFolder 等工具,可以方便地从文件夹中加载图像数据。

from torchvision import datasets, transforms

transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
])

dataset = datasets.ImageFolder('path/to/images', transform=transform)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
为什么要这么做?

图像数据通常需要进行尺寸调整、归一化等预处理,torchvision.transforms 提供了多种常用的图像变换操作,可以方便地应用到数据加载过程中。

3. 加载官方数据集

PyTorch 提供了许多常见的标准数据集,可以通过 torchvision.datasets 轻松加载,如 MNIST、CIFAR-10 等。加载这些数据集非常简单,可以直接使用预定义的 API。

from torchvision import datasets, transforms

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=64, shuffle=True)
为什么使用官方数据集?

官方数据集已被预处理,并且通常用于标准化的训练和评估任务,使用它们可以避免处理数据集的繁琐步骤,直接进行模型训练和测试。

数据探索与清洗

1. 检查图像与标注匹配

图像和标注匹配是确保模型训练正确性的第一步。在某些任务中,图像和其对应的标注可能会不匹配。通常需要检查每个图像是否有有效的标注。

import cv2
import os

image_path = 'path/to/image'
annotation_path = 'path/to/annotation'

# 检查图像文件是否有效
image = cv2.imread(image_path)
if image is None:
    print(f"Invalid image file: {image_path}")
    
# 检查标注文件是否有效
with open(annotation_path) as f:
    annotations = f.readlines()
    if not annotations:
        print(f"Invalid annotation file: {annotation_path}")
为什么这样做?

确保图像和标注文件匹配是保证模型训练质量的关键。如果图像和标注不匹配,会导致模型学习到错误的特征,影响训练效果。

2. 删除损坏图像

图像数据集可能包含损坏的图像(如无法打开的文件)。在训练模型之前,需要清理这些损坏的图像。

def check_image_validity(image_path):
    try:
        img = cv2.imread(image_path)
        if img is None:
            return False
        return True
    except Exception as e:
        return False

# 遍历数据集,删除损坏图像
image_paths = ['path/to/image1.jpg', 'path/to/image2.jpg', 'path/to/image3.jpg']
valid_images = [img for img in image_paths if check_image_validity(img)]
为什么这么做?

删除损坏图像有助于确保训练数据集的质量,避免模型因损坏的图像而学到不必要的噪音。

3. 数据集划分

数据集通常需要划分为训练集、验证集和测试集。常见的做法是将 70%-80% 的数据用作训练,剩下的用于验证和测试。可以使用 train_test_split 来完成这一过程。

from sklearn.model_selection import train_test_split

data = [i for i in range(100)]  # 假设有100个样本
train_data, test_data = train_test_split(data, test_size=0.2)
为什么划分数据集?

数据集划分有助于模型评估。训练集用于训练模型,验证集用于调参,测试集用于评估模型的最终性能。


结论

数据准备是深度学习项目中至关重要的一步。通过合理使用数据加载器和进行必要的数据清洗,可以确保模型能够高效且准确地进行训练。特别是图像和文本等复杂数据类型,适当的预处理和清洗能够极大地提高模型的性能和泛化能力。在进行数据处理时,确保数据完整性和一致性是非常关键的,避免无效数据干扰训练过程。

Logo

纵情码海钱塘涌,杭州开发者创新动! 属于杭州的开发者社区!致力于为杭州地区的开发者提供学习、合作和成长的机会;同时也为企业交流招聘提供舞台!

更多推荐