Onnx模型量化

说明:

首先需要使用torch.onnx.export的方式把pt文件转成对应的onnx文件,这一步不多说。现在有了float32的onnx,加载到内存可能很大,在不考虑精度太大影响但对硬件资源有限的条件下,比如服务器内存很紧张,这些情况可以考虑onnx直接量化成int8。完整代码我将会在最后给出,只需要修改对应输入数据的取值范围和形状,onnx路径也需要修改成你的对应路径即可。
需要特别注意的地方:

  1. input_data_tensor 这个需要对应的你的工程的前处理
  2. “ActivationSymmetric”: False这个值我用的relu函数如果这里设置True会对精度影响很大,一般如果你的激活函数是对称函数可以设置True。
  3. 量化后结果:精度有影响需考虑能否接受,内存占用减小了至少1半,但是在cpu上运行速度完全没减少,没有对应gpu显卡加速不清楚其他人的情况。

代码:

import numpy as np
import onnxruntime
import torch

from onnxruntime import quantization

class QuntizationDataReader(quantization.CalibrationDataReader):
    def __init__(self, torch_ds, batch_size, input_name):

        self.torch_dl = torch.utils.data.DataLoader(torch_ds, batch_size=batch_size, shuffle=True)

        self.input_name = input_name
        self.datasize = len(self.torch_dl)

        self.enum_data = iter(self.torch_dl)

    def to_numpy(self, pt_tensor):
        return pt_tensor.detach().cpu().numpy() if pt_tensor.requires_grad else pt_tensor.cpu().numpy()

    def get_next(self):
        batch = next(self.enum_data, None)
        if batch is not None:
            return {self.input_name: self.to_numpy(batch)}
        else:
            return None

    def rewind(self):
        self.enum_data = iter(self.torch_dl)


if __name__ == '__main__':
    model_fp32_path = './_fp32.onnx'
    model_int8_path = './_int8.onnx'
    input_data = np.random.randint(0, 101, size=(32, 3, 160, 160))
    input_data_tensor = torch.tensor(input_data / 255.0, dtype=torch.float32)  # normalize()

    session = onnxruntime.InferenceSession(model_fp32_path)
    input_name = session.get_inputs()[0].name
    qdr = QuntizationDataReader(input_data_tensor, batch_size=8, input_name=input_name)
    q_static_opts = {"ActivationSymmetric": False,
                     "WeightSymmetric": True}
    # q_static_opts = {"ActivationSymmetric": True,
    #                  "WeightSymmetric": True}
    quantized_model = quantization.quantize_static(model_input=model_fp32_path,
                                                   model_output=model_int8_path,
                                                   calibration_data_reader=qdr,
                                                   extra_options=q_static_opts)



Logo

一起探索未来云端世界的核心,云原生技术专区带您领略创新、高效和可扩展的云计算解决方案,引领您在数字化时代的成功之路。

更多推荐