完整的训练一个车牌识别模型

车牌识别是一个典型的计算机视觉多任务问题,通常包含车牌定位、字符分割和字符识别三个子任务。下面我将为你提供一个完整的车牌识别模型训练方案,使用 PyTorch 实现端到端的车牌识别

1. 问题分析

车牌识别需要解决的核心问题是:
从图像中识别出车牌上的所有字符
中国车牌通常包含 7 个字符(1 个省份简称 + 1 个字母 + 5 个字母或数字)
需要处理不同角度、光照条件和清晰度的车牌图像

2. 模型设计

本方案采用了一个基于 ResNet18 的多任务学习模型:
使用预训练的 ResNet18 作为特征提取器,提取车牌图像的高级特征
为车牌的每个字符位置设计一个分类器(共 7 个分类器)
每个分类器负责识别对应位置的字符
这种设计的优势是可以端到端训练,无需单独进行字符分割步骤。

3. 数据集准备

要训练这个模型,你需要一个包含车牌图像和对应标签的数据集:
图像应该是已经裁剪好的车牌区域(如果需要完整的端到端解决方案,还需要添加车牌检测模块)
每个图像的标签是对应的车牌字符(如 “京 A12345”)
数据集目录结构可以是简单的所有图像放在一个文件夹中,文件名包含车牌字符
你可以通过以下方式获取车牌数据集:
公开数据集如 CCPD(中国城市停车场数据集)
自己收集并标注图像
使用数据增强技术扩充数据集

4. 训练过程

训练过程包含以下关键步骤:
数据加载和预处理,包括 resize、旋转、平移等数据增强
模型初始化,使用预训练的 ResNet18 作为基础网络
多任务损失计算,对每个字符位置计算交叉熵损失并求和
模型优化,使用 Adam 优化器和学习率调度器
性能评估,包括字符级准确率和全牌准确率

5.主要代码

# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from torch.nn import functional as F
import os
import cv2
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
import random
import string

# 设置随机种子,确保结果可复现
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

# 1. 数据准备
# 定义车牌字符集(以中国车牌为例)
PROVINCES = ["京", "津", "沪", "渝", "冀", "豫", "云", "辽", "黑", "湘", "皖", "鲁", "新", "苏", "浙", "赣", "鄂", "桂", "甘", "晋", "蒙", "陕", "吉", "闽", "贵", "粤", "青", "藏", "川", "宁", "琼"]
ALPHABETS = list(string.ascii_uppercase)
DIGITS = list(string.digits)
CHARS = PROVINCES + ALPHABETS + DIGITS
CHAR_TO_IDX = {char: idx for idx, char in enumerate(CHARS)}
IDX_TO_CHAR = {idx: char for idx, char in enumerate(CHARS)}
NUM_CLASSES = len(CHARS)
PLATE_LENGTH = 7  # 中国车牌通常是7个字符

# 自定义车牌数据集
class LicensePlateDataset(Dataset):
    def __init__(self, image_paths, labels, transform=None):
        self.image_paths = image_paths
        self.labels = labels
        self.transform = transform
        
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        label = self.labels[idx]
        
        # 读取图像并转换为RGB
        img = cv2.imread(img_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        
        # 应用变换
        if self.transform:
            img = self.transform(img)
            
        # 将标签转换为索引
        label_indices = [CHAR_TO_IDX[char] for char in label]
        return img, torch.tensor(label_indices, dtype=torch.long)

# 数据加载和预处理
def load_dataset(data_dir):
    image_paths = []
    labels = []
    
    # 假设数据集组织结构为: data_dir/label/xxx.jpg
    # 或者 data_dir/xxx_label.jpg
    for filename in os.listdir(data_dir):
        if filename.endswith(('.jpg', '.png', '.jpeg')):
            # 提取标签(假设文件名格式为 "车牌字符.jpg")
            label = filename.split('.')[0]
            if len(label) == PLATE_LENGTH:
                image_paths.append(os.path.join(data_dir, filename))
                labels.append(label)
    
    return image_paths, labels

# 定义数据变换
transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((64, 192)),  # 车牌图像通常是宽大于高
    transforms.RandomRotation(degrees=5),  # 随机旋转增强
    transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),  # 随机平移增强
    transforms.ColorJitter(brightness=0.2, contrast=0.2),  # 亮度和对比度增强
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # ImageNet标准化参数
])

