完整的训练一个车牌识别模型
车牌识别是一个典型的计算机视觉多任务问题,通常包含车牌定位、字符分割和字符识别三个子任务。下面我将为你提供一个完整的车牌识别模型训练方案,使用 PyTorch 实现端到端的车牌识别。
完整的训练一个车牌识别模型
车牌识别是一个典型的计算机视觉多任务问题,通常包含车牌定位、字符分割和字符识别三个子任务。下面我将为你提供一个完整的车牌识别模型训练方案,使用 PyTorch 实现端到端的车牌识别
1. 问题分析
车牌识别需要解决的核心问题是:
从图像中识别出车牌上的所有字符
中国车牌通常包含 7 个字符(1 个省份简称 + 1 个字母 + 5 个字母或数字)
需要处理不同角度、光照条件和清晰度的车牌图像
2. 模型设计
本方案采用了一个基于 ResNet18 的多任务学习模型:
使用预训练的 ResNet18 作为特征提取器,提取车牌图像的高级特征
为车牌的每个字符位置设计一个分类器(共 7 个分类器)
每个分类器负责识别对应位置的字符
这种设计的优势是可以端到端训练,无需单独进行字符分割步骤。
3. 数据集准备
要训练这个模型,你需要一个包含车牌图像和对应标签的数据集:
图像应该是已经裁剪好的车牌区域(如果需要完整的端到端解决方案,还需要添加车牌检测模块)
每个图像的标签是对应的车牌字符(如 “京 A12345”)
数据集目录结构可以是简单的所有图像放在一个文件夹中,文件名包含车牌字符
你可以通过以下方式获取车牌数据集:
公开数据集如 CCPD(中国城市停车场数据集)
自己收集并标注图像
使用数据增强技术扩充数据集
4. 训练过程
训练过程包含以下关键步骤:
数据加载和预处理,包括 resize、旋转、平移等数据增强
模型初始化,使用预训练的 ResNet18 作为基础网络
多任务损失计算,对每个字符位置计算交叉熵损失并求和
模型优化,使用 Adam 优化器和学习率调度器
性能评估,包括字符级准确率和全牌准确率
5.主要代码
# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from torch.nn import functional as F
import os
import cv2
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
import random
import string
# 设置随机种子,确保结果可复现
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)
# 1. 数据准备
# 定义车牌字符集(以中国车牌为例)
PROVINCES = ["京", "津", "沪", "渝", "冀", "豫", "云", "辽", "黑", "湘", "皖", "鲁", "新", "苏", "浙", "赣", "鄂", "桂", "甘", "晋", "蒙", "陕", "吉", "闽", "贵", "粤", "青", "藏", "川", "宁", "琼"]
ALPHABETS = list(string.ascii_uppercase)
DIGITS = list(string.digits)
CHARS = PROVINCES + ALPHABETS + DIGITS
CHAR_TO_IDX = {char: idx for idx, char in enumerate(CHARS)}
IDX_TO_CHAR = {idx: char for idx, char in enumerate(CHARS)}
NUM_CLASSES = len(CHARS)
PLATE_LENGTH = 7 # 中国车牌通常是7个字符
# 自定义车牌数据集
class LicensePlateDataset(Dataset):
def __init__(self, image_paths, labels, transform=None):
self.image_paths = image_paths
self.labels = labels
self.transform = transform
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
img_path = self.image_paths[idx]
label = self.labels[idx]
# 读取图像并转换为RGB
img = cv2.imread(img_path)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
# 应用变换
if self.transform:
img = self.transform(img)
# 将标签转换为索引
label_indices = [CHAR_TO_IDX[char] for char in label]
return img, torch.tensor(label_indices, dtype=torch.long)
# 数据加载和预处理
def load_dataset(data_dir):
image_paths = []
labels = []
# 假设数据集组织结构为: data_dir/label/xxx.jpg
# 或者 data_dir/xxx_label.jpg
for filename in os.listdir(data_dir):
if filename.endswith(('.jpg', '.png', '.jpeg')):
# 提取标签(假设文件名格式为 "车牌字符.jpg")
label = filename.split('.')[0]
if len(label) == PLATE_LENGTH:
image_paths.append(os.path.join(data_dir, filename))
labels.append(label)
return image_paths, labels
# 定义数据变换
transform = transforms.Compose([
transforms.ToPILImage(),
transforms.Resize((64, 192)), # 车牌图像通常是宽大于高
transforms.RandomRotation(degrees=5), # 随机旋转增强
transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)), # 随机平移增强
transforms.ColorJitter(brightness=0.2, contrast=0.2), # 亮度和对比度增强
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # ImageNet标准化参数
])
# 2. 模型定义
class LicensePlateRecognizer(nn.Module):
def __init__(self, num_classes=NUM_CLASSES, plate_length=PLATE_LENGTH):
super(LicensePlateRecognizer, self).__init__()
self.plate_length = plate_length
self.num_classes = num_classes
# 使用预训练的ResNet18作为特征提取器
self.backbone = models.resnet18(pretrained=True)
# 移除最后的全连接层
self.features = nn.Sequential(*list(self.backbone.children())[:-1])
# 计算特征图的维度
self.feature_dim = 512 # ResNet18最后输出的特征维度
# 为每个字符位置定义一个分类器
self.classifiers = nn.ModuleList([
nn.Linear(self.feature_dim, num_classes) for _ in range(plate_length)
])
def forward(self, x):
# 提取特征
x = self.features(x)
x = x.view(x.size(0), -1) # 展平特征
# 对每个字符位置进行预测
outputs = [classifier(x) for classifier in self.classifiers]
return outputs
# 3. 训练和验证函数
def train_epoch(model, train_loader, criterion, optimizer, device):
model.train()
total_loss = 0.0
correct = 0
total = 0
for batch_idx, (images, labels) in enumerate(train_loader):
images, labels = images.to(device), labels.to(device)
# 清零梯度
optimizer.zero_grad()
# 前向传播
outputs = model(images)
# 计算每个字符的损失并求和
loss = 0.0
for i in range(PLATE_LENGTH):
loss += criterion(outputs[i], labels[:, i])
# 反向传播和优化
loss.backward()
optimizer.step()
total_loss += loss.item()
# 计算准确率
for i in range(PLATE_LENGTH):
_, predicted = torch.max(outputs[i].data, 1)
total += labels.size(0)
correct += (predicted == labels[:, i]).sum().item()
# 打印训练信息
if batch_idx % 100 == 0:
print(f'Batch {batch_idx}/{len(train_loader)}, Loss: {loss.item():.4f}')
epoch_loss = total_loss / len(train_loader)
epoch_acc = 100 * correct / total
return epoch_loss, epoch_acc
def validate(model, val_loader, criterion, device):
model.eval()
total_loss = 0.0
correct = 0
total = 0
full_correct = 0 # 全对的车牌数量
with torch.no_grad():
for images, labels in val_loader:
images, labels = images.to(device), labels.to(device)
outputs = model(images)
# 计算损失
loss = 0.0
for i in range(PLATE_LENGTH):
loss += criterion(outputs[i], labels[:, i])
total_loss += loss.item()
# 计算准确率
batch_full_correct = True
for i in range(PLATE_LENGTH):
_, predicted = torch.max(outputs[i].data, 1)
total += labels.size(0)
correct += (predicted == labels[:, i]).sum().item()
# 检查是否所有字符都预测正确
if not batch_full_correct:
continue
batch_full_correct = batch_full_correct and (predicted == labels[:, i]).all().item()
if batch_full_correct:
full_correct += labels.size(0)
val_loss = total_loss / len(val_loader)
val_acc = 100 * correct / total
val_full_acc = 100 * full_correct / len(val_loader.dataset)
return val_loss, val_acc, val_full_acc
# 4. 预测函数
def predict_plate(model, image, transform, device):
model.eval()
# 预处理图像
if transform:
image = transform(image)
# 添加批次维度
image = image.unsqueeze(0).to(device)
with torch.no_grad():
outputs = model(image)
# 解析预测结果
plate = []
for i in range(PLATE_LENGTH):
_, predicted = torch.max(outputs[i].data, 1)
plate.append(IDX_TO_CHAR[predicted.item()])
return ''.join(plate)
# 5. 主函数
def main():
# 配置参数
data_dir = './license_plates' # 数据集目录
batch_size = 32
epochs = 20
lr = 0.001
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"使用设备: {device}")
# 加载数据
print("加载数据集...")
image_paths, labels = load_dataset(data_dir)
# 划分训练集和验证集
train_paths, val_paths, train_labels, val_labels = train_test_split(
image_paths, labels, test_size=0.2, random_state=42
)
# 创建数据集和数据加载器
train_dataset = LicensePlateDataset(train_paths, train_labels, transform=transform)
val_dataset = LicensePlateDataset(val_paths, val_labels, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
print(f"训练集大小: {len(train_dataset)}, 验证集大小: {len(val_dataset)}")
# 初始化模型、损失函数和优化器
model = LicensePlateRecognizer().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=lr)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=3, factor=0.5)
# 记录训练过程
train_losses = []
train_accs = []
val_losses = []
val_accs = []
val_full_accs = []
# 训练模型
print("开始训练...")
best_val_acc = 0.0
for epoch in range(epochs):
print(f"\nEpoch {epoch+1}/{epochs}")
print("-" * 50)
# 训练
train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
train_losses.append(train_loss)
train_accs.append(train_acc)
print(f"训练集: 损失 = {train_loss:.4f}, 字符准确率 = {train_acc:.2f}%")
# 验证
val_loss, val_acc, val_full_acc = validate(model, val_loader, criterion, device)
val_losses.append(val_loss)
val_accs.append(val_acc)
val_full_accs.append(val_full_acc)
print(f"验证集: 损失 = {val_loss:.4f}, 字符准确率 = {val_acc:.2f}%, 全牌准确率 = {val_full_acc:.2f}%")
# 调整学习率
scheduler.step(val_loss)
# 保存最佳模型
if val_acc > best_val_acc:
best_val_acc = val_acc
torch.save(model.state_dict(), 'best_license_plate_model.pth')
print(f"保存最佳模型 (验证字符准确率: {best_val_acc:.2f}%)")
# 绘制训练曲线
plt.figure(figsize=(15, 5))
plt.subplot(1, 2, 1)
plt.plot(train_losses, label='训练损失')
plt.plot(val_losses, label='验证损失')
plt.title('损失曲线')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.subplot(1, 2, 2)
plt.plot(train_accs, label='训练字符准确率')
plt.plot(val_accs, label='验证字符准确率')
plt.plot(val_full_accs, label='验证全牌准确率')
plt.title('准确率曲线')
plt.xlabel('Epoch')
plt.ylabel('准确率 (%)')
plt.legend()
plt.tight_layout()
plt.savefig('training_curves.png')
plt.show()
# 随机展示一些预测结果
model.load_state_dict(torch.load('best_license_plate_model.pth'))
model.to(device)
sample_indices = random.sample(range(len(val_dataset)), 5)
plt.figure(figsize=(15, 3))
for i, idx in enumerate(sample_indices):
image, label = val_dataset[idx]
# 转换回原始图像用于显示
img_np = image.permute(1, 2, 0).numpy()
img_np = img_np * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])
img_np = np.clip(img_np, 0, 1)
# 预测车牌
predicted_plate = predict_plate(model, image, None, device)
actual_plate = ''.join([IDX_TO_CHAR[idx.item()] for idx in label])
plt.subplot(1, 5, i+1)
plt.imshow(img_np)
plt.title(f'预测: {predicted_plate}\n实际: {actual_plate}')
plt.axis('off')
plt.tight_layout()
plt.savefig('prediction_samples.png')
plt.show()
if __name__ == '__main__':
main()
6. 模型改进方向
你可以通过以下方式进一步提高模型性能:
使用更强大的骨干网络(如 ResNet50、MobileNetV2 等)
添加注意力机制,让模型更关注字符区域
引入 CTC(Connectionist Temporal Classification)损失函数处理变长序列
增加车牌检测模块,实现从原始图像到车牌识别的完整流程
使用更丰富的数据增强技术,提高模型的鲁棒性
这个实现提供了一个基础框架,你可以根据实际需求进行调整和扩展。
更多推荐
所有评论(0)