MMdetection批量推理图片


前言

mmdetection作为一个优秀的开源目标检测算法库,在训练模型方面是相当的方便,但是某些时候使用它进行推理时就有点难受,本文就演示如何批量推理图片(多张图片存放在文件夹中),mmdetection的版本是2.27.0


批量推理图片

mmdetection中想测试图片那必须得有对应的标注信息文件,要是没有的话调用官方api只能一张一张推理,慢的要死,还是自己弄一个文件靠谱。可以在根目录下创建一个batch_infer.py的文件,这里需要调用推理的api,我直接贴上代码:

import argparse
import os
from mmdet.apis import inference_detector, init_detector  #, show_result_pyplot
import cv2
from pathlib import Path
 
    
def parse_args():
    parser = argparse.ArgumentParser(
        description='MMDet test (and eval) a model')
    parser.add_argument('--config', type=str, help='配置文件路径')
    parser.add_argument('--checkpoint-file', type=str, help='权重文件路径')
    parser.add_argument(
        '--img-dir', type=str,
        help='待检测图片路径')
    parser.add_argument('--out-dir', type=str, help='保存检测图片路径')
    parser.add_argument(
        '--gpu-ids',
        type=int,
        nargs='+',
        help='(Deprecated, please use --gpu-id) ids of gpus to use '
        '(only applicable to non-distributed training)')
    parser.add_argument(
        '--gpu-id',
        type=int,
        default=0,
        help='id of gpu to use '
        '(only applicable to non-distributed testing)')
    parser.add_argument(
        '--score-thr',
        type=float,
        default=0.50,
        help='score threshold (default: 0.50)')
    args = parser.parse_args()
    
    return args
    
    
    
def show_result_pyplot(model, img, result, score_thr=0.3, fig_size=(15, 10)):
    """Visualize the detection results on the image.
    Args:
        model (nn.Module): The loaded detector.
        img (str or np.ndarray): Image filename or loaded image.
        result (tuple[list] or list): The detection result, can be either
            (bbox, segm) or just bbox.
        score_thr (float): The threshold to visualize the bboxes and masks.
        fig_size (tuple): Figure size of the pyplot figure.
    """
    if hasattr(model, 'module'):
        model = model.module
    img = model.show_result(img, result, score_thr=score_thr, show=False)
    return img
 
def main():  
    args = parse_args()
    
    # config文件
    config_file = args.config
    # 训练好的模型
    checkpoint_file = args.checkpoint_file
    # checkpoint_file = 'work_dirs/faster_rcnn_r50_fpn_1x_coco/epoch_300.pth'
    model = init_detector(config_file, checkpoint_file, device='cuda:0')
 
    # 图片路径
    img_dir = args.img_dir
    # 检测后存放图片路径
    out_dir = args.out_dir
 
    if not os.path.exists(out_dir):
        os.mkdir(out_dir)
    
    # 检测阈值
    score_thr = args.score_thr
    
    img_list = []
    count = 0
    path = Path(img_dir)
        
    for p in path.iterdir():
        # print('model is processing the {}/{} images.'.format(count, len(img_list)))
        model = init_detector(config_file, checkpoint_file, device='cuda:0')
        result = inference_detector(model, str(p))
        img = show_result_pyplot(model, str(p), result, score_thr=score_thr)
        cv2.imwrite("{}/{}.jpg".format(out_dir, p.stem), img)    
 

if __name__ == '__main__':
    main()

演示如何推理:

python batch_infer.py \
--config work_dirs/yolox_s_8x8_300e_coco/yolox_s_8x8_300e_coco.py \
--checkpoint-file work_dirs/yolox_s_8x8_300e_coco/bast.pth \
--img-dir data/coco/test2000 --out-dir work_dirs/detect/xs/test2000

注意事项

我使用的是mmdeteciton-2.27.0版本,2.x版本应该是通用的;还有就是在推理时是需要用到GPU的,我测试过,如果没有GPU会报错,所以请注意这两点。

Logo

旨在为数千万中国开发者提供一个无缝且高效的云端环境,以支持学习、使用和贡献开源项目。

更多推荐