将神经网络训练好之后,如何保存它呢,保存它之后有如何提取它呢?

如下图所示,net1是训练好的神经网络,有两种方式保存它:1.保存整个训练好的神经网络,2.保存神经网络的最终参数

net2是根据第1种方式保存的。net2是根据第2种方式保存的

源代码:

# 引入模块
import torch
import torch.nn.functional as f
from torch.autograd import Variable
import matplotlib.pyplot as plt


# 生成一些假数据
x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1)  # 神经网络只能接受二维数据的输入
y = pow(x, 2) + 0.2*torch.rand(x.size())  # 后半部分制造噪音
x, y = Variable(x), Variable(y)  # 训练神经网络时只能接受Variable形式输入


# 定义保存函数
def save():
    net1 = torch.nn.Sequential(
        torch.nn.Linear(1, 10),
        torch.nn.ReLU(),
        torch.nn.Linear(10, 1)
)
    optimizer = torch.optim.SGD(net1.parameters(), lr=0.5)
    loss_func = torch.nn.MSELoss()
    for i in range(100):
        prediction = net1(x)  # 喂数据x给net1
        loss = loss_func(prediction, y)
        optimizer.zero_grad()  # 将上面运算过程中的grad清零
        loss.backward()  # 误差反向传递
        optimizer.step()  # 将新参数作用于神经网络
    
    # 绘图
    plt.figure(figsize=(10, 3))  # 设置图像的大小
    plt.subplot(131)
    plt.title('net1', color='red', size=20)
    plt.scatter(x.data.numpy(), y.data.numpy())
    plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
    plt.text(0.3, 0, 'loss=%.4f' % loss, fontdict={'color': 'red', 'size': 10})
    
    # 保存net1的两种方式
    torch.save(net1, 'net1.pkl')  # 方式1:保存整个神经网络
    torch.save(net1.state_dict(), 'net1_parameters.pkl')  # 方式2:保存神经网络的参数      


# 定义提取整个神经网络的函数
def restore_net():
    net2 = torch.load('net1.pkl')  # 加载文件net1.pkl, 将其内容赋值给net2
    prediction = net2(x)
    loss_func = torch.nn.MSELoss()
    loss = loss_func(prediction, y)
    
    # 绘制net2结果图形
    plt.subplot(132)
    plt.title('net2', color='red', size=20)
    plt.scatter(x.data.numpy(), y.data.numpy())
    plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
    plt.text(0.3, 0, 'loss=%.4f' % loss, fontdict={'color': 'red', 'size': 10})


# 定义提取神经网络状态参数的函数
def restore_net_parameters():
    net3 = torch.nn.Sequential(
        torch.nn.Linear(1, 10),
        torch.nn.ReLU(),
        torch.nn.Linear(10, 1)
)  # 构造net3的基本框架
    net3.load_state_dict(torch.load('net1_parameters.pkl'))  # 提取net1的状态参数,将状态参数给net3
    prediction = net3(x)
    loss_func = torch.nn.MSELoss()
    loss = loss_func(prediction, y)
    
    # 绘制net3结果图形
    plt.subplot(133)
    plt.title('net3', color='red', size=20)
    plt.scatter(x.data.numpy(), y.data.numpy())
    plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
    plt.text(0.3, 0, 'loss=%.4f' % loss, fontdict={'color': 'red', 'size': 10})


# 调用函数
save()
restore_net()
restore_net_parameters()
plt.show()  # 将三个函数绘制的图形显示出来

注意:将plt.show()放置在最后,能显示出三幅图像连在一起的。若在每个定义的函数的后面均加上plt.show(),三幅图像是分开显示的,无法连成一个整体。

Logo

瓜分20万奖金 获得内推名额 丰厚实物奖励 易参与易上手

更多推荐