别再对着Market-1501数据集发懵了!手把手教你用Python脚本搞定PyTorch格式转换
·
从零开始:Market-1501数据集高效预处理实战指南
第一次打开Market-1501数据集时,那种面对复杂目录结构的茫然感我至今记忆犹新。作为行人重识别领域的经典基准数据集,它独特的文件组织和命名规则常常让初学者望而生畏。本文将带你用Python脚本一步步拆解这个"硬骨头",不仅教你如何操作,更让你理解每个步骤背后的设计逻辑。
1. 理解Market-1501的数据哲学
在动手写代码前,我们需要先读懂数据集设计者的意图。Market-1501的目录结构看似复杂,实则暗含行人重识别任务的评估逻辑:
- 多摄像头采集 :6个摄像头(5高清+1低清)的交叉视角数据,模拟真实监控场景
- 训练/测试划分 :751人用于训练,750人用于测试,确保模型评估的公正性
- 双重标注体系 :自动检测框(DPM)与人工标注并存,反映实际应用中的噪声情况
数据集的核心目录包括:
Market-1501/
├── bounding_box_test/ # 测试集图像
├── bounding_box_train/ # 训练集图像
├── query/ # 查询图像
├── gt_bbox/ # 手工标注框
└── gt_query/ # 查询结果评估标注
文件命名规则包含丰富信息,以 0017_c2s1_000976_01.jpg 为例:
0017:行人ID(标签)c2:摄像头2拍摄s1:第1段视频序列000976:视频帧编号01:该帧上第1个检测框(00表示人工标注)
2. PyTorch数据加载的格式需求
PyTorch的 DataLoader 期望数据组织为:
pytorch/
├── train/
│ ├── 0001/ # 每个ID单独文件夹
│ │ ├── 0001_c1s1_000451_03.jpg
│ │ └── ...
├── val/
└── query/
这种结构与原始格式的关键差异在于:
- 按ID聚合 :同一ID的所有图像归入同一文件夹
- 分离查询集 :评估时需单独处理的查询图像
- 训练验证拆分 :防止模型过拟合的必要措施
3. 自动化转换脚本详解
以下脚本将原始格式转换为PyTorch友好格式,我们逐模块分析其实现逻辑:
3.1 基础目录准备
import os
from shutil import copyfile
dataset_path = './Market-1501'
pytorch_path = os.path.join(dataset_path, 'pytorch')
# 创建输出目录结构
os.makedirs(pytorch_path, exist_ok=True)
required_folders = ['train', 'val', 'query', 'gallery']
for folder in required_folders:
os.makedirs(os.path.join(pytorch_path, folder), exist_ok=True)
注意:
exist_ok=True参数避免目录已存在时报错,使脚本具备幂等性
3.2 训练集处理与验证集拆分
def process_train_val(src_dir, train_dir, val_dir, val_samples=1):
for img_name in os.listdir(src_dir):
if not img_name.endswith('.jpg'):
continue
person_id = img_name.split('_')[0]
src_path = os.path.join(src_dir, img_name)
# 为每个ID创建专属目录
person_train_dir = os.path.join(train_dir, person_id)
os.makedirs(person_train_dir, exist_ok=True)
# 复制到训练集
copyfile(src_path, os.path.join(person_train_dir, img_name))
# 验证集处理(每个ID取前N张)
if len(os.listdir(val_dir)) < val_samples:
person_val_dir = os.path.join(val_dir, person_id)
os.makedirs(person_val_dir, exist_ok=True)
copyfile(src_path, os.path.join(person_val_dir, img_name))
process_train_val(
src_dir=os.path.join(dataset_path, 'bounding_box_train'),
train_dir=os.path.join(pytorch_path, 'train'),
val_dir=os.path.join(pytorch_path, 'val'),
val_samples=1
)
3.3 查询集与测试集处理
def process_query_gallery(src_dir, dst_dir):
for img_name in os.listdir(src_dir):
if not img_name.endswith('.jpg'):
continue
person_id = img_name.split('_')[0]
dst_person_dir = os.path.join(dst_dir, person_id)
os.makedirs(dst_person_dir, exist_ok=True)
copyfile(
os.path.join(src_dir, img_name),
os.path.join(dst_person_dir, img_name)
)
# 处理查询集
process_query_gallery(
src_dir=os.path.join(dataset_path, 'query'),
dst_dir=os.path.join(pytorch_path, 'query')
)
# 处理测试集(gallery)
process_query_gallery(
src_dir=os.path.join(dataset_path, 'bounding_box_test'),
dst_dir=os.path.join(pytorch_path, 'gallery')
)
4. 构建PyTorch数据加载管道
转换完成后,我们可以用以下方式创建高效的数据管道:
4.1 自定义Dataset类
import torch
from torch.utils.data import Dataset
from PIL import Image
import torchvision.transforms as T
class Market1501Dataset(Dataset):
def __init__(self, root_dir, transform=None):
self.root = root_dir
self.transform = transform or T.Compose([
T.Resize((256, 128)),
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
self.samples = []
for person_id in os.listdir(root_dir):
person_dir = os.path.join(root_dir, person_id)
for img_name in os.listdir(person_dir):
self.samples.append((
os.path.join(person_dir, img_name),
int(person_id) # 将ID转为整数
))
def __getitem__(self, index):
img_path, label = self.samples[index]
img = Image.open(img_path).convert('RGB')
return self.transform(img), label
def __len__(self):
return len(self.samples)
4.2 创建DataLoader实例
from torch.utils.data import DataLoader
# 数据增强配置
train_transform = T.Compose([
T.RandomHorizontalFlip(),
T.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0),
T.RandomRotation(10),
T.Resize((256, 128)),
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# 初始化数据集
train_set = Market1501Dataset(
root_dir=os.path.join(pytorch_path, 'train'),
transform=train_transform
)
val_set = Market1501Dataset(
root_dir=os.path.join(pytorch_path, 'val')
)
# 创建数据加载器
train_loader = DataLoader(
train_set, batch_size=32, shuffle=True, num_workers=4
)
val_loader = DataLoader(
val_set, batch_size=32, shuffle=False, num_workers=4
)
5. 高级技巧与避坑指南
在实际项目中,我们还需要考虑以下关键点:
5.1 数据分布分析
使用以下代码检查数据分布:
import matplotlib.pyplot as plt
def plot_distribution(dataset):
id_counts = {}
for _, label in dataset.samples:
id_counts[label] = id_counts.get(label, 0) + 1
plt.figure(figsize=(10, 5))
plt.bar(id_counts.keys(), id_counts.values())
plt.xlabel('Person ID')
plt.ylabel('Image Count')
plt.title('Images per ID Distribution')
plt.show()
plot_distribution(train_set)
5.2 跨摄像头验证策略
Market-1501的特殊性在于同一ID可能来自不同摄像头,因此验证集应该确保:
- 至少包含来自两个摄像头的样本
- 避免同一摄像头下的图像既出现在训练集又出现在验证集
改进后的验证集采样逻辑:
def create_cross_camera_val_set(train_dir, val_dir, min_cameras=2):
from collections import defaultdict
import random
# 统计每个ID的摄像头分布
id_cameras = defaultdict(set)
for img_name in os.listdir(train_dir):
if not img_name.endswith('.jpg'):
continue
person_id, camera_id = img_name.split('_')[0], img_name.split('_')[1][1]
id_cameras[person_id].add(camera_id)
# 为每个ID选择验证样本
for person_id, cameras in id_cameras.items():
if len(cameras) < min_cameras:
continue
# 随机选择两个不同摄像头的样本
selected_cams = random.sample(list(cameras), 2)
val_samples = [
img for img in os.listdir(os.path.join(train_dir, person_id))
if img.split('_')[1][1] in selected_cams
][:2]
# 移动到验证集
os.makedirs(os.path.join(val_dir, person_id), exist_ok=True)
for img in val_samples:
src = os.path.join(train_dir, person_id, img)
dst = os.path.join(val_dir, person_id, img)
os.rename(src, dst)
5.3 高效数据加载优化
对于大规模数据集,建议使用:
- LMDB数据库 :将图像转为二进制存储,减少IO开销
- DALI加速 :NVIDIA开发的高性能数据加载库
- 预处理缓存 :对固定变换的结果进行磁盘缓存
示例LMDB存储实现:
import lmdb
import pickle
def convert_to_lmdb(dataset, lmdb_path, write_frequency=1000):
env = lmdb.open(lmdb_path, map_size=1099511627776)
with env.begin(write=True) as txn:
for idx, (img_path, label) in enumerate(dataset.samples):
with open(img_path, 'rb') as f:
img_data = f.read()
# 使用pickle序列化存储
data = pickle.dumps((img_data, label))
txn.put(str(idx).encode(), data)
if idx % write_frequency == 0:
print(f"Processed {idx} samples")
env.close()
更多推荐
所有评论(0)