Pytorch 数据和模型存取

本方法总结自《动手学深度学习》(Pytorch版)github项目

  • Pytorch 存储和读取主要依靠 load 和 save 函数
  • 模型存取依靠 load_state_dict() 函数
数据存储与读取
import torch

path = 'p.pth'  # 'p.pt'
a = torch.tensor(1)
torch.save(a, path)
b = torch.load(path)
模型存取
  • 仅存储/加载模型参数
model = net()
state_dict = model.state_dict()  # 模型状态
torch.save(state_dict, path)
model2 = net()
model2.load_state_dict(torch.load(path))
  • 存储/加载整个模型
model = net()
torch.save(model, path)
model2 = torch.load(path)
Logo

瓜分20万奖金 获得内推名额 丰厚实物奖励 易参与易上手

更多推荐