YOLOv8剪枝

前言

手写AI推出的全新模型剪枝与重参课程。记录下个人学习笔记,仅供自己参考。

本次课程主要讲解YOLOv8剪枝。

课程大纲可看下面的思维导图

在这里插入图片描述

1.Overview

YOLOV8剪枝的流程如下:

在这里插入图片描述

结论:在VOC2007上使用yolov8s模型进行的实验显示,预训练和约束训练在迭代50个epoch后达到了相同的mAP(:0.5)值,约为0.77。剪枝后,微调阶段需要65个epoch才能达到相同的mAP50。修建后的ONNX模型大小从43M减少到36M。

注意:我们需要将网络结构和网络权重区分开来,YOLOv8的网络结构来自yaml文件,如果我们进行剪枝后保存的权重文件的结构其实是和原始的yaml文件不符合的,需要对yaml文件进行修改满足我们的要求。

2.Pretrain(option)

步骤如下:

  • git clone https://github.com/ultralytics/ultralytics.git
  • use VOC2007, and modify the VOC.yaml(去除VOC2012的相关内容)
  • disable amp(禁用amp混合精度)
# FILE: ultralytics/yolo/engine/trainer.py
...
def check_amp(model):
    # Avoid using mixed precision to affect finetune
    return False # <============== modified(修改部分)
    device = next(model.parameters()).device  # get model device
    if device.type in ('cpu', 'mps'):
        return False  # AMP only used on CUDA devices

    def amp_allclose(m, im):
        # All close FP32 vs AMP results
    ...

3.Constrained Training

约束训练是为了筛选哪些channel比较重要,哪些channel没有那么重要,也就是我们上节课所说的稀疏训练

  • prune the BN layer by adding L1 regularizer.
# FILE: ultralytics/yolo/engine/trainer.py
...
# Backward
self.scaler.scale(self.loss).backward()

# <============ added(新增)
l1_lambda = 1e-2 * (1 - 0.9 * epoch / self.epochs)
for k, m in self.model.named_modules():
    if isinstance(m, nn.BatchNorm2d):
        m.weight.grad.data.add_(l1_lambda * torch.sign(m.weight.data))
        m.bias.grad.data.add_(1e-2 * torch.sign(m.bias.data))

# Optimize - https://pytorch.org/docs/master/notes/amp_examples.html
if ni - last_opt_step >= self.accumulate:
    self.optimizer_step()
    last_opt_step = ni
...

注意1:在剪枝时,我们选择加载last.pt而非best.pt,因为由于迁移学习,模型的泛化性比较好,在第一个epoch时mAP值最大,但这并不是真实的,我们需要稳定下来的一个模型进行prune

注意2:我们在对Conv层进行剪枝时,我们只考虑1v1(如BottleNeck,C2f and SPPF)、1vm(如Backbone,Detect)的情形,并不考虑mv1的情形。

思考:Constrained Training的必要性?

约束训练可以使得模型更易于剪枝。在约束训练中,模型会学习到一些通道或者权重系数比较不重要的信息,而这些信息在剪枝过程中得到应用,从而达到模型压缩的效果。而如果直接进行剪枝操作,可能会出现一些问题,比如剪枝后的模型精度大幅下降、剪枝不均匀等。因此,在进行剪枝操作之前,通过稀疏训练的方式,可以更好地准确地确定哪些通道或者权重系数可以被剪掉,从而避免上述问题的发生。

4.Prune

4.1 检查BN层的bias

  • 剪枝后,确保BN层的大部分bias足够小(接近于0),否则重新进行稀疏训练
for name, m in model.named_modules():
    if isinstance(m, torch.nn.BatchNorm2d):
        w = m.weight.abs().detach()
        b = m.bias.abs().detach()
        ws.append(w)
        bs.append(b)
        print(name, w.max().item(), w.min().item(), b.max().item(), b.min().item())

4.2 设置阈值和剪枝率

  • threshold:全局或局部
  • factor:保持率,裁剪太多不推荐
factor = 0.8
ws = torch.cat(ws)
threshold = torch.sort(ws, descending=True)[0][int(len(ws) * factor)]
print(threshold)

4.3 最小剪枝Conv单元的TopConv

Top-Bottom Conv如下图所示:

在这里插入图片描述

TopConv剪枝的示例代码如下:

