pytorch(四):神经网络的保存与提取
将神经网络训练好之后,如何保存它呢,保存它之后有如何提取它呢?如下图所示,net1是训练好的神经网络,有两种方式保存它:1.保存整个训练好的神经网络,2.保存神经网络的最终参数net2是根据第1种方式保存的。net2是根据第2种方式保存的源代码:# 引入模块import torchimport torch.nn.functional as ffrom torch.aut...
·
将神经网络训练好之后,如何保存它呢,保存它之后有如何提取它呢?
如下图所示,net1是训练好的神经网络,有两种方式保存它:1.保存整个训练好的神经网络,2.保存神经网络的最终参数
net2是根据第1种方式保存的。net2是根据第2种方式保存的
源代码:
# 引入模块
import torch
import torch.nn.functional as f
from torch.autograd import Variable
import matplotlib.pyplot as plt
# 生成一些假数据
x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1) # 神经网络只能接受二维数据的输入
y = pow(x, 2) + 0.2*torch.rand(x.size()) # 后半部分制造噪音
x, y = Variable(x), Variable(y) # 训练神经网络时只能接受Variable形式输入
# 定义保存函数
def save():
net1 = torch.nn.Sequential(
torch.nn.Linear(1, 10),
torch.nn.ReLU(),
torch.nn.Linear(10, 1)
)
optimizer = torch.optim.SGD(net1.parameters(), lr=0.5)
loss_func = torch.nn.MSELoss()
for i in range(100):
prediction = net1(x) # 喂数据x给net1
loss = loss_func(prediction, y)
optimizer.zero_grad() # 将上面运算过程中的grad清零
loss.backward() # 误差反向传递
optimizer.step() # 将新参数作用于神经网络
# 绘图
plt.figure(figsize=(10, 3)) # 设置图像的大小
plt.subplot(131)
plt.title('net1', color='red', size=20)
plt.scatter(x.data.numpy(), y.data.numpy())
plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
plt.text(0.3, 0, 'loss=%.4f' % loss, fontdict={'color': 'red', 'size': 10})
# 保存net1的两种方式
torch.save(net1, 'net1.pkl') # 方式1:保存整个神经网络
torch.save(net1.state_dict(), 'net1_parameters.pkl') # 方式2:保存神经网络的参数
# 定义提取整个神经网络的函数
def restore_net():
net2 = torch.load('net1.pkl') # 加载文件net1.pkl, 将其内容赋值给net2
prediction = net2(x)
loss_func = torch.nn.MSELoss()
loss = loss_func(prediction, y)
# 绘制net2结果图形
plt.subplot(132)
plt.title('net2', color='red', size=20)
plt.scatter(x.data.numpy(), y.data.numpy())
plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
plt.text(0.3, 0, 'loss=%.4f' % loss, fontdict={'color': 'red', 'size': 10})
# 定义提取神经网络状态参数的函数
def restore_net_parameters():
net3 = torch.nn.Sequential(
torch.nn.Linear(1, 10),
torch.nn.ReLU(),
torch.nn.Linear(10, 1)
) # 构造net3的基本框架
net3.load_state_dict(torch.load('net1_parameters.pkl')) # 提取net1的状态参数,将状态参数给net3
prediction = net3(x)
loss_func = torch.nn.MSELoss()
loss = loss_func(prediction, y)
# 绘制net3结果图形
plt.subplot(133)
plt.title('net3', color='red', size=20)
plt.scatter(x.data.numpy(), y.data.numpy())
plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
plt.text(0.3, 0, 'loss=%.4f' % loss, fontdict={'color': 'red', 'size': 10})
# 调用函数
save()
restore_net()
restore_net_parameters()
plt.show() # 将三个函数绘制的图形显示出来
注意:将plt.show()放置在最后,能显示出三幅图像连在一起的。若在每个定义的函数的后面均加上plt.show(),三幅图像是分开显示的,无法连成一个整体。
更多推荐
已为社区贡献1条内容
所有评论(0)