本文源码基于开源项目https://github.com/bianjingshan/MOT-deepsort.git

内容简介:

        目前主流的目标跟踪算法都是基于Tracking by Detection策略,即根据目标检测结果进行目标跟踪,本文主要对经典的目标跟踪算法DeepSort C++版本代码的工作流程以及模块进行讲解。如果本文内容有误,欢迎指出探讨。

        本次代码讲解主要遵循由上而下的方式,先讲总体后讲细节。

 代码目录结构如下:

 

data/2DMOT2015/test:用于测试的目标跟踪数据集

demo:目标跟踪简单应用demo

src:目标跟踪源码

 先上代码deersort.h代码:

#ifndef _DEEP_SORT_H_
#define _DEEP_SORT_H_

#include <vector>
//#ifdef __cplusplus
//extern "C" {
//#endif

//坐标信息
typedef struct
{
	int x;
	int y;
	int width;
	int height;
}DS_Rect;

//目标检测信息
typedef struct
{
	int class_id;
	DS_Rect rect;
	float confidence;
}DS_DetectObject;

//目标跟踪信息
typedef struct
{
	int track_id;
	int class_id;
	float confidence;
	DS_Rect rect;
}DS_TrackObject;


typedef void * DS_Tracker;
typedef std::vector<DS_DetectObject> DS_DetectObjects;
typedef std::vector<DS_TrackObject> DS_TrackObjects;

//目标跟踪器的创建
DS_Tracker DS_Create(
	float max_cosine_distance=0.2, 
	int nn_budget=100, 
    float max_iou_distance = 0.7, 
    int max_age = 30, 
	int n_init=3);
//跟踪器删除
bool DS_Delete(DS_Tracker h_tracker);

//跟踪器的更新以完成跟踪过程
bool DS_Update(
	DS_Tracker h_tracker, 
	DS_DetectObjects detect_objects, 
	DS_TrackObjects &track_objects);


//#ifdef __cplusplus
//}
//#endif

#endif


        首先定义三个结构体,分别为目标坐标信息、目标检测信息以及目标跟踪信息结构体;函数部分主要为跟踪器的创建、删除、更新。

上deepsort.cpp源码:

#include <iostream>
#include <sstream>
#include <stdio.h>
#include <fstream>
#include <stdlib.h>
#include <unistd.h>
#include "tracker.h"
#include "deepsort.h"

//跟踪器创建
DS_Tracker DS_Create(
	float max_cosine_distance, 
	int nn_budget, 
    float max_iou_distance, 
    int max_age, 
	int n_init)
{
    return (DS_Tracker)(new tracker(max_cosine_distance, nn_budget, max_iou_distance, max_age, n_init));
}

//跟踪器删除
bool DS_Delete(DS_Tracker h_tracker)
{
    delete((tracker *)h_tracker);
    return true;
}
#if 0
bool DS_Update(
    DS_Tracker h_tracker, 
	DS_DetectObject *p_detects, 
	int detect_num, 
	DS_TrackObject *p_tracks, 
	int *p_tracks_num, 
	int max_tracks_num)
{
    tracker *p_tracker=(tracker *)h_tracker;
    DETECTION_ROW temp_object;
    DETECTIONS detections;
    for(int iloop=0;iloop<detect_num;iloop++)
    {
        temp_object.confidence=p_detects[iloop].confidence;
        temp_object.tlwh = DETECTBOX(p_detects[iloop].x, p_detects[iloop].y, p_detects[iloop].width, p_detects[iloop].height);
//如果使用特征匹配,则先清空特征数组
#ifdef FEATURE_MATCH_EN
        temp_object.feature.setZero();
#endif 
        detections.push_back(temp_object);
    }
    p_tracker->predict();
	p_tracker->update(detections);
    DETECTBOX output_box;
    int output_num=0;
    for(Track& track : p_tracker->tracks) 
    {
        if(!track.is_confirmed() || track.time_since_update > 1) continue;
        output_box=track.to_tlwh();
        
        p_tracks[output_num].track_id=track.track_id;
        p_tracks[output_num].x=output_box(0);
        p_tracks[output_num].y=output_box(1);
        p_tracks[output_num].width=output_box(2);
        p_tracks[output_num].height=output_box(3);
        output_num++;
        if(output_num>=max_tracks_num)
        {
            break;
        }
    }
    *p_tracks_num=output_num;
    return true;
}
#endif

//跟踪器进行更新跟踪
bool DS_Update(
	DS_Tracker h_tracker, 
	DS_DetectObjects detect_objects, 
	DS_TrackObjects &track_objects)
{
    tracker *p_tracker=(tracker *)h_tracker;
    DETECTION_ROW temp_object;
    DETECTIONS detections;
    //遍历目标检测信息
    for(int iloop=0;iloop<detect_objects.size();iloop++)
    {
        temp_object.class_id=detect_objects[iloop].class_id;
        temp_object.confidence=detect_objects[iloop].confidence;
        temp_object.tlwh = DETECTBOX(
            detect_objects[iloop].rect.x, 
            detect_objects[iloop].rect.y, 
            detect_objects[iloop].rect.width, 
            detect_objects[iloop].rect.height);
#ifdef FEATURE_MATCH_EN
        temp_object.feature.setZero();
#endif 
        detections.push_back(temp_object);
    }
    //tracker进行推理更新
    p_tracker->predict();
	p_tracker->update(detections);
    DETECTBOX output_box;
    DS_TrackObject track_object;
    track_objects.clear();
    //筛选出跟踪匹配的目标
    for(Track& track : p_tracker->tracks) 
    {
        if(!track.is_confirmed() || track.time_since_update > 1) continue;
        output_box=track.to_tlwh();
        
        track_object.track_id=track.track_id;
        track_object.class_id=track.class_id;
        track_object.confidence=track.confidence;
        track_object.rect.x=output_box(0);
        track_object.rect.y=output_box(1);
        track_object.rect.width=output_box(2);
        track_object.rect.height=output_box(3);
        track_objects.push_back(track_object);
    }
    return true;
}

        这里的DS_Create为初始化一些参数并创建一个tracker返回;DS_Delete则删除创建的tracker;

        DS_Update实际使用以下面的为例,三个参数分别为初始化好的tracker,目标检测信息数组,目标跟踪信息数组的引用。

        首先将目标检测到的类别、坐标等信息遍历放入到vector,然后使用tracker进行预测推理更新,更新后将得到的tracks进行筛选,得出本次跟踪的结果。其中track_id为跟踪id,可以在外部打印或者画到图像中以查看每个跟踪的目标id。

         下一篇将对tracker的内容做讲解。

Logo

旨在为数千万中国开发者提供一个无缝且高效的云端环境,以支持学习、使用和贡献开源项目。

更多推荐