torch转onnx

import torch
from models import Net
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

pth_path='epoch-best.pth' 
model_demo = Net()  # 构造模型,创建新模型,网络实例
loaded_model = torch.load(pth_path)  # 加载模型参数
model_demo.load_state_dict(loaded_model['state_dict'])  #将模型参数加载到构造的新建模型实例model_demo中,需要创建的model_demo模型和加载模型的结构、参数名称、参数维度相同,不同时,可选择加载相同部分参数
model_demo = model_demo.to(device)
model_demo.eval()

dummy_input = torch.randn(1, 3, 320, 320,device=device)  # input data  
torch.onnx.export(model_demo, dummy_input, 'epoch-best.onnx', verbose=False, input_names=input_names, output_names=output_names)   # torch2onnx
Logo

欢迎加入我们的广州开发者社区,与优秀的开发者共同成长!

更多推荐