转自AI Studio,原文链接:一文看懂基于PaddleOCR的表格结构识别算法 - 飞桨AI Studio

📖 0 项目背景

PaddleOCR是百度开源的超轻量级OCR模型库,提供了数十种文本检测、识别模型,旨在打造一套丰富、领先、实用的文字检测、识别模型/工具库,助力使用者训练出更好的模型,并应用落地。

关于如何使用PaddleOCR,平台已经有了非常丰富的教程,可以在十分钟内快速启动自定义图片的OCR任务[1],但该类项目注重于简单任务的快速实现。

或者基于该模型库完成自己的OCR任务,但这些项目更注重于项目的完成效果,而不是讲解如何对该模型库进行定制化以适应自己的任务。

因此,这些项目对于PaddleOCR的学习有一定门槛,不能满足初学者希望快速修改实现自己项目的需求。

本项目将针对该痛点,手把手教你针对同花顺算法竞赛数据,使用PaddleOCR完成表格结构识别。

在该过程中,会结合官网教程。详细讲解训练数据的制作,训练模型参数,推理结果的读取等。

并针对数据特点,提出对表格进行切割后复原的数据处理方法。

🥡 1 项目功能

项目介绍了PaddleOCR的关键参数和在同花顺算法竞赛中的使用。

如果有以下需求,该项目可能对您有用:

  • 希望快速上手PaddleOCR并对其进行修改

  • 希望使用PaddleOCR完成表格结构识别算法

  • 希望参与后续同花顺算法竞赛或使用比赛数据

💡 2 PaddleOCR及表格结构识别介绍

2.1 Paddle OCR特性:

  • 超轻量级中文OCR,总模型仅8.6M

    • 单模型支持中英文数字组合识别、竖排文本识别、长文本识别

    • 检测模型DB(4.1M)+识别模型CRNN(4.5M)

  • 多种文本检测训练算法,EAST、DB

  • 多种文本识别训练算法,Rosetta、CRNN、STAR-Net、RARE

2.2 表格结构识别任务:

表格作为一种高效的数据组织与展现方法被广泛应用,已成为各类文档中最常见的页面对象。

目前很大一部分文档以图片的形式存在,无法直接获取表格信息。

人工还原表格既费时又容易出错,因此如何自动并准确地从文档图片中识别出表格成为一个亟待解决的问题。

但由于表格大小、种类与样式的复杂多样(例如表格中存在不同的背景填充、不同的行列合并方法、不同的分割线类型等),导致表格识别一直是文档识别领域的研究难点。

同花顺算法竞赛专注于表格结构识别,为选手提供了已标注的表格图片数据,需要选手通过深度学习的方法,识别出表格结构并输出。

💡 3 数据集介绍

3.1 赛题任务

训练数据主要包括原始图片及对应的ground truth,ground truth内包含表格位置信息和单元格信息。

选手可以直接使用ground truth内的表格位置信息,也可以使用自己预测的表格位置信息。

在得到表格区域的基础上,选手需要将表格的结构识别出来,

输出单元格的行列结构信息及单元格内的文字位置信息。

3.2 数据说明

640张训练集、106张测试集A、108张测试集B及其对应的ground truth(xml文件)

选手可以直接使用ground truth内的表格位置信息,也可以使用自己预测的表格位置信息。

在得到表格区域的基础上,选手需要将表格的结构识别出来,

输出单元格的行列结构信息及单元格内的文字位置信息。


  • ground truth字段说明

    1. table:表格,包含表格位置信息及该表格内的单元格信息。points字段为“x0, y0 x1, y1 x2,y2 x3, y3”格式,表示表格区域的四个角点,角点顺序不固定

    2. cell:单元格,包含行列信息及位置信息

    3.start_col、end_col、start_row、end_row:单元格所处的行列信息

    4.points:单元格内文本的位置信息,格式为“x0, y0 x1, y1 x2,y2 x3, y3”,表示文本区域的四个角点,角点顺序不固定;当单元格内存在多行文本时,取所有文本的最小外包矩形作为文本区域

  • ground truth示例如下

参考资料

  • [1]. https://aistudio.baidu.com/aistudio/projectdetail/1798439
  • [2]. http://contest.aicubes.cn/#/detail?topicId=51
  • [3]. https://aistudio.baidu.com/aistudio/projectdetail/3639862

🍰 4 项目整体思路

  • 单张表格的标注信息如下:

  • 比赛提供的数据集图片为:

4.1 官方Baseline思路

官方Baseline主要使用分割模型来对表格进行结构分析。

1. 表格可以通过xml文件获取边界框,不需要检测

2. 表格结构分析基于unet语义分割模型来做

生成两个图层,分别是表格的横向线纵向线(有线表格和无线表格都按照有线处理)。

 

3. 分割完成后,表格被横纵线条阶段,用opencv找矩形,还原行列结构,形成cell列表,示意图如下:

4. 以cell为单位,遍历每个文字对象的中心点是否落在本cell中,若中心点在cell中,将四个角点坐标都加入列表L,遍历完成后,取L中所有点的最小外接矩形作为文本框的坐标框

4.2 该项目实现思路

本项目主要使用目标检测对表格进行结构分析。

为避免分割中会存在线段多分割或少分割的情况,本项目直接对每个cell进行使用PaddleOCR进行定位。

此外,为了排除非表格信息的影响,我们先将表格位置图片裁剪下来训练模型。

随后,根据原表格在原图中的位置对表格进行复原,既保证输入模型的数据量少,又能提高模型的表现。

以测试图片 0-10.jpg 为例


1. 根据xml切分图像表格

 

2. 使用PaddleOCR根据切分表格训练每个cell的定位模型

3. 根据每个cell的定位结果进行还原

📲 5 代码实现

In [ ]
# 解压数据集
!unzip -oq /home/aistudio/data/data137832/文档图片表格结构识别算法.zip -d /home/aistudio/work/
In [15]
# 读取数据集,并使用point2loc对无序坐标点进行顺序化
import glob
import os
import xml
from xml.dom import minidom

train_img_path = 'work/train_new/imgs/*.jpg'
train_gt_path = 'work/train_new/gt/*.xml'

train_img = glob.glob(train_img_path)
train_gt = glob.glob(train_gt_path)

def point2loc(rand_corners):
    # 输入顺序不固定的四个角点,按照顺时针输出顺序固定的四个角点
    corners = [[int(item.split(',')[0]),int(item.split(',')[1])] for item in rand_corners.split(' ')] 

    # 找到中心点
    center_point = [0,0]
    for item in corners:
        center_point[0] += item[0]
        center_point[1] += item[1]
    center_point = [item/len(corners) for item in center_point]
    # 判断
    p1,p2,p3,p4 = [0,0],[0,0],[0,0],[0,0]
    for point in corners:
        if point[0]<center_point[0] and  point[1]<center_point[1]:
            p1 = point
        elif point[0]<center_point[0] and  point[1]>center_point[1]:
            p4 = point
        elif point[0]>center_point[0] and  point[1]>center_point[1]:
            p3 = point
        elif point[0]>center_point[0] and  point[1]<center_point[1]:
            p2 = point
        else:
            print('数据发生异常,请排查数据')
            
    return [p1,p2,p3,p4]

# 读取并处理xml文件
table_all = {}
for xml_file in train_gt:
    #打开xml文档
    dom = minidom.parse(xml_file)
    #得到文档元素对象
    root = dom.documentElement
    #table
    tables_lists=root.getElementsByTagName('table')

    #每一个table,记录返回的cell信息
    cell_all = {}
    table_loc = {}
    for idx,table_ in enumerate(tables_lists):
        nodes=table_.childNodes
        # 字段为“x0, y0 x1, y1 x2,y2 x3, y3”格式
        # 表示表格区域的四个角点,角点顺序不固定
        table_pts =point2loc(nodes[0].getAttribute('points'))
        # 记录每个表格的角点
        table_loc[idx] = table_pts
        # 记录每个表格中的cell信息
        # (当单元格内存在多行文本时,取所有文本的最小外包矩形作为文本区域)
        cell_infos = []
        for j in range(1, len(nodes)):
            info_temp = {
                        # start_col、end_col、start_row、end_row:单元格所处的行列信息
                        'row_col': [nodes[j].getAttribute('start-row'), nodes[j].getAttribute('end-row'),
                                    nodes[j].getAttribute('start-col'), nodes[j].getAttribute('end-col')],
                        # 格式为“x0, y0 x1, y1 x2,y2 x3, y3”,
                        # 表示文本区域的四个角点,角点顺序不固定
                        'points': point2loc(nodes[j].firstChild.getAttribute('points'))}
            cell_infos.append(info_temp)
        cell_all[idx] = cell_infos
        cell_all['table_loc'] = table_loc

    table_all[xml_file.replace('gt','imgs').replace('.xml','.jpg')] = cell_all

    
In [16]
# 将图像数据重新切分,使其只包含表格,同时修改坐标。
import os
import cv2
import numpy as np
import json


def dict_slice(adict, start, end):
    # 字典切片
    keys = adict.keys()
    dict_slice = {}
    for k in list(keys)[start:end]:
        dict_slice[k] = adict[k]
    return dict_slice


train_folder = 'work/train_data'
train_path = 'work/train_data.txt'

test_folder = 'work/eval_data'
test_path = 'work/eval_data.txt'

# 切分一部分数据用于模型验证

if  not os.path.exists(train_folder):
    os.mkdir(train_folder)

if  not os.path.exists(test_folder):
    os.mkdir(test_folder)

train_file=open(train_path,mode='w')
test_file=open(test_path,mode='w')

scale  = 0.8 # 切分比例
train_file_img = dict_slice(table_all,0,int(scale*len(table_all)))
test_file_img =  dict_slice(table_all,int(scale*len(table_all)),int(len(table_all)))

# 训练集
for idx,(key,values) in enumerate(train_file_img.items()):
    img_path = key
    img = cv2.imread(img_path)
    if img is None:
        # 存在png数据集
        img = cv2.imread(img_path[:-3] +  'png')

    for table_id in range(len(values)-1):
        Img_name = os.path.join(train_folder,img_path.split('/')[-1].replace('.jpg','')+"_%d" % (table_id+1)+'.jpg')
        table_loc = values['table_loc'][table_id]
        table_pic = img[table_loc[0][1]:table_loc[2][1],table_loc[0][0]:table_loc[2][0],:]
        cv2.imwrite(Img_name, table_pic,[int(cv2.IMWRITE_JPEG_QUALITY), 100])
        # 可视化结果
        # img = cv2.rectangle(img,(table_loc[0][0],table_loc[0][1]), (table_loc[2][0],table_loc[2][1]), (255, 0, 0), 2)
        # cv2.imwrite(os.path.join(train_folder,img_path.split('/')[-1]), img,[int( cv2.IMWRITE_JPEG_QUALITY), 100])
        cell_list = []
        for cell in values[table_id]:
            cell_point = cell['points']
            # 获得该图像表格左上角原始坐标
            left_top_point = table_loc[0]

            cell_point = [[item[0]-left_top_point[0],item[1]-left_top_point[1]] for item in cell_point]
            temp_cell = {'transcription':'#','points':cell_point}
            cell_list.append(json.dumps(temp_cell))
        cell_list_ = str(cell_list).replace('\'{','{').replace('}\'','}')
        train_file.write(Img_name.replace('work/','') +'\t' +str(cell_list_)+'\n')
train_file.close()

# 验证集
for idx,(key,values) in enumerate(test_file_img.items()):
    img_path = key
    img = cv2.imread(img_path)
    if img is None:
        # 存在png数据集
        img = cv2.imread(img_path[:-3] +  'png')

    for table_id in range(len(values)-1):
        Img_name = os.path.join(test_folder,img_path.split('/')[-1].replace('.jpg','')+"_%d" % (table_id+1)+'.jpg')
        table_loc = values['table_loc'][table_id]
        table_pic = img[table_loc[0][1]:table_loc[2][1],table_loc[0][0]:table_loc[2][0],:]
        cv2.imwrite(Img_name, table_pic,[int(cv2.IMWRITE_JPEG_QUALITY), 100])
        # 可视化结果
        # img = cv2.rectangle(img,(table_loc[0][0],table_loc[0][1]), (table_loc[2][0],table_loc[2][1]), (255, 0, 0), 2)
        # cv2.imwrite(os.path.join(test_folder,img_path.split('/')[-1]), img,[int( cv2.IMWRITE_JPEG_QUALITY), 100])
        cell_list = []
        for cell in values[table_id]:
            cell_point = cell['points']
            # 获得该图像表格左上角原始坐标
            left_top_point = table_loc[0]

            cell_point = [[item[0]-left_top_point[0],item[1]-left_top_point[1]] for item in cell_point]
            temp_cell = {'transcription':'#','points':cell_point}
            cell_list.append(json.dumps(temp_cell))
        cell_list_ = str(cell_list).replace('\'{','{').replace('}\'','}')
        test_file.write(Img_name.replace('work/','') +'\t' +str(cell_list_)+'\n')
test_file.close()
In [ ]
%cd ~ 
<span style="color:rgba(0, 0, 0, 0.85)"><span style="background-color:#ffffff">/home/aistudio
</span></span>
In [4]
%cd ~ 
## 源库函数,本项目使用代码已经过修改,直接解压上传的文件
# !git clone -b release/2.1 https://gitee.com/PaddlePaddle/PaddleOCR.git
!unzip -oq /home/aistudio/data/data137832/PaddleOCR.zip -d /home/aistudio/work/
!mv /home/aistudio/work/home/aistudio/PaddleOCR/ /home/aistudio
In [9]
# 安装依赖库
%cd ~/PaddleOCR
!pip install -r requirements.txt -i https://mirror.baidu.com/pypi/simple

5.1 PaddleOCR预训练模型使用演示

该部分详细内容建议参考项目十分钟掌握PaddleOCR使用

In [10]
! mkdir inference
# 下载超轻量级中文OCR模型的检测模型并解压
! cd inference && wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar && tar xf ch_ppocr_mobile_v2.0_det_infer.tar && rm ch_ppocr_mobile_v2.0_det_infer.tar
# 下载超轻量级中文OCR模型的识别模型并解压
! cd inference && wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_infer.tar && tar xf ch_ppocr_mobile_v2.0_rec_infer.tar && rm ch_ppocr_mobile_v2.0_rec_infer.tar
# 下载超轻量级中文OCR模型的文本方向分类器模型并解压
! cd inference && wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_cls_infer.tar && tar xf ch_ppocr_mobile_v2.0_cls_infer.tar && rm ch_ppocr_mobile_v2.0_cls_infer.tar
! cd ..
In [11]
import matplotlib.pyplot as plt
from PIL import Image
%pylab inline

def show_img(img_path,figsize=(10,10)):
    ## 显示原图,读取名称为11.jpg的测试图像
    img = Image.open(img_path)
    plt.figure("test_img", figsize=figsize)
    plt.imshow(img)
    plt.show()
show_img("./doc/imgs_en/11-1.jpg")
<span style="color:rgba(0, 0, 0, 0.85)"><span style="background-color:#ffffff">Populating the interactive namespace from numpy and matplotlib
</span></span>
<span style="color:rgba(0, 0, 0, 0.85)"><span style="background-color:#ffffff"><Figure size 720x720 with 1 Axes></span></span>
In [12]
# 快速运行
!python3 tools/infer/predict_system.py --image_dir="./doc/imgs_en/11-1.jpg" \
--det_model_dir="./inference/ch_ppocr_mobile_v2.0_det_infer"  \
--rec_model_dir="./inference/ch_ppocr_mobile_v2.0_rec_infer" \
--cls_model_dir="./inference/ch_ppocr_mobile_v2.0_cls_infer"
In [ ]
## 显示轻量级模型识别结果
show_img("./inference_results/11-20.jpg",figsize=(20,20))
In [14]
%cd ~/PaddleOCR/train_data/
<span style="color:rgba(0, 0, 0, 0.85)"><span style="background-color:#ffffff">[Errno 2] No such file or directory: '/home/aistudio/PaddleOCR/train_data/'
/home/aistudio/PaddleOCR
</span></span>

5.2 PaddleOCR表格结构定位预训练模型下载

PaddleOCR的检测模型目前支持两种backbone,分别是MobileNetV3ResNet_vd系列。

您可以根据需求使用PaddleClas中的模型更换backbone, 对应的backbone预训练模型可以从PaddleClas repo主页中找到下载链接。教程

cd PaddleOCR/
# 根据backbone的不同选择下载对应的预训练模型
# 下载MobileNetV3的预训练模型
wget -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/pretrained/MobileNetV3_large_x0_5_pretrained.pdparams
# 或,下载ResNet18_vd的预训练模型
wget -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/pretrained/ResNet18_vd_pretrained.pdparams
# 或,下载ResNet50_vd的预训练模型
wget -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/pretrained/ResNet50_vd_ssld_pretrained.pdparams
In [ ]
# 下载MobileNetV3的预训练模型
!wget -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/MobileNetV3_large_x0_5_pretrained.tar
! cd pretrain_models/ && tar xf MobileNetV3_large_x0_5_pretrained.tar
# 下载ResNet50的预训练模型
!wget -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_vd_ssld_pretrained.tar
! cd pretrain_models/ && tar xf ResNet50_vd_ssld_pretrained.tar

5.3 训练

5.3.1 数据要求 在使用框架中,最重要的就是数据格式的对齐。在本任务中,要求的格式为icdar2015数据集格式,见格式

将训练图片放入同一个文件夹(train_images),并用一个txt文件(rec_gt_train.txt)记录图片路径和标签。

  • User的数据集应该 有两个文件夹和两个文件,按照如下方式组织训练数据集:

    任意目录/
      └─ 训练图像文件夹/           数据集的训练数据
      └─ 验证图像文件夹/           数据集的测试数据
      └─ train_label.txt           数据集的训练标注
      └─ teval_label.txt           数据集的测试标注
    
  • 提供的标注文件格式如下,中间用"\t"分隔:

     " 图像文件名                    json.dumps编码的图像标注信息"
    ch4_test_images/img_61.jpg    [{"transcription": "", "points": [[310, 104], [416, 141], [418, 216], [312, 179]]}, {...}]
    

json.dumps编码前的图像标注信息是包含多个字典的list,

字典中的 points 表示文本框的四个点的坐标(x, y),从左上角的点开始顺时针排列。

transcription 表示当前文本框的文字,当其内容为“###”时,表示该文本框无效,在训练时会跳过。

  • 本项目的数据构造过程可以见项目第五章的数据集生成部分。

5.3.2 参数设置

参数设置可以极大的影响模型的训练结果和训练时间。

直接使用默认参数时,本项目的precision精度仅约86%,通过参数调整,最终的precision约93%。

PaddleOCR将网络划分为四部分,分别在ppocr/modeling下。 进入网络的数据将按照顺序(transforms->backbones-> necks->heads)依次通过这四个部分。

参数的选择会直接影响模型的性能。

例如:模型主干网络选择对比

包括模型结构的自定义实现,[教程](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.4/doc/doc_ch/detection.md)

此外,根据如下代码,模型根据yml的参数进行训练。模型的参数含义解释见教程


训练时间14H

# 训练backbone为的db算法的检测模型
!python3 tools/train.py -c configs/det/det_r50_vd_db.yml -o \
Global.eval_batch_step="[0,1000]" \
Global.use_amp=True \
Global.scale_loss=1024.0  \
Global.use_dynamic_loss_scaling=True \
Global.load_static_weights=true \
Global.pretrained_model='./pretrain_models/ResNet50_vd_ssld_pretrained' \
Train.dataset.data_dir='../work/' \
Train.dataset.label_file_list=['../work/train_data.txt'] \
Eval.dataset.data_dir='../work/' \
Eval.dataset.label_file_list=['../work/eval_data.txt']
In [ ]
# 生成测试数据集
%cd ..
<span style="color:rgba(0, 0, 0, 0.85)"><span style="background-color:#ffffff">/home/aistudio
</span></span>
In [ ]
# 切分测试数据
import glob
import os
import xml
from xml.dom import minidom

train_img_path = 'work/test_a/imgs/*.jpg'
train_gt_path = 'work/test_a/gt/*.xml'

train_img = glob.glob(train_img_path)
train_gt = glob.glob(train_gt_path)

def point2loc(rand_corners):
    # 输入顺序不固定的四个角点,按照顺时针输出顺序固定的四个角点
    corners = [[int(item.split(',')[0]),int(item.split(',')[1])] for item in rand_corners.split(' ')] 

    # 找到中心点
    center_point = [0,0]
    for item in corners:
        center_point[0] += item[0]
        center_point[1] += item[1]
    center_point = [item/len(corners) for item in center_point]
    # 判断
    p1,p2,p3,p4 = [0,0],[0,0],[0,0],[0,0]
    for point in corners:
        if point[0]<center_point[0] and  point[1]<center_point[1]:
            p1 = point
        elif point[0]<center_point[0] and  point[1]>center_point[1]:
            p4 = point
        elif point[0]>center_point[0] and  point[1]>center_point[1]:
            p3 = point
        elif point[0]>center_point[0] and  point[1]<center_point[1]:
            p2 = point
        else:
            print('数据发生异常,请排查数据')
            
    return [p1,p2,p3,p4]

# 读取并处理xml文件
table_all = {}
for xml_file in train_gt:
    #打开xml文档
    dom = minidom.parse(xml_file)
    #得到文档元素对象
    root = dom.documentElement
    #table
    tables_lists=root.getElementsByTagName('table')

    #每一个table,记录返回的cell信息
    cell_all = {}
    table_loc = {}
    for idx,table_ in enumerate(tables_lists):
        nodes=table_.childNodes
        # 字段为“x0, y0 x1, y1 x2,y2 x3, y3”格式
        # 表示表格区域的四个角点,角点顺序不固定
        table_pts =point2loc(nodes[0].getAttribute('points'))
        # 记录每个表格的角点
        table_loc[idx] = table_pts
        # 记录每个表格中的cell信息
        # (当单元格内存在多行文本时,取所有文本的最小外包矩形作为文本区域)
        cell_infos = []
        for j in range(1, len(nodes)):
            info_temp = {
                        # start_col、end_col、start_row、end_row:单元格所处的行列信息
                        'row_col': [nodes[j].getAttribute('start-row'), nodes[j].getAttribute('end-row'),
                                    nodes[j].getAttribute('start-col'), nodes[j].getAttribute('end-col')],
                        # 格式为“x0, y0 x1, y1 x2,y2 x3, y3”,
                        # 表示文本区域的四个角点,角点顺序不固定
                        'points': point2loc(nodes[j].firstChild.getAttribute('points'))}
            cell_infos.append(info_temp)
        cell_all[idx] = cell_infos
        cell_all['table_loc'] = table_loc

    table_all[xml_file.replace('gt','imgs').replace('.xml','.jpg')] = cell_all
In [ ]
# 将图像数据重新切分,使其只包含表格,同时修改坐标。
import os
import cv2
import numpy as np
import json

def dict_slice(adict, start, end):
    # 字典切片
    keys = adict.keys()
    dict_slice = {}
    for k in list(keys)[start:end]:
        dict_slice[k] = adict[k]
    return dict_slice


label_folder = 'work/test_data'
label_path = 'work/test_data.txt'

if  not os.path.exists(label_folder):
    os.mkdir(label_folder)


label_file=open(label_path,mode='w')

# 训练集
for idx,(key,values) in enumerate(table_all.items()):
    img_path = key
    img = cv2.imread(img_path)
    if img is None:
        # 存在png数据集
        img = cv2.imread(img_path[:-3] +  'png')

    for table_id in range(len(values)-1):
        Img_name = os.path.join(label_folder,img_path.split('/')[-1].replace('.jpg','')+"_%d" % (table_id+1)+'.jpg')
        table_loc = values['table_loc'][table_id]
        table_pic = img[table_loc[0][1]:table_loc[2][1],table_loc[0][0]:table_loc[2][0],:]
        cv2.imwrite(Img_name, table_pic,[int(cv2.IMWRITE_JPEG_QUALITY), 100])
        # 可视化结果
        # img = cv2.rectangle(img,(table_loc[0][0],table_loc[0][1]), (table_loc[2][0],table_loc[2][1]), (255, 0, 0), 2)
        # cv2.imwrite(os.path.join(train_folder,img_path.split('/')[-1]), img,[int( cv2.IMWRITE_JPEG_QUALITY), 100])
        cell_list = []
        for cell in values[table_id]:
            cell_point = cell['points']
            # 获得该图像表格左上角原始坐标
            left_top_point = table_loc[0]

            cell_point = [[item[0]-left_top_point[0],item[1]-left_top_point[1]] for item in cell_point]
            temp_cell = {'transcription':'#','points':cell_point}
            cell_list.append(json.dumps(temp_cell))
        cell_list_ = str(cell_list).replace('\'{','{').replace('}\'','}')
        label_file.write(Img_name.replace('work/','') +'\t' +str(cell_list_)+'\n')
label_file.close()
In [ ]
%cd ~/PaddleOCR
<span style="color:rgba(0, 0, 0, 0.85)"><span style="background-color:#ffffff">/home/aistudio/PaddleOCR
</span></span>
#生成检测的mask
!python3 tools/infer_det.py -c configs/det/det_r50_vd_db.yml -o \
Global.infer_img="../work/test_data/" \
Global.pretrained_model="./output/det_r50_vd/best_accuracy"
In [ ]
%cd ..
<span style="color:rgba(0, 0, 0, 0.85)"><span style="background-color:#ffffff">/home/aistudio
</span></span>
In [ ]
import pandas as pd
import json
import os
import glob
import cv2
## 生成分块图到原图的mask图
# test
split_img = 'work/test_data'

split_cell = 'PaddleOCR/output/det_db/predicts_db.txt'
data = pd.read_table(split_cell,header=None, sep='\t')
data[0] = data[0].apply(lambda x: x.replace('../work/test_data/',''))

label_img = 'work/test_a/imgs'
label_table = 'work/test_a/gt'

save_mask_path = 'work/mask_test_a/'

if  not os.path.exists(save_mask_path):
    os.mkdir(save_mask_path)

train_img_path = 'work/test_a/imgs/*.jpg'
train_gt_path = 'work/test_a/gt/*.xml'

label_imgs = glob.glob(label_img+'/*.jpg')

label_img_flag,img_name = 0,0
for i in range(data.shape[0]):
    subimg_name = data.iloc[i,0]
    # 图片中的第几个表格
    sub_idx = int(subimg_name[:-4].split('_')[-1])
    # 该表格的原图坐标,取左上角
    img_name = subimg_name[:-5-len(list(str(sub_idx)))]  + subimg_name[-4:]
    img_path = 'work/test_a/imgs/'+img_name
    table_locs = table_all[img_path]['table_loc'][sub_idx-1][0]

    print(subimg_name,sub_idx,img_name)
    # 读取第一张图像
    if i ==0 and label_img_flag == 0:
        img_name_ = img_name
        label_img_flag = img_path
        # 更新
        img = cv2.imread(img_path)
        if img is None:
            # 存在png数据集
            img = cv2.imread(img_path[:-3] +  'png')

    # 当该图的子图处理完了,再处理下一张图像
    elif img_path!=label_img_flag:

        # 保存处理完的图像
        cv2.imwrite(save_mask_path+img_name_, img,[int(cv2.IMWRITE_JPEG_QUALITY), 100])
        img_name_ = img_name
        # 更新label_img_flag
        label_img_flag = img_path
        img = cv2.imread(img_path)
        if img is None:
            # 存在png数据集
            img = cv2.imread(img_path[:-3] +  'png')
        

    imgloc = data.iloc[i,1]
    locs = imgloc.replace('{"transcription": "", ','').replace(']}, ','').replace(']}]','').replace(']','').replace('[','').replace(' ','').split('"points":')
    locs = locs[1:]
    # 该表格的检测到的所有cell
    for idx,loc in enumerate(locs):

        location = loc.split(',')
        location = [int(item) for item in location]
        
        location = [item+table_locs[0] if i%2==0 else item+table_locs[1] for i,item in enumerate(location)]
        cv2.rectangle(img, (int(location[0]), int(location[1])), (int(location[4]), int(location[5])),(0,0,255),2)

🏛 6 效果对比

我们对同一张数据演示训练前和训练后的效果对比:
直接使用预训练模型使用训练完成的模型

训练完成的模型较默认预训练模型的表格识别精度大大提高。

🐱 7 项目总结

  • 项目主要讲解基于PPOCR完成同花顺-文档图片表格结构识别竞赛

  • 更改源为GITEE,增加拉取成功率,实现了坐标顺序处理。

  • 相比Baseline展示可视化结果,精度大大提高

  • 模型验证集达到precision: 0.933, recall: 0.864。该模型精度在榜单上可以排到TOP2。


特别注意:本项目还有Cell的定位部分未完成,作为后续的改进方向。


有任何问题,欢迎评论区留言交流
Logo

瓜分20万奖金 获得内推名额 丰厚实物奖励 易参与易上手

更多推荐