一、概述

在前序系列博客中,系统已构建了一个分工明确、功能互补的双轨制处理流程。针对国外证件,后端服务采用基于预训练MobileNetV3模型的特征提取方案,通过计算待查证件与数据库样证模板特征向量的余弦相似度,实现证件类型的快速匹配。针对国内证件,则启用一个基于YOLOv11m的高精度紫外防伪特征检测模型,专注于真伪鉴别。这一架构虽然逻辑清晰、覆盖面广,但在国外证件识别的实践中,其核心环节——特征提取的有效性,尚有提升空间。

当前系统的模板匹配性能,高度依赖于MobileNetV3模型所提取特征向量的区分度。然而,该模型是在通用、大规模的ImageNet数据集上预训练的,其学习到的特征更善于分辨猫、狗、汽车等宏观物体,而非国外证件这类高度细分的、以版式布局、微缩文字、安全线和底纹等精细特征为主要区分依据的特殊图像。将这种通用特征直接应用于证件匹配,其效果虽可接受,但并未达到最优,相似度计算结果有时难以在版式极为接近的不同证件间形成足够清晰的界限,构成了系统准确性的瓶颈。

要提升模型针对性,最直接的思路是利用现有的样证库对其进行微调。然而,此路径面临一个核心挑战:小样本问题。在已建立的samples样证库中,每个具体的证件类别(例如,某个国家某个州在特定年份发布的驾照版本)通常只包含极少数样本,大部分类别仅有1至3个实例,部分甚至为孤例。对于深度神经网络而言,在如此稀疏的数据上进行传统的监督式微调,模型极易陷入过拟合,即死记硬背下有限的训练样本,而丧失对新样本的泛化能力。

为攻克这一难题,需要引入一种更适应小样本场景的训练范式。传统的分类学习模式将被摒弃,取而代之的是度量学习(Metric Learning)。其核心思想并非让模型学会“这是哪一类证件”,而是让模型学会如何“衡量两张证件的相似度”。具体实现将采用孪生网络(Siamese Network) 架构,并以 三元组损失(Triplet Loss) 作为优化目标。该方法通过向网络同时输入一个基准样本(Anchor)、一个同类样本(Positive)和一个异类样本(Negative),驱动模型学习一个更优的特征空间。在这个空间里,同类证件的特征向量在欧氏距离上被“拉近”,而异类证件的特征向量则被“推远”。这种“学习相似性”的策略,极大地降低了模型对每个类别样本数量的依赖,是解决小样本识别问题的理想方案。

因此,本篇博客将详细阐述如何应用度量学习,对现有的MobileNetV3特征提取器进行微调,以显著提升其在国外证件识别任务上的精度。后续章节将首先介绍为适应三元组训练而进行的样证库命名与结构优化,随后将深入探讨孪生网络的实现细节、三元组损失函数的设计、数据增强策略以及完整的模型微调流程。

二、样证库命名优化

2.1 命名现状与局限性

在前期的数据准备阶段,为快速构建样证库,采用了基于数字序号的极简命名约定:在每个具体的证件类别目录下,1.jpg2.jpg3.jpg4.jpg分别固定代表正面白光、反面白光、正面紫外和反面紫外四种图像。这种方式虽然在初始化时简单直观,但其内在的刚性结构,为后续的数据集扩展与模型训练带来了两个核心障碍:

  1. 可扩展性缺失:该命名体系是封闭的。每个图像类型(如“正面白光”)都被唯一地映射到一个固定的数字文件名。这导致无法为同一个证件类别添加更多的同类型样本。例如,若获取了同一版式证件的第二份“正面白光”图像,将无法以1.jpg为名存入,因为该名称已被占用。这一限制直接阻碍了样证库的扩充,使得数据集规模无法增长,这对于任何基于深度学习的微调任务而言都是一个根本性的瓶颈。

  2. 语义不明确:纯数字的命名方式缺乏自解释性。开发人员或数据管理员需要额外记忆或查阅文档才能理解每个数字与图像类型的对应关系,增加了维护成本和出错的可能性。