# 2. 模型定义
class LicensePlateRecognizer(nn.Module):
    def __init__(self, num_classes=NUM_CLASSES, plate_length=PLATE_LENGTH):
        super(LicensePlateRecognizer, self).__init__()
        self.plate_length = plate_length
        self.num_classes = num_classes
        
        # 使用预训练的ResNet18作为特征提取器
        self.backbone = models.resnet18(pretrained=True)
        # 移除最后的全连接层
        self.features = nn.Sequential(*list(self.backbone.children())[:-1])
        
        # 计算特征图的维度
        self.feature_dim = 512  # ResNet18最后输出的特征维度
        
        # 为每个字符位置定义一个分类器
        self.classifiers = nn.ModuleList([
            nn.Linear(self.feature_dim, num_classes) for _ in range(plate_length)
        ])
    
    def forward(self, x):
        # 提取特征
        x = self.features(x)
        x = x.view(x.size(0), -1)  # 展平特征
        
        # 对每个字符位置进行预测
        outputs = [classifier(x) for classifier in self.classifiers]
        return outputs

# 3. 训练和验证函数
def train_epoch(model, train_loader, criterion, optimizer, device):
    model.train()
    total_loss = 0.0
    correct = 0
    total = 0
    
    for batch_idx, (images, labels) in enumerate(train_loader):
        images, labels = images.to(device), labels.to(device)
        
        # 清零梯度
        optimizer.zero_grad()
        
        # 前向传播
        outputs = model(images)
        
        # 计算每个字符的损失并求和
        loss = 0.0
        for i in range(PLATE_LENGTH):
            loss += criterion(outputs[i], labels[:, i])
        
        # 反向传播和优化
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        
        # 计算准确率
        for i in range(PLATE_LENGTH):
            _, predicted = torch.max(outputs[i].data, 1)
            total += labels.size(0)
            correct += (predicted == labels[:, i]).sum().item()
        
        # 打印训练信息
        if batch_idx % 100 == 0:
            print(f'Batch {batch_idx}/{len(train_loader)}, Loss: {loss.item():.4f}')
    
    epoch_loss = total_loss / len(train_loader)
    epoch_acc = 100 * correct / total
    return epoch_loss, epoch_acc

def validate(model, val_loader, criterion, device):
    model.eval()
    total_loss = 0.0
    correct = 0
    total = 0
    full_correct = 0  # 全对的车牌数量
    
    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            
            outputs = model(images)
            
            # 计算损失
            loss = 0.0
            for i in range(PLATE_LENGTH):
                loss += criterion(outputs[i], labels[:, i])
            total_loss += loss.item()
            
            # 计算准确率
            batch_full_correct = True
            for i in range(PLATE_LENGTH):
                _, predicted = torch.max(outputs[i].data, 1)
                total += labels.size(0)
                correct += (predicted == labels[:, i]).sum().item()
                
                # 检查是否所有字符都预测正确
                if not batch_full_correct:
                    continue
                batch_full_correct = batch_full_correct and (predicted == labels[:, i]).all().item()
            
            if batch_full_correct:
                full_correct += labels.size(0)
    
    val_loss = total_loss / len(val_loader)
    val_acc = 100 * correct / total
    val_full_acc = 100 * full_correct / len(val_loader.dataset)
    return val_loss, val_acc, val_full_acc

# 4. 预测函数
def predict_plate(model, image, transform, device):
    model.eval()
    
    # 预处理图像
    if transform:
        image = transform(image)
    
    # 添加批次维度
    image = image.unsqueeze(0).to(device)
    
    with torch.no_grad():
        outputs = model(image)
    
    # 解析预测结果
    plate = []
    for i in range(PLATE_LENGTH):
        _, predicted = torch.max(outputs[i].data, 1)
        plate.append(IDX_TO_CHAR[predicted.item()])
    
    return ''.join(plate)

