前言

在MMDetection框架中,我们经常会在forward函数中看到下面的代码,以ATSS为例。

    def forward(self, x: Tuple[Tensor]) -> Tuple[List[Tensor]]:
        """Forward features from the upstream network.

        Args:
            x (tuple[Tensor]): Features from the upstream network, each is
                a 4D-tensor.

        Returns:
            tuple: Usually a tuple of classification scores and bbox prediction
                cls_scores (list[Tensor]): Classification scores for all scale
                    levels, each is a 4D-tensor, the channels number is
                    num_anchors * num_classes.
                bbox_preds (list[Tensor]): Box energies / deltas for all scale
                    levels, each is a 4D-tensor, the channels number is
                    num_anchors * 4.
        """
        return multi_apply(self.forward_single, x, self.scales)

forward函数的作用是得到网络的输出。对于检测任务来说通常是三个输出,分别对应分类分支和回归分支以及置信度分支。

上面函数中X是经过NECK后的向量,经过multi_apply后ATSS返回的是classification scoresbbox prediction以及centerness

流程

在正式讲解之前,先看一下函数调用流程

forward
multi_apply
forward_single

通过上图我们可以清晰的看到multi_apply调用了forward_single函数,从而得到我们想要的输出。那multi_apply是怎么调用的呢?

前置知识

  1. python中的partial作用
  2. python中的map函数作用
  3. python中的zip函数作用

调用过程解析

在看本章节的时候,务必确保自己懂了上面的前置知识。

def multi_apply(func, *args, **kwargs):
    """Apply function to a list of arguments.

    Note:
        This function applies the ``func`` to multiple inputs and
        map the multiple outputs of the ``func`` into different
        list. Each list contains the same type of outputs corresponding
        to different inputs.

    Args:
        func (Function): A function that will be applied to a list of
            arguments

    Returns:
        tuple(list): A tuple containing multiple list, each list contains \
            a kind of returned results by the function
    """
    pfunc = partial(func, **kwargs) if kwargs else func
    map_results = map(pfunc, *args)
    return tuple(map(list, zip(*map_results)))
函数的作用
pfunc = partial(func, **kwargs) if kwargs else func
# func对象是我们传过去的函数,即forward_single,如下图

在这里插入图片描述

map_results = map(pfunc, *args)
# *args 就是我们前面传过来的x和scale
# 本句的作用,调用forward_single得到网络的输出。
return tuple(map(list, zip(*map_results)))
# 将网络的输出按组打包。
# 原始的fpn某一层网络输出是(cls,reg,obj),经过zip之后,五层fpn的输出变为了([cls1, cls2, cls3, cls4, cls5], [reg1, reg2, reg3, reg4, reg5], [obj1, obj2, obj3, obj4, obj5])如下图

在这里插入图片描述
以上就是对MMDetection中的multi_apply的理解,如有疑问欢迎交流。

Logo

为开发者提供学习成长、分享交流、生态实践、资源工具等服务,帮助开发者快速成长。

更多推荐