为解决这些问题,特别是在为后续的度量学习准备数据集的背景下,必须对现有的样证库进行一次彻底的、自动化的命名重构。

2.2 新的命名规范

新的命名规范旨在实现语义化可扩展性的统一。其核心思想是采用序号_类型.扩展名的格式。具体规则如下:

  • 文件名结构[index]_[image_type].jpg
  • 序号 ([index]):一个从1开始的递增序号。在本次迁移中,由于每个类别下同类型图像只有一个,所有现有图像的序号都将被统一设置为1
  • 类型映射 ([image_type])
    • 1.jpg -> front_white
    • 2.jpg -> back_white
    • 3.jpg -> front_uv
    • 4.jpg -> back_uv

例如,一个旧的1.jpg文件将被重命名为1_front_white.jpg。这种新格式的优势显而易见:未来若新增一份正面白光样本,可直接命名为2_front_white.jpg,以此类推,实现了数据集的无缝扩展。

2.3 关联数据文件的同步更新

在样证采集中,部分图像可能附带有同名的JSON标注文件(如1.json对应1.jpg),用于记录版面中特定区域的坐标信息。这些JSON文件内部包含一个imagePath字段,其值指向关联的图像文件名。

因此,在重构文件名的同时,必须执行同步更新:

  1. JSON文件名同步:将1.json重命名为1_front_white.json
  2. 内部路径同步:打开1_front_white.json文件,将其中的"imagePath": "1.jpg"字段修改为"imagePath": "1_front_white.jpg"

2.4 自动化迁移脚本

为确保整个迁移过程高效、准确且无遗漏,下面提供一个专门的Python脚本rename_samples.py。该脚本将递归遍历整个samples目录,自动应用上述所有重命名与内容更新规则。

代码清单: rename_samples.py

import os
import json
from pathlib import Path

# 定义旧文件名到新类型名的映射关系
NAME_MAPPING = {
    "1": "front_white",
    "2": "back_white",
    "3": "front_uv",
    "4": "back_uv"
}

def update_json_image_path(json_path: Path, new_image_name: str):
    """
    读取JSON文件,更新其imagePath字段,并写回。
    增加了对多种编码格式(UTF-8, GBK)的兼容处理。
    
    Args:
        json_path (Path): JSON文件的路径。
        new_image_name (str): 新的图像文件名。
    """
    data = None
    # 尝试使用UTF-8编码打开,这是标准做法
    try:
        with open(json_path, 'r', encoding='utf-8') as f:
            data = json.load(f)
    except UnicodeDecodeError:
        print(f"    - '{json_path.name}' is not UTF-8 encoded. Trying GBK...")
        # 如果UTF-8失败,尝试使用GBK编码,这在处理中文Windows环境下生成的文件时很常见
        try:
            with open(json_path, 'r', encoding='gbk') as f:
                data = json.load(f)
        except Exception as e:
            print(f"    - [ERROR] Failed to read '{json_path.name}' with both UTF-8 and GBK: {e}")
            return
    except Exception as e:
        print(f"    - [ERROR] An unexpected error occurred while reading '{json_path.name}': {e}")
        return

    if data is None:
        return

    # 更新imagePath字段
    data['imagePath'] = new_image_name
    
    # 以UTF-8编码写回文件,实现编码统一
    try:
        with open(json_path, 'w', encoding='utf-8') as f:
            json.dump(data, f, ensure_ascii=False, indent=2)
        print(f"    - Updated imagePath in '{json_path.name}' and saved as UTF-8.")
    except Exception as e:
        print(f"    - [ERROR] Failed to write updated JSON to '{json_path.name}': {e}")