# 5. 主函数
def main():
    # 配置参数
    data_dir = './license_plates'  # 数据集目录
    batch_size = 32
    epochs = 20
    lr = 0.001
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"使用设备: {device}")
    
    # 加载数据
    print("加载数据集...")
    image_paths, labels = load_dataset(data_dir)
    
    # 划分训练集和验证集
    train_paths, val_paths, train_labels, val_labels = train_test_split(
        image_paths, labels, test_size=0.2, random_state=42
    )
    
    # 创建数据集和数据加载器
    train_dataset = LicensePlateDataset(train_paths, train_labels, transform=transform)
    val_dataset = LicensePlateDataset(val_paths, val_labels, transform=transform)
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
    
    print(f"训练集大小: {len(train_dataset)}, 验证集大小: {len(val_dataset)}")
    
    # 初始化模型、损失函数和优化器
    model = LicensePlateRecognizer().to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=3, factor=0.5)
    
    # 记录训练过程
    train_losses = []
    train_accs = []
    val_losses = []
    val_accs = []
    val_full_accs = []
    
    # 训练模型
    print("开始训练...")
    best_val_acc = 0.0
    
    for epoch in range(epochs):
        print(f"\nEpoch {epoch+1}/{epochs}")
        print("-" * 50)
        
        # 训练
        train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
        train_losses.append(train_loss)
        train_accs.append(train_acc)
        print(f"训练集: 损失 = {train_loss:.4f}, 字符准确率 = {train_acc:.2f}%")
        
        # 验证
        val_loss, val_acc, val_full_acc = validate(model, val_loader, criterion, device)
        val_losses.append(val_loss)
        val_accs.append(val_acc)
        val_full_accs.append(val_full_acc)
        print(f"验证集: 损失 = {val_loss:.4f}, 字符准确率 = {val_acc:.2f}%, 全牌准确率 = {val_full_acc:.2f}%")
        
        # 调整学习率
        scheduler.step(val_loss)
        
        # 保存最佳模型
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save(model.state_dict(), 'best_license_plate_model.pth')
            print(f"保存最佳模型 (验证字符准确率: {best_val_acc:.2f}%)")
    
    # 绘制训练曲线
    plt.figure(figsize=(15, 5))
    
    plt.subplot(1, 2, 1)
    plt.plot(train_losses, label='训练损失')
    plt.plot(val_losses, label='验证损失')
    plt.title('损失曲线')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    
    plt.subplot(1, 2, 2)
    plt.plot(train_accs, label='训练字符准确率')
    plt.plot(val_accs, label='验证字符准确率')
    plt.plot(val_full_accs, label='验证全牌准确率')
    plt.title('准确率曲线')
    plt.xlabel('Epoch')
    plt.ylabel('准确率 (%)')
    plt.legend()
    
    plt.tight_layout()
    plt.savefig('training_curves.png')
    plt.show()
    
    # 随机展示一些预测结果
    model.load_state_dict(torch.load('best_license_plate_model.pth'))
    model.to(device)
    
    sample_indices = random.sample(range(len(val_dataset)), 5)
    plt.figure(figsize=(15, 3))
    
    for i, idx in enumerate(sample_indices):
        image, label = val_dataset[idx]
        # 转换回原始图像用于显示
        img_np = image.permute(1, 2, 0).numpy()
        img_np = img_np * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])
        img_np = np.clip(img_np, 0, 1)
        
        # 预测车牌
        predicted_plate = predict_plate(model, image, None, device)
        actual_plate = ''.join([IDX_TO_CHAR[idx.item()] for idx in label])
        
        plt.subplot(1, 5, i+1)
        plt.imshow(img_np)
        plt.title(f'预测: {predicted_plate}\n实际: {actual_plate}')
        plt.axis('off')
    
    plt.tight_layout()
    plt.savefig('prediction_samples.png')
    plt.show()

if __name__ == '__main__':
    main()

6. 模型改进方向

你可以通过以下方式进一步提高模型性能:
使用更强大的骨干网络(如 ResNet50、MobileNetV2 等)
添加注意力机制,让模型更关注字符区域
引入 CTC(Connectionist Temporal Classification)损失函数处理变长序列
增加车牌检测模块,实现从原始图像到车牌识别的完整流程
使用更丰富的数据增强技术,提高模型的鲁棒性
这个实现提供了一个基础框架,你可以根据实际需求进行调整和扩展。

Logo

惟楚有才,于斯为盛。欢迎来到长沙!!! 茶颜悦色、臭豆腐、CSDN和你一个都不能少~

更多推荐