作为科研领域十分重要的计算工具,MATLAB在深度学习方面也一直与时俱进,每一个版本的更新都会引进许多新的机器学习和深度学习案例。

下面介绍将训练好的网络进行保存的方法,当再次调用网络时,可以在前一次训练的基础上进一步训练或者直接处理新数据,从而节省时间,提高效率。也可以将网络直接用于新数据的学习和处理,而不需要重新训练数据。

在python深度学习中,也存在“断点续训”这种类似的操作,将训练得到的网络参数以txt文本的形式保存下来,并在之后的运行中通过导入参数实现网络的继续训练。DL with python(9)——TensorFlow实现神经网络模型的断点续训对此进行了简单的介绍和实现。

MATLAB中实现网络模型的保存和调用十分简单,具体实现可以参考以下代码中的两句代码

save('net.mat','net');       % 将网络net保存为.mat文件,后面可直接调用
load('net.mat');     % 导入之前保存的网络

其中,save函数将保存一个net.mat文件在当前工作文件夹,其中是经过训练的网络;load函数导入该文件,并在后面的代码中直接调用网络net。

完整的程序如下,实现的是RBF神经网络三分类问题。

%% 网络的构建和训练
% 训练数据,输入为9*3的矩阵,9个输入,带有3个特征
data = [10 0 0;
        10 0 1;
        10 1 0;
        2 10 0;
        2 10 1;
        2 11 0;
        5 0 10;
        5 0 11;
        5 1 10];
% 从输出目标可以看到输入分为3类
target = [1;1;1;2;2;2;3;3;3];

% 利用数据构建RBF神经网络并训练
net = newrb(data',target');  % 注意矩阵的转置
save('net.mat','net');       % 将网络net保存为.mat文件,后面可直接调用

% 查看效果
y = sim(net,data');  % 网络对输入进行运算得到输出y
y=round(y);          % 将输出y的近似值作为分类结果
performance = sum(target==y')/size(target,1)  % 计算网络输出和实际输出的对应程度

%% 测试训练后的模型
load('net.mat');     % 导入之前保存的网络
testdata = [10 0 0]; % 给出一个新的数据
y = sim(net,testdata'); % 利用训练后的网络对新数据进行分类
y=round(y);             % 得到分类结果

Logo

旨在为数千万中国开发者提供一个无缝且高效的云端环境,以支持学习、使用和贡献开源项目。

更多推荐