def process_directory(current_dir: Path):
    """
    处理单个目录,对其下的1/2/3/4.jpg/.json文件进行重命名。
    """
    print(f"\nProcessing directory: '{current_dir}'")
    
    for old_base_name, new_type_name in NAME_MAPPING.items():
        old_image_path = current_dir / f"{old_base_name}.jpg"
        old_json_path = current_dir / f"{old_base_name}.json"

        if old_image_path.exists():
            # 修正并确认新的文件名格式
            new_image_name = f"1_{new_type_name}.jpg"
            new_image_path = current_dir / new_image_name
            
            os.rename(old_image_path, new_image_path)
            print(f"  - Renamed '{old_image_path.name}' -> '{new_image_path.name}'")

            if old_json_path.exists():
                new_json_name = f"1_{new_type_name}.json"
                new_json_path = current_dir / new_json_name
                
                os.rename(old_json_path, new_json_path)
                print(f"    - Renamed '{old_json_path.name}' -> '{new_json_path.name}'")
                
                update_json_image_path(new_json_path, new_image_name)

def main():
    """
    主函数,遍历samples目录并启动处理流程。
    """
    samples_root = Path("samples")
    if not samples_root.is_dir():
        print(f"Error: Samples directory '{samples_root}' not found.")
        return

    for country_dir in samples_root.iterdir():
        if country_dir.is_dir():
            for state_dir in country_dir.iterdir():
                if state_dir.is_dir():
                    for template_dir in state_dir.iterdir():
                        if template_dir.is_dir():
                            process_directory(template_dir)
                            
    print("\n--- Renaming process completed successfully. ---")


if __name__ == "__main__":
    main()

将此脚本放置于项目根目录(与samples文件夹同级)并执行,它将自动完成对整个样证库的结构化重命名。迁移完成后,数据集的结构将变得清晰、语义化,并为后续添加更多样本、开展模型微调训练奠定了坚实的数据基础。

接下来,手工添加新样本,并按照上述命名规范进行命名。

2.5 样证统计

在完成了样证库的命名规范化并手动扩充了部分新样本之后,对现有数据集的规模和多样性进行一次全面的盘点,是进入模型训练阶段前的一项关键准备工作。清晰地了解数据全貌,有助于评估后续训练任务的可行性,并为制定合理的数据划分与增强策略提供依据。

为实现这一目标,下面提供一个简洁的Python脚本statistics_samples.py。该脚本将自动遍历samples目录,依据其三层结构(国家 -> 省/州/地区 -> 模板ID),精确统计出当前样证库覆盖的国家总数,以及收集到的独立证件模板总数。

代码清单: statistics_samples.py

from pathlib import Path

def main():
    """
    主函数,遍历samples目录,统计国家数量和样证模板总数。
    """
    samples_root = Path("samples")
    if not samples_root.is_dir():
        print(f"错误: 未找到样证库目录 '{samples_root}'。")
        return

    country_count = 0
    template_count = 0

    print("开始统计样证库数据...")

    # 遍历第一层:国家目录
    country_dirs = [d for d in samples_root.iterdir() if d.is_dir()]
    country_count = len(country_dirs)

    # 遍历第二层和第三层以统计模板总数
    for country_dir in country_dirs:
        for state_dir in country_dir.iterdir():
            if state_dir.is_dir():
                # 计算当前省/州/地区目录下的模板数量
                num_templates_in_state = len([d for d in state_dir.iterdir() if d.is_dir()])
                template_count += num_templates_in_state
    
    print("\n--- 样证库统计结果 ---")
    print(f"覆盖国家/地区总数: {country_count} 个")
    print(f"独立样证模板总数: {template_count} 本")
    print("--------------------------")

if __name__ == "__main__":
    main()

将此脚本放置于项目根目录并执行,即可快速获得样证库的宏观统计数据。

python statistics_samples.py

执行后,脚本将输出如下格式的报告:

开始统计样证库数据...

--- 样证库统计结果 ---
覆盖国家/地区总数: 132 个
独立样证模板总数: 330 本
--------------------------

这份统计结果直观地反映了数据集的现状,为后续的度量学习微调提供了关键的基线信息。

三、算法微调和优化

在样证库结构优化与数据扩充完成后,便进入了提升系统识别精度的核心环节——对特征提取模型进行微调。此阶段的目标是,利用已有的、经过规范化整理的样证数据集,通过度量学习(Metric Learning)范式,对预训练的MobileNetV3模型进行优化,使其生成的特征向量能够更好地服务于国外证件的精细化区分任务。

3.1 理论基础:度量学习与孪生网络

