Pytorch 学习(九):Pytorch 数据和模型存取
Pytorch 网络模型创建本方法总结自《动手学深度学习》(Pytorch版)github项目Pytorch 存储和读取主要依靠 load 和 save 函数模型存取依靠 load_state_dict() 函数数据存储与读取import torchpath = 'p.pth'# 'p.pt'a = torch.tensor(1)torch.save(a, path)b = torch.load(
·
Pytorch 数据和模型存取
本方法总结自《动手学深度学习》(Pytorch版)github项目
- Pytorch 存储和读取主要依靠 load 和 save 函数
- 模型存取依靠 load_state_dict() 函数
数据存储与读取
import torch
path = 'p.pth' # 'p.pt'
a = torch.tensor(1)
torch.save(a, path)
b = torch.load(path)
模型存取
- 仅存储/加载模型参数
model = net()
state_dict = model.state_dict() # 模型状态
torch.save(state_dict, path)
model2 = net()
model2.load_state_dict(torch.load(path))
- 存储/加载整个模型
model = net()
torch.save(model, path)
model2 = torch.load(path)
更多推荐
已为社区贡献1条内容
所有评论(0)