ONNX介绍

ONNX是一种针对机器学习所设计的开放式的文件格式,用于存储训练好的模型。它使得不同的深度学习框架(如Pytorch, MXNet)可以采用相同格式存储模型数据。简而言之,ONNX是一种便于在各个主流深度学习框架中迁移模型的中间表达格式。

ONNX与Protobuf

ONNX采用序列化数据结构协议protobuf来存储模型信息。我们可以通过protobuf 自己设计一种数据结构的协议,然后是用各种语言去读取或者写入。ONNX中采用Onnx.proto定义ONNX的数据协议规则和一些其他的信息。同样,也可以借助protobuf来解析Onnx模型。

ONNX数据结构

ONNX中主要定义了以下六种数据结构:
Onnx数据结构

ONNX模型解析

ONNX模型解析流程:

  1. 读取.onnx文件,获得model结构;
  2. 通过model结构访问graph结构;
  3. 通过graph访问整个网络的所有node以及inputs、outputs;
  4. 通过node结构,可以获取每一个OP的参数信息。

其中,graph结构中还定义了initializer和value_info,分别存放了模型的权重参数和每个节点的输出信息。

ONNX模型构建及推理

ONNX模型构建流程 :

  1. 根据网络结构make_node创建相关节点,节点的inputs和outputs参数决定了graph的连接情况;
  2. 利用定义好的节点,make_graph生成计算图;
  3. 利用graph make_model;
  4. Check_model、Save_model。

模型构建完毕后,利用Onnxruntime推理模型。

以下是一个只有一层Maxpool的Onnx模型构建和推理过程:

"""
create on 2021-4-7 22:48
Maxpool layer test in Onnx
@author: yang
"""

import onnx
import numpy as np
from onnx import helper
from onnx import TensorProto
import onnxruntime as rt

def network_construct():

    X = helper.make_tensor_value_info('X', TensorProto.FLOAT, [1, 1, 3, 3])
    Y = helper.make_tensor_value_info('Y', TensorProto.FLOAT, [1, 1, 2, 2])

    # Make MaxPool Node
    node_def = onnx.helper.make_node(
        'MaxPool',
        inputs=['X'],
        outputs=['Y'],
        kernel_shape=[2, 2],
        strides=[2, 2],
        pads=[1, 1, 1, 1]   # Top、Left、Bottom、Right
    )

    # Make Graph
    graph_def = helper.make_graph(
        [node_def],
        'test-MaxPool',
        [X],
        [Y]
    )

    # Make model
    model_def = helper.make_model(
        graph_def,
        producer_name='yang'
    )

    # Check & Save Model
    onnx.checker.check_model(model_def)
    onnx.save(model_def, 'MaxPool.onnx')

def  model_infer():

    # Infer Model
    sess = rt.InferenceSession('MaxPool.onnx')

    input_name = sess.get_inputs()[0].name
    output_name = sess.get_outputs()[0].name

    input_data = [[[[1, 2, 3], [4, 5, 6], [7, 8, 9]]]]
    input_data = np.array(input_data, dtype=np.float32)

    result = sess.run([output_name], {input_name: input_data})
    print(result)

def main():
    network_construct()
    model_infer()

if __name__ == '__main__':
    main()

ONNX模型修改

在实际使用过程中,笔者需要对Onnx模型进行拆分、删除或者修改某个特定节点。下面记录了笔者对模型节点拆分和节点修改的过程,使用需结合模型具体情况。

节点删除:

"""
create on 2021-4-7 22:48
To Del some unwanted nodes in the OnnxModel
@author: yang
"""

import onnx
import onnx.helper as helper

def Del_node():
    model_path = "Phase4_85_135.onnx"
    model = onnx.load_model(model_path)
    onnx.checker.check_model(model_path)

    # Del Nodes List
    node_name = ["Constant", "Shape", "Gather", "GlobalAveragePool", "Unsqueeze", "Concat", "Reshape", "Gemm", "Sigmoid", "Mul"]
    node_output = ['581', '595']

    graph = model.graph
    print("Node num:", len(graph.node))

    for node in graph.node:
        if str(node.output)[2:5] in node_output:
            print(str(node.output)[2:5])
            model.graph.node.remove(node)

    for i in range(len(graph.node)-1, -1, -1):
        if graph.node[i].op_type in node_name:
            graph.node.remove(graph.node[i])

    # New Output
    model.graph.output.pop(0)
    model.graph.output.append(helper.make_tensor_value_info("568", 1, (1, 2048, 8, 6)))
    model.graph.output[0].name = "568"

    model.graph.output.append(helper.make_tensor_value_info("594", 1, (1, 2048, 8, 6)))
    model.graph.output[1].name = "594"

    print("After del Node num:", len(model.graph.node))
    onnx.save_model(model, "del_model.onnx")

def main():
    Del_Nodes()

if __name__ == '__main__':
    main()

节点修改:

"""
create on 2021-4-7 22:48
Modify the ONet OnnxModel in MTCNN
Change dropout layer to reshape layer
@author: yang
"""

import onnx
import numpy as np
from onnx import helper
from onnx import TensorProto
from onnx import shape_inference

def modify_model():
    model = onnx.load_model('../ONet/ONet.onnx')
    model.producer_name = "yang"

    graph = model.graph
    node = graph.node
    initializer = graph.initializer

    # Get the dropout node index
    index = 0
    for each in range(len(node)):
        if node[each].name == 'drop5':
            index = each

    # Make reshape input tensor
    ShapeTensor = helper.make_tensor(
        'ShapeTensor',
        TensorProto.INT64,
        [4],
        [1, 256, 1, 1]
    )

    # Make reshape node
    reshape_node = helper.make_node(
        'Reshape',
        inputs=['conv5_Gemm_Y', 'ShapeTensor'],
        outputs=['prelu5_reshape_Y'],
        name='prelu5_Reshape',
    )

    # Modify next node input
    node[index+1].input[0] = 'prelu5_reshape_Y'

    # Remove & insert node
    graph.node.remove(node[index])
    graph.node.insert(index, reshape_node)

    # Append initializer
    initializer.append(ShapeTensor)

    # Make graph
    graph = helper.make_graph(node, graph.name, graph.input, graph.output, initializer)

    # Make model
    info_model = helper.make_model(graph)

    # Add valueinfo
    model = shape_inference.infer_shapes(info_model)
    
    onnx.checker.check_model(model)
    onnx.save_model(model, 'ONet_m.onnx')

def main():
    modify_model()

if __name__ == '__main__':
    main()
Logo

旨在为数千万中国开发者提供一个无缝且高效的云端环境,以支持学习、使用和贡献开源项目。

更多推荐