传统的监督学习(如分类任务)在小样本场景下表现不佳,模型容易过拟合。度量学习则提供了一种截然不同的思路,其目标不是学习将样本映射到某个固定类别,而是学习一个度量函数或一个特征空间,在这个空间中,相似样本的距离近,不相似样本的距离远。

  • 孪生网络 (Siamese Network): 这是实现度量学习的经典架构。它由两个或多个共享相同权重和架构的子网络组成。在训练过程中,不同的输入样本(例如,两张不同的证件图像)分别通过这些相同的子网络,被映射到同一个特征空间,生成各自的特征向量。由于网络权重共享,这种映射方式是一致且可比的。

  • 三元组损失 (Triplet Loss): 这是驱动孪生网络学习的关键。训练时,不再向网络输入单个样本,而是输入一个“三元组”:

    • 基准样本 (Anchor): 一个随机选取的基准证件图像。
    • 正样本 (Positive): 与基准样本属于同一类别的另一张证件图像。
    • 负样本 (Negative): 与基准样本属于不同类别的证件图像。

    三元组损失函数的目标是,在特征空间中,最小化基准样本与正样本之间的距离(将同类“拉近”),同时最大化基准样本与负样本之间的距离(将异类“推远”),并确保两者之间存在一个预设的边界(margin)。

3.2 数据集准备与三元组生成

为了实施三元组训练,需要创建一个能够动态生成有效三元组的数据加载器。这个加载器不仅要处理图像的读取和预处理,还必须实现一套智能化的负样本采样策略,以应对证件识别的特殊挑战。

  • 正反面数据分离:证件的正面和反面版式迥异,不应混合比较。因此,在构建数据集时,所有以 _front_white.jpg 结尾的图像将被归入“正面数据集”,所有以 _back_white.jpg 结尾的图像则归入“反面数据集”。后续的三元组生成将在这两个数据集中独立进行。

  • 图像预处理:为消除不同采集设备或环境光照带来的色彩差异,所有图像在送入模型前都将经过一个“去色”处理。具体操作为:首先将彩色图像转换为灰度图,然后再将该灰度图转换回三通道的RGB图像。这一步骤保留了图像的结构和纹理信息,同时标准化了色彩空间,增强了模型的鲁棒性。

  • 智能负样本采样策略:一个高质量的负样本,应当是与基准样本相似但又分属不同类别的“困难样本”,这样能迫使模型学习更具辨识度的细微特征。基于此,三元组的生成逻辑如下:

    1. 随机选取一个基准样本 (Anchor)
    2. 从与基准样本同类的样本中,随机选取一个正样本 (Positive)
    3. 优先在与基准样本相同国家、但不同类别的样本中,随机选取一个负样本 (Negative)。例如,若基准是“美国加州驾照”,则优先选择“美国纽约州驾照”作为负样本。
    4. 若在同国家内找不到合适的负样本(例如,该国家在库中只有一个证件类别),则再从所有其他国家的样本中随机选取一个负样本。

3.3 模型微调流程

模型微调的完整流程将被封装在一个独立的训练脚本train_metric_learning.py中。

代码清单: train_metric_learning.py

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import models, transforms
from PIL import Image, ImageDraw # 导入ImageDraw用于绘制
import os
from pathlib import Path
import random
import numpy as np

# --- 自定义数据增强 - 随机划痕 ---
class RandomScratches:
    """
    一个在PIL图像上绘制随机划痕的数据增强变换。
    """
    def __init__(self, num_scratches_range=(1, 5), p=0.8):
        self.num_scratches_range = num_scratches_range
        self.p = p

    def __call__(self, img):
        if random.random() > self.p:
            return img
        img_draw = img.copy()
        draw = ImageDraw.Draw(img_draw)
        width, height = img.size
        num_scratches = random.randint(*self.num_scratches_range)
        for _ in range(num_scratches):
            x1, y1 = random.randint(0, width), random.randint(0, height)
            x2, y2 = random.randint(0, width), random.randint(0, height)
            line_width = random.randint(1, 2)
            line_color = random.randint(50, 200)
            draw.line([(x1, y1), (x2, y2)], fill=(line_color, line_color, line_color), width=line_width)
        return img_draw