def prune_conv(conv1: Conv, conv2: Conv):
    gamma = conv1.bn.weight.data.detach()
    beta  = conv1.bn.bias.data.detach()

    keep_idxs = []    
    local_threshold = threshold

    while len(keep_idxs) < 8:
        keep_idxs = torch.where(gamma.abs() >= local_threshold)[0] 
        local_threshold = local_threshold * 0.5

    n = len(keep_idxs)
    print(n / len(gamma) * 100)  # 打印我们保留了多少的channel
    
    # prune
    conv1.bn.weight.data = gamma[keep_idxs]
    conv1.bn.bias.data   = beta[keep_idxs]
    conv1.bn.running_mean.data = conv1.bn.running_mean.data[keep_idxs]
    conv1.bn.running_var.data  = conv1.bn.running_var.data[keep_idxs]
    conv1.bn.num_features   = n
    conv1.conv.weight.data  = conv1.conv.weight.data[keep_idxs]
    conv1.conv.out_channels = n

    if conv1.conv.bias is not None:
        conv1.conv.bias.data = conv1.conv.bias.data[keep_idxs]

# pattern to prune
# 1. prune all 1 vs 1 TB pattern e.g. bottleneck
for name, m in model.named_modules():
    if isinstance(m, Bottleneck):
        prune_conv(m.cv1, m.cv2)

注意:由于NVIDIA的硬件加速的原因,我们保留的channels应该大于等于8,我们可以通过设置local_threshold,尽量小点,让更多的channel保留下来。

4.4 最小剪枝Conv单元的BottomConv

BottomConv剪枝的示例代码如下:

def prune_conv(conv1: Conv, conv2: Conv):
    ...
    if not isinstance(conv2, list):
        conv2 = [conv2]
    
    for item in conv2:
        if item is not None:
            if isinstance(item, Conv):
                conv = item.conv
            else:
                conv = item
            conv.in_channels = n
            conv.weight.data = conv.weight.data[:, keep_idxs]

注意BottomConv存在两种情形,一种是单个Conv,还有一种是Conv列表。TopConv需要考虑conv2d+bn+act,而BottomConv只需要考虑conv2d

4.5 Seq剪枝

Seq剪枝的示例代码如下:

def prune(m1, m2):
    if isinstance(m1, C2f):
        m1 = m1.cv2
    
    if not isinstance(m2, list):
        m2 = [m2]
    
    for i, item in enumerate(m2):
        if isinstance(item, C2f) or isinstance(item, SPPF):
            m2[i] = item.cv1
    
    prune_conv(m1, m2)

# 2. prune sequential
seq = model.model
for i in range(3, 9):
    if i in [6, 4, 9]: continue
    prune(seq[i], seq[i+1])

注意:我们不考虑1vm的情形,因此在4,6,9的module我们是不进行剪枝的,后续head进行Concat时是对4,6,9的module进行拼接的。考虑到前几个conv的特征提取的重要性,我们也不剪枝它们。(那感觉没剪几个module呀😂)

4.6 Detect-FPN剪枝

Detect-FPN剪枝的示例代码如下:

# 3. prune FPN related block
detect: Detect = seq[-1]

last_inputs = [seq[15], seq[18], seq[21]]
colasts     = [seq[16], seq[19], None]

for last_input, colast, cv2, cv3 in zip(last_inputs, colasts, detect.cv2, detect.cv3):
    prune(last_input, [colast, cv2[0], cv3[0]])
    prune(cv2[0], cv2[1])
    prune(cv2[1], cv2[2])
    prune(cv3[0], cv3[1])
    prune(cv3[1], cv3[2])

for name, p in yolo.model.named_parameters():
    p.requires_grad = True

注意:一定要设置所有参数为需要训练。因为加载后的model会给弄成False,导致报错

4.7 完整示例代码

完整的示例代码如下:

from ultralytics import YOLO
import torch
from ultralytics.nn.modules import Bottleneck, Conv, C2f, SPPF, Detect

# Load a model
yolo = YOLO("epoch-50-constrained_weights/last.pt")  # build a new model from scratch
model = yolo.model

ws = []
bs = []

for name, m in model.named_modules():
    if isinstance(m, torch.nn.BatchNorm2d):
        w = m.weight.abs().detach()
        b = m.bias.abs().detach()
        ws.append(w)
        bs.append(b)
        print(name, w.max().item(), w.min().item(), b.max().item(), b.min().item())
# keep
factor = 0.8
ws = torch.cat(ws)
threshold = torch.sort(ws, descending=True)[0][int(len(ws) * factor)]
print(threshold)