# --- 1. 数据集与三元组生成 ---

class TripletDataset(Dataset):
    """
    一个为度量学习生成三元组的数据集。
    
    - 实现了智能负样本采样策略。
    - 为孤立样本实现了包含随机划痕在内的完整数据增强流程。
    """
    def __init__(self, image_dir, image_type_suffix, main_transform=None):
        self.image_dir = Path(image_dir)
        self.transform_main = main_transform
        self.image_type_suffix = image_type_suffix
        
        self.positive_generator_transform = transforms.Compose([
            transforms.RandomAffine(degrees=15, translate=(0.1, 0.1), fill=0),
            transforms.GaussianBlur(kernel_size=(5, 9), sigma=(0.1, 5.0)),
            transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1),
            RandomScratches(p=0.8),
            transforms.ToTensor(),
            transforms.RandomErasing(p=0.7, scale=(0.02, 0.2), ratio=(0.3, 3.3), value=0, inplace=False),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

        self.samples = []
        self.class_to_images = {}
        self.country_to_classes = {}

        print(f"正在为 '{image_type_suffix}' 类型建立索引...")
        for country_dir in self.image_dir.iterdir():
            if country_dir.is_dir():
                country_name = country_dir.name
                country_has_valid_classes = False
                for state_dir in country_dir.iterdir():
                    if state_dir.is_dir():
                        for template_dir in state_dir.iterdir():
                            if template_dir.is_dir():
                                # --- 核心修复逻辑 ---
                                # 1. 首先查找是否存在所需类型的图像
                                image_files = list(template_dir.glob(f"*{self.image_type_suffix}"))
                                
                                # 2. 只有当找到图像时,才创建类别条目
                                if image_files:
                                    if not country_has_valid_classes:
                                        # 延迟创建国家条目,确保其不为空
                                        self.country_to_classes[country_name] = []
                                        country_has_valid_classes = True

                                    class_id = f"{country_name}_{state_dir.name}_{template_dir.name}"
                                    self.country_to_classes[country_name].append(class_id)
                                    self.class_to_images[class_id] = image_files
                                    
                                    for img_file in image_files:
                                        self.samples.append((class_id, img_file, country_name))

        print(f"索引建立完成,共找到 {len(self.samples)} 个有效样本。")

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, index):
        anchor_class, anchor_path, anchor_country = self.samples[index]

        possible_positives = [p for p in self.class_to_images[anchor_class] if p != anchor_path]
        anchor_pil = Image.open(anchor_path).convert('L').convert('RGB')
        
        if possible_positives:
            positive_path = random.choice(possible_positives)
            positive_pil = Image.open(positive_path).convert('L').convert('RGB')
            positive_img = self.transform_main(positive_pil)
        else:
            anchor_pil_resized = transforms.Resize((224, 224))(anchor_pil)
            positive_img = self.positive_generator_transform(anchor_pil_resized)

        negative_class = None
        possible_neg_classes_in_country = [c for c in self.country_to_classes.get(anchor_country, []) if c != anchor_class]
        if possible_neg_classes_in_country:
            negative_class = random.choice(possible_neg_classes_in_country)
        else:
            all_other_classes = [c for c in self.class_to_images.keys() if c != anchor_class]
            negative_class = random.choice(all_other_classes)
            
        negative_path = random.choice(self.class_to_images[negative_class])
        negative_pil = Image.open(negative_path).convert('L').convert('RGB')

        anchor_img = self.transform_main(anchor_pil)
        negative_img = self.transform_main(negative_pil)

        return anchor_img, positive_img, negative_img

# --- 2. 模型定义 ---
class EmbeddingNet(nn.Module):
    def __init__(self, base_model):
        super(EmbeddingNet, self).__init__()
        self.features = nn.Sequential(*list(base_model.children())[:-1])
        self.flatten = nn.Flatten()
    def forward(self, x):
        x = self.features(x)
        x = self.flatten(x)
        return x

# --- 3. 训练脚本 ---
def train():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    main_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.RandomAffine(degrees=5, translate=(0.05, 0.05), scale=(0.95, 1.05)),
        transforms.ColorJitter(brightness=0.2, contrast=0.2),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    front_dataset = TripletDataset("samples", "_front_white.jpg", main_transform)
    back_dataset = TripletDataset("samples", "_back_white.jpg", main_transform)
    
    if not front_dataset.samples or not back_dataset.samples:
        print("错误:正面或反面数据集为空,无法开始训练。请检查 'samples' 目录。")
        return
        
    combined_dataset = torch.utils.data.ConcatDataset([front_dataset, back_dataset])
    dataloader = DataLoader(combined_dataset, batch_size=16, shuffle=True, num_workers=4)

    base_model = models.mobilenet_v3_large(weights=models.MobileNet_V3_Large_Weights.DEFAULT)
    model = EmbeddingNet(base_model).to(device)

    criterion = nn.TripletMarginLoss(margin=1.0)
    optimizer = optim.Adam(model.parameters(), lr=0.0001)

    num_epochs = 100
    print("开始模型微调...")
    preloss = 100000000
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        for i, (anchor, positive, negative) in enumerate(dataloader):
            anchor, positive, negative = anchor.to(device), positive.to(device), negative.to(device)
            
            optimizer.zero_grad()
            
            anchor_embedding = model(anchor)
            positive_embedding = model(positive)
            negative_embedding = model(negative)
            
            loss = criterion(anchor_embedding, positive_embedding, negative_embedding)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()

        epoch_loss = running_loss / len(dataloader)
        if epoch_loss < preloss:
            preloss = epoch_loss
            torch.save(model.state_dict(), "mobilenetv3_finetuned.pth")
        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}")

    print("模型微调完成,权重已保存至 mobilenetv3_finetuned.pth")

if __name__ == "__main__":
    train()

该脚本完整实现了从数据加载、模型构建到训练循环的全过程,并引入了随机仿射变换作为数据增强手段,以进一步提升模型的泛化能力。

需要说明的是,本文使用的torch版本信息如下:

torch                     1.13.1+cu116
torchvision               0.14.1+cu116

3.4 集成微调后的模型

训练完成后,将生成一个mobilenetv3_finetuned.pth权重文件。最后一步是更新feature_extractor.py模块,加载并使用这个经过微调的、性能更强的模型,以替换原有的预训练模型。

代码清单: feature_extractor.py (更新后)

import io
import pickle

import numpy as np
import torch
import torch.nn as nn
from PIL import Image
from torchvision import models, transforms


class EmbeddingNet(nn.Module):
    """
    与训练时相同的模型结构定义,用于加载权重。
    """
    def __init__(self, base_model):
        super(EmbeddingNet, self).__init__()
        self.features = nn.Sequential(*list(base_model.children())[:-1])
        self.flatten = nn.Flatten()

    def forward(self, x):
        x = self.features(x)
        x = self.flatten(x)
        return x

class ImageFeatureExtractor:
    """
    一个封装了微调后MobileNetV3模型的图像特征提取器。
    """
    def __init__(self, model_path="mobilenetv3_finetuned.pth"):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        # 1. 加载基础模型结构
        base_model = models.mobilenet_v3_large()
        self.model = EmbeddingNet(base_model).to(self.device)
        
        # 2. 加载微调后的权重
        self.model.load_state_dict(torch.load(model_path, map_location=self.device))
        self.model.eval()

        # 3. 定义与训练时一致的图像预处理流程
        self.preprocess = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225]
            ),
        ])

    def extract_features(self, image_bytes: bytes) -> bytes:
        """
        接收图像的二进制数据,返回其特征向量的二进制序列化结果。
        """
        # "去色"处理
        image = Image.open(io.BytesIO(image_bytes)).convert("L").convert("RGB")
        
        input_tensor = self.preprocess(image)
        input_batch = input_tensor.unsqueeze(0).to(self.device)

        with torch.no_grad():
            output_features = self.model(input_batch)

        feature_np = output_features.cpu().numpy().flatten()
        return pickle.dumps(feature_np)

通过将feature_extractor.py更新为上述版本,整个后端服务在处理国外证件时,将自动调用经过度量学习微调的、针对性更强的特征提取模型。这标志着系统在国外证件识别的准确性和鲁棒性上完成了一次关键的迭代升级,使其能够更精准地区分版式相似的证件,从而为用户提供更可靠的识别结果。

3.5 重新计算数据库特征向量

在通过度量学习成功对特征提取模型进行微调后,系统获得了一个全新的、针对证件识别任务高度优化的特征空间。原有的database.db数据库中存储的特征向量,是由未经微调的通用MobileNetV3模型生成的,这些旧的特征向量与新模型所处的优化特征空间不兼容,无法用于进行有意义的相似度比对。

为了使整个识别系统能够完全利用微调带来的性能提升,必须确保数据库中存储的样证特征向量与运行时用于提取查询图像特征的模型是同源的。因此,需要对数据库中所有样证的特征向量进行一次彻底的更新。

最直接且最可靠的方案是完全重建数据库。该过程分为以下两个步骤:

  1. 删除旧数据库:在项目根目录下,手动删除database.db文件。这一步操作将彻底清除所有过时的数据,包括旧的表结构和由原始模型生成的特征向量,为新数据库的创建提供一个纯净的环境。

  2. 更新数据库初始化脚本 (init_db.py):原有的init_db.py脚本是基于旧的数字命名约定(1.jpg, 2.jpg等)设计的。现在必须对其进行升级,使其能够识别并处理本篇博客中引入的、更具语义化的新命名规范(1_front_white.jpg, 1_back_uv.jpg等)。升级后的脚本将使用经过微调的ImageFeatureExtractor,遍历samples目录,为每一个样证模板重新提取并存储其所有图像的特征向量。

代码清单: init_db.py (更新后)

import re
from pathlib import Path

from sqlmodel import Session, select, SQLModel

from database import engine
from feature_extractor import ImageFeatureExtractor # 确保导入更新后的特征提取器
from models import CertificateTemplate, Country


def read_image_bytes(image_path: Path) -> bytes:
    """
    安全地读取图像文件,返回其二进制内容。
    如果文件不存在或无法读取,则返回一个空的字节对象。
    """
    if image_path and image_path.exists() and image_path.is_file():
        return image_path.read_bytes()
    return b''


def main():
    """
    主函数,执行数据库初始化和标准样证数据的填充。
    - 使用微调后的模型重新计算所有特征向量。
    - 正确处理同一类别下的多个独立样本。
    """
    print("开始初始化数据库...")
    # 确保数据库表已根据模型创建
    SQLModel.metadata.create_all(engine)
    print("数据库表结构已确认/创建。")

    with Session(engine) as session:
        # 实例化使用微调后权重的特征提取器
        extractor = ImageFeatureExtractor(model_path="mobilenetv3_finetuned.pth")
        print("已加载微调后的图像特征提取器。")

        samples_dir = Path("samples")
        if not samples_dir.exists():
            print(f"错误: 未找到样本数据目录 '{samples_dir}'。初始化终止。")
            return

        for country_dir in samples_dir.iterdir():
            if not country_dir.is_dir():
                continue

            dir_name = country_dir.name
            match = re.match(r'(.+)_(\d{3})$', dir_name)
            if not match:
                print(f"警告: 跳过格式不正确的国家目录 '{dir_name}'。")
                continue
            
            country_name, country_code = match.groups()

            # 检查或创建国家记录
            statement = select(Country).where(Country.code == country_code)
            db_country = session.exec(statement).first()
            if not db_country:
                print(f"数据库中未找到国家 '{country_name}',正在创建...")
                db_country = Country(name=country_name, code=country_code)
                session.add(db_country)
                session.commit()
                session.refresh(db_country)
                print(f"国家 '{country_name}' 创建成功。")

            for state_dir in country_dir.iterdir():
                if not state_dir.is_dir():
                    continue
                state_name = state_dir.name

                for template_dir in state_dir.iterdir():
                    if not template_dir.is_dir():
                        continue
                    
                    template_id = template_dir.name
                    print(f"\n正在处理类别目录: {country_name} - {state_name} - {template_id}")

                    # --- 核心修改:识别并遍历目录下的所有独立样本 ---
                    
                    # 1. 扫描目录,根据文件名前缀的数字识别出所有样本索引
                    all_files = list(template_dir.glob("*.jpg"))
                    sample_indices = sorted(list(set([f.name.split('_')[0] for f in all_files])))
                    
                    if not sample_indices:
                        print("  - 目录下无有效样本文件,跳过。")
                        continue
                        
                    print(f"  - 在该类别下发现 {len(sample_indices)} 个独立样本,索引为: {sample_indices}")

                    # 2. 遍历每一个识别出的样本索引
                    for index in sample_indices:
                        print(f"    - 正在处理样本索引: {index}")
                        
                        # 3. 为当前样本索引查找对应的四种图像
                        front_white_path = template_dir / f"{index}_front_white.jpg"
                        back_white_path = template_dir / f"{index}_back_white.jpg"
                        front_uv_path = template_dir / f"{index}_front_uv.jpg"
                        back_uv_path = template_dir / f"{index}_back_uv.jpg"

                        # 4. 读取图像二进制数据
                        front_white_bytes = read_image_bytes(front_white_path)
                        back_white_bytes = read_image_bytes(back_white_path)
                        front_uv_bytes = read_image_bytes(front_uv_path)
                        back_uv_bytes = read_image_bytes(back_uv_path)

                        # 5. 使用新模型为当前样本提取特征向量
                        print("      - 正在提取特征向量...")
                        feature_front_white = extractor.extract_features(front_white_bytes) if front_white_bytes else b''
                        feature_back_white = extractor.extract_features(back_white_bytes) if back_white_bytes else b''
                        feature_front_uv = extractor.extract_features(front_uv_bytes) if front_uv_bytes else b''
                        feature_back_uv = extractor.extract_features(back_uv_bytes) if back_uv_bytes else b''
                        print("      - 特征向量提取完成。")

                        # 6. 构建描述信息
                        template_name_str = f"{country_name} {state_name} 样证模板 {template_id}"
                        if state_name.lower() == 'other':
                            template_name_str = f"{country_name} 样证模板 {template_id}"
                        
                        # 在描述中加入样本索引以作区分
                        template_desc_str = f"标准样证 - {country_name} - {state_name} - 模板编号 {template_id} (样本 {index})"

                        # 7. 创建数据库记录
                        new_template_instance = CertificateTemplate(
                            name=template_name_str,
                            description=template_desc_str,
                            image_front_white=front_white_bytes,
                            image_front_uv=front_uv_bytes,
                            image_back_white=back_white_bytes,
                            image_back_uv=back_uv_bytes,
                            feature_front_white=feature_front_white,
                            feature_front_uv=feature_front_uv,
                            feature_back_white=feature_back_white,
                            feature_back_uv=feature_back_uv,
                            country_id=db_country.id
                        )

                        session.add(new_template_instance)
                        print(f"    - 样本 {index} 已准备好写入数据库。")

        # 提交所有会话中的更改
        session.commit()
        print("\n所有数据已成功写入数据库。")
    print("数据库初始化完成。")


if __name__ == "__main__":
    main()
  1. 执行脚本与验证:在执行数据库重建之前,首先需要利用Alembic工具根据models.py中的定义,创建出空的数据库表结构。在项目根目录的终端中,执行以下命令:

    alembic upgrade head
    

    该命令会生成一个包含正确表结构、但没有任何数据的database.db文件。

    随后,运行更新后的数据库初始化脚本:

    python init_db.py
    

    脚本将遍历samples目录,使用微调后的新模型为所有样证图像计算特征向量,并将这些高质量的特征连同图像数据一并填充到新建的数据库中。

    执行完毕后,整个后端系统的数据基础便与经过优化的特征提取模型完全对齐。此时,当客户端发起国外证件的识别请求时,后端服务将能够在全新的、更具区分度的特征空间中进行相似度比对,从而显著提升识别的准确率和可靠性。

Logo

更多推荐