def prune_conv(conv1: Conv, conv2: Conv):
    gamma = conv1.bn.weight.data.detach()
    beta  = conv1.bn.bias.data.detach()
    # if gamma.abs().min() > f or beta.abs().min() > 0.1:
    #     return
    
    # idxs = torch.argsort(gamma.abs() * coeff + beta.abs(), descending=True)
    keep_idxs = []
    local_threshold = threshold
    while len(keep_idxs) < 8:
        keep_idxs = torch.where(gamma.abs() >= local_threshold)[0]
        local_threshold = local_threshold * 0.5
    n = len(keep_idxs)
    # n = max(int(len(idxs) * 0.8), p)
    print(n / len(gamma) * 100)
    # scale = len(idxs) / n
    conv1.bn.weight.data = gamma[keep_idxs]
    conv1.bn.bias.data   = beta[keep_idxs]
    conv1.bn.running_var.data = conv1.bn.running_var.data[keep_idxs]
    conv1.bn.running_mean.data = conv1.bn.running_mean.data[keep_idxs]
    conv1.bn.num_features = n
    conv1.conv.weight.data = conv1.conv.weight.data[keep_idxs]
    conv1.conv.out_channels = n
    
    if conv1.conv.bias is not None:
        conv1.conv.bias.data = conv1.conv.bias.data[keep_idxs]

    if not isinstance(conv2, list):
        conv2 = [conv2]
        
    for item in conv2:
        if item is not None:
            if isinstance(item, Conv):
                conv = item.conv
            else:
                conv = item
            conv.in_channels = n
            conv.weight.data = conv.weight.data[:, keep_idxs]
    
def prune(m1, m2):
    if isinstance(m1, C2f):      # C2f as a top conv
        m1 = m1.cv2
    
    if not isinstance(m2, list): # m2 is just one module
        m2 = [m2]
        
    for i, item in enumerate(m2):
        if isinstance(item, C2f) or isinstance(item, SPPF):
            m2[i] = item.cv1
    
    prune_conv(m1, m2)

for name, m in model.named_modules():
    if isinstance(m, Bottleneck):
        prune_conv(m.cv1, m.cv2)
        
seq = model.model
for i in range(3, 9):
    if i in [6, 4, 9]: continue
    prune(seq[i], seq[i+1])
    
detect:Detect = seq[-1]
last_inputs   = [seq[15], seq[18], seq[21]]
colasts       = [seq[16], seq[19], None]
for last_input, colast, cv2, cv3 in zip(last_inputs, colasts, detect.cv2, detect.cv3):
    prune(last_input, [colast, cv2[0], cv3[0]])
    prune(cv2[0], cv2[1])
    prune(cv2[1], cv2[2])
    prune(cv3[0], cv3[1])
    prune(cv3[1], cv3[2])

# ***step4,一定要设置所有参数为需要训练。因为加载后的model他会给弄成false。导致报错
# pipeline:
# 1. 为模型的BN增加L1约束,lambda用1e-2左右
# 2. 剪枝模型,比如用全局阈值
# 3. finetune,一定要注意,此时需要去掉L1约束。最终final的版本一定是去掉的
for name, p in yolo.model.named_parameters():
    p.requires_grad = True
    
# 1. 不能剪枝的layer,其实可以不用约束
# 2. 对于低于全局阈值的,可以删掉整个module
# 3. keep channels,对于保留的channels,他应该能整除n才是最合适的,否则硬件加速比较差
#    n怎么选,一般fp16时,n为8
#                int8时,n为16
#    cp.async.cg.shared
#

yolo.val()
# yolo.export(format="onnx")
# yolo.train(data="VOC.yaml", epochs=100)
print("done")

5.YOLOv8剪枝总结

关于yolov8剪枝有以下几点值得注意:

Pipeline:

    1. 为模型的BN增加L1约束,lambda用1e-2左右
    1. 剪枝模型使用的是全局阈值
    1. finetune模型时,一定要注意,此时需要去掉L1约束,最终的final的版本一定是去掉的(ultralytics/yolo/engine/trainer.py中注释)
    1. 对于yolo.model.named_parameters()循环,需要设置p.requires_gradTrue

Future work:

    1. 不能剪枝的layer,其实可以不用约束
    1. 对于低于全局阈值的,可以删掉整个module
    1. keep channels,对于保留的channels,它应该能整除n才是最合适的,否则硬件加速比较差
  • n怎么选呢?一般fp16时,n为8;int8时,n为16

总结

本次课程学习了YOLOv8的剪枝,主要是对前面剪枝课程的一个总结和实现吧,大体流程就是稀疏训练后进行剪枝最后微调,看着虽然简单,但实际细节把控还是非常多的,比如说哪些部分好剪,哪些部分不好剪,剪枝的过程中如何通过model获取想要prune的module等等,需要对YOLOv8整体网络结构和对ONNX模型的操作非常熟练,这还只是基础理论,实操部分的坑还没踩呢,在之后好好练习练习吧😄

Logo

旨在为数千万中国开发者提供一个无缝且高效的云端环境,以支持学习、使用和贡献开源项目。

更多推荐