trainNetwork - Matlab官网介绍的中文版
trainNetwork训练神经网络进行深度学习原地址 https://www.mathworks.com/help/deeplearning/ref/trainnetwork.html几种调用方法net = trainNetwork(imds,layers,options)net = trainNetwork(ds,layers,options)net = trainNetwork(X,Y,la
trainNetwork训练神经网络进行深度学习
原地址 https://www.mathworks.com/help/deeplearning/ref/trainnetwork.html
几种调用方法
描述
使用trainNetwork
训练卷积神经网络(ConvNet,CNN),长短期记忆(LSTM)网络,或双向LSTM(BiLSTM)网络的深度学习分类和回归的问题。您可以在CPU或GPU上训练网络。对于图像分类和图像回归,您可以使用多个GPU或并行进行训练。使用GPU,多GPU和并行选项需要Parallel Computing Toolbox™。要使用深层学习GPU,你还必须有一个CUDA ®启用NVIDIA ® GPU计算能力3.0或更高版本。使用指定培训选项,包括用于执行环境的选项trainingOptions
。
为分类和回归问题训练网络。预测变量必须位于的第一列中net
= trainNetwork(tbl
,responseName
,layers
,options
)tbl
。该responseName
参数指定在响应变量tbl
。
例子
- 图像分类训练网络
将数据作为ImageDatastore
对象加载。
digitDatasetPath = fullfile(matlabroot,'toolbox','nnet',...
'nndemos','nndatasets','DigitDataset');
imds = imageDatastore(digitDatasetPath,...
'IncludeSubfolders',true,...
'LabelSource','foldernames');
数据存储区包含10,000个从0到9的数字合成图像。这些图像是通过对使用不同字体创建的数字图像应用随机转换而生成的。每个数字图像为28 x 28像素。数据存储区每个类别包含相等数量的图像。
显示数据存储中的某些图像。
figure
numImages = 10000;
perm = randperm(numImages,20);
for i = 1:20
subplot(4,5,i);
imshow(imds.Files{perm(i)});
end
指定卷积神经网络架构。对于回归问题,请在网络末端包括一个回归层。
layers = [ ...
imageInputLayer([28 28 1])
convolution2dLayer(5,20)
reluLayer
maxPooling2dLayer(2,'Stride',2)
fullyConnectedLayer(10)
softmaxLayer
classificationLayer];
指定网络训练选项。将初始学习速率设置为0.001。
options = trainingOptions('sgdm',...
'InitialLearnRate',0.001,...
'Verbose',false,...
'Plots','training-progress');
训练网络。
net = trainNetwork(imdsTrain,layers,options);
通过评估测试数据的预测准确性来测试网络的性能。使用predict
预测验证图像的旋转角度。
[XTest,〜,YTest] = digitTest4DArrayData;
YPred =predict(net,XTest);
通过计算预测旋转角和实际旋转角的均方根误差(RMSE)来评估模型的性能。
rmse = sqrt(mean((YTest-YPred)。^ 2))
rmse = single
6.0655
序列分类训练网络
查看MATLAB命令
训练用于序列到标签分类的深度学习LSTM网络。
如[1]和[2]中所述加载日语元音数据集。XTrain
是包含270个长度可变且特征尺寸为12的序列的单元格数组。Y
是标签1,2,...,9的分类向量。中的条目XTrain
是具有12行(每个要素一行)和不同列数(每个时间步长一列)的矩阵。
[XTrain,YTrain] = japaneseVowelsTrainData;
可视化图中的第一个时间序列。每行对应一个特征。
数字
情节(XTrain {1}')
标题(“训练观察1”)
numFeatures = size(XTrain {1},1);
图例(“ Feature” + string(1:numFeatures),'Location','northeastoutside')
定义LSTM网络体系结构。将输入大小指定为12(输入数据的特征数)。指定一个LSTM层,使其具有100个隐藏单元并输出序列的最后一个元素。最后,通过包括大小为9的完全连接的层,其后是softmax层和分类层,来指定九个类。
inputSize = 12;
numHiddenUnits = 100;
numClasses = 9;
层数= [ ...
sequenceInputLayer(inputSize)
lstmLayer(numHiddenUnits,'OutputMode','last')
fullyConnectedLayer(numClasses)
softmaxLayer
分类图层]
层数=
具有层的5x1层阵列:
1英寸序列输入序列输入具有12个尺寸
2英寸LSTM LSTM具有100个隐藏单元
3英寸全连接9个全连接层
4英寸Softmax softmax
5''分类输出交叉熵
指定训练选项。将求解器指定为'adam'
和'GradientThreshold'
。1.将小批量大小设置为27,并将最大纪元数设置为100。
由于小批量生产的序列短,因此CPU更适合训练。设置'ExecutionEnvironment'
到'cpu'
。要在GPU上进行训练(如果有),请设置'ExecutionEnvironment'
为'auto'
(默认值)。
maxEpochs = 100;
miniBatchSize = 27;
options = trainingOptions('adam',...
'ExecutionEnvironment','cpu',...
'MaxEpochs',maxEpochs,...
'MiniBatchSize',miniBatchSize,...
'GradientThreshold',1,...
'详细'',false,...
``情节'',``培训进度'');
使用指定的培训选项来培训LSTM网络。
net = trainNetwork(XTrain,YTrain,图层,选项);
加载测试集并将序列分类为扬声器。
[XTest,YTest] = japaneseVowelsTestData;
分类测试数据。指定用于训练的相同的小批量大小。
YPred = classify(net,XTest,'MiniBatchSize',miniBatchSize);
计算预测的分类准确性。
acc = sum(YPred == YTest)./ numel(YTest)
acc = 0.9541
输入参数
ds
— 数据存储数据
存储
数据存储,用于内存不足数据和预处理。
对于只有一个输入的网络,数据存储区返回的表或单元格数组有两列,分别指定了网络输入和期望的响应。
对于具有多个输入的网络,数据存储区必须是组合或转换后的数据存储区,该数据存储区将返回具有(numInputs
+1)列的单元格数组,其中包含预测变量和响应,其中 numInputs
是网络输入 numResponses
的数量,是响应的数量。对于i
小于或等于的值,单元阵列numInputs
的i
第th个元素对应于input layers.InputNames(i)
,其中 layers
是定义网络体系结构的层图。单元格数组的最后一列对应于响应。
下表列出了直接与兼容的数据存储 trainNetwork
。您可以使用transform
和combine
函数将其他内置数据存储区用于训练深度学习网络。这些函数可以将从数据存储中读取的数据转换为所需的表或单元格数组格式 trainNetwork
。有关更多信息,请参阅用于深度学习的数据存储。
数据存储类型 | 描述 |
---|---|
CombinedDatastore | 水平串联从两个或多个基础数据存储读取的数据。 |
TransformedDatastore | 根据您自己的预处理管道,转换来自底层数据存储的批量读取数据。 |
AugmentedImageDatastore | 应用随机仿射几何变换,包括调整大小,旋转,反射,剪切和平移,以训练深度神经网络。 |
PixelLabelImageDatastore | 将相同的仿射几何变换应用于图像和相应的地面真相标签,以训练语义分割网络(需要Computer Vision Toolbox™)。 |
RandomPatchExtractionDatastore | 从图像或像素标签图像中提取成对的随机色块(需要Image Processing Toolbox™)。您可以选择将相同的随机仿射几何变换应用于面片对。 |
DenoisingImageDatastore | 将随机生成的高斯噪声应用于训练降噪网络(需要“图像处理工具箱”)。 |
定制小批量数据存储 | 创建序列,时间序列或文本数据的迷你批。有关详细信息,请参阅开发自定义微型批处理数据存储。 |
sequences
— 数字数组的序列或时间序列数据
单元格数组 | 数值数组 | 数据存储
序列或时间序列数据,指定为N乘1的数字数组单元格数组,其中N是观察数,代表单个序列的数字数组或数据存储。
对于单元格数组或数字数组输入,包含序列的数字数组的维数取决于数据类型。
输入项 | 描述 |
---|---|
矢量序列 | Ç -by- 小号矩阵,其中 Ç是的序列的特征的数量和š是序列长度。 |
二维图像序列 | h- by- w - c- by- s 数组,其中h,w和 c分别对应于图像的高度,宽度和通道数,而s是序列长度。 |
3-D图像序列 | ħ -by- 瓦特 -by- d -by- Ç -by- 小号,其中ħ,瓦特, d,和Ç对应的高度,宽度,深度和3-d的图像,分别的通道数,和s是序列长度。 |
对于数据存储区输入,数据存储区必须以序列的单元格数组或第一列包含序列的表的形式返回数据。序列数据的尺寸必须与上表相对应。
Y
— 响应
标签的分类向量 | 数值数组 | 分类序列的单元格数组 | 数字序列的单元格数组
响应,指定为标签的分类向量,数字数组,分类序列的单元格数组或数字序列的单元格数组。的格式Y
取决于任务的类型。响应中不得包含NaN
。
分类
任务 | 格式 |
---|---|
图片分类 | 标签的N ×1分类向量,其中N是观察数。 |
序列到标签分类 | |
序列到序列分类 | 标签分类序列的N ×1单元格数组,其中 N是观察数。在将 |
对于一个观察到的序列到序列分类问题, sequences
也可以是向量。在这种情况下, Y
必须是标签的分类序列。
回归
任务 | 格式 |
---|---|
二维图像回归 |
|
3-D图像回归 |
|
序列一回归 | N × R矩阵,其中N是序列数,R是响应数。 |
序列到序列回归 | 数字序列的N ×1单元格数组,其中N 是序列数。序列是具有R行的矩阵 ,其中R是响应数。在将 |
对于只有一个观察值的逐序列回归问题, sequences
可以将其作为矩阵。在这种情况下, Y
必须是响应矩阵。
标准化响应通常有助于稳定和加速训练神经网络以进行回归。有关更多信息,请参阅 训练卷积神经网络进行回归。
tbl
— 输入数据
table
输入数据,指定为包含第一列中的预测变量和其余列中的响应的表。表格中的每一行都对应一个观察值。
表列中预测变量和响应的排列方式取决于问题的类型。
分类
任务 | 预测变量 | 回应 |
---|---|---|
图片分类 |
| 分类标签 |
序列到标签分类 | 包含序列或时间序列数据的MAT文件的绝对或相对文件路径。 MAT文件必须包含一个由矩阵表示的时间序列,该矩阵具有与数据点相对应的行和与时间步长相对应的列。 | 分类标签 |
序列到序列分类 | MAT文件的绝对或相对文件路径。MAT文件必须包含一个由分类向量表示的时间序列,并且每个时间步的标签均对应于其条目。 |
对于分类问题,如果您未指定 responseName
,则该函数默认使用的第二列中的响应tbl
。
回归
任务 | 预测变量 | 回应 |
---|---|---|
图像回归 |
|
|
序列一回归 | 包含序列或时间序列数据的MAT文件的绝对或相对文件路径。 MAT文件必须包含一个由矩阵表示的时间序列,该矩阵具有与数据点相对应的行和与时间步长相对应的列。 |
|
序列到序列回归 | MAT文件的绝对或相对文件路径。MAT文件必须包含一个由矩阵表示的时间序列,其中行对应于响应,列对应于时间步长。 |
对于回归问题,如果不指定 responseName
,则该函数默认使用的其余列tbl
。标准化响应通常有助于稳定和加速训练神经网络以进行回归。有关更多信息,请参阅训练卷积神经网络进行回归。
响应中不能包含NaN
。如果预测变量数据包含NaN
,则它们将通过训练传播。但是,在大多数情况下,培训无法收敛。
资料类型: table
layers
— 网络层
Layer
阵列 | LayerGraph
目的
网络层,指定为Layer
数组或LayerGraph
对象。
要创建依次连接所有层的网络,可以使用Layer
数组作为输入参数。在这种情况下,返回的网络是一个SeriesNetwork
对象。
有向无环图(DAG)网络具有复杂的结构,其中各层可以具有多个输入和输出。要创建DAG网络,请将网络体系结构指定为LayerGraph
对象,然后将该层图用作的输入参数 trainNetwork
。
options
— 培训选项
TrainingOptionsSGDM
| TrainingOptionsRMSProp
|TrainingOptionsADAM
培训选项,指定为TrainingOptionsSGDM
, TrainingOptionsRMSProp
或者 TrainingOptionsADAM
对象通过返回的trainingOptions
功能。要指定求解器和其他用于网络训练的选项,请使用 trainingOptions
。
输出参数
net
—训练有素的网络
SeriesNetwork
对象| DAGNetwork
目的
经过训练的网络,作为SeriesNetwork
对象或DAGNetwork
对象返回。
如果使用Layer
数组作为 layers
输入参数来训练网络,则它 net
是一个SeriesNetwork
对象。如果使用LayerGraph
对象作为输入参数来训练网络,则 net
该DAGNetwork
对象为对象。
info
—培训信息
结构
训练信息,以结构形式返回,其中每个字段是标量或数字向量,每个训练迭代具有一个元素。
对于分类问题,info
包含以下字段:
-
TrainingLoss
—损失函数值 -
TrainingAccuracy
-训练精度 -
ValidationLoss
—损失函数值 -
ValidationAccuracy
—验证准确性 -
BaseLearnRate
—学习率 -
FinalValidationLoss
—最终验证损失 -
FinalValidationAccuracy
—最终验证准确性
对于回归问题,info
包含以下字段:
-
TrainingLoss
—损失函数值 -
TrainingRMSE
—训练RMSE值 -
ValidationLoss
—损失函数值 -
ValidationRMSE
—验证RMSE值 -
BaseLearnRate
—学习率 -
FinalValidationLoss
—最终验证损失 -
FinalValidationRMSE
—最终验证RMSE
结构只包含的字段ValidationLoss
, ValidationAccuracy
,ValidationRMSE
,FinalValidationLoss
, FinalValidationAccuracy
和 FinalValidationRMSE
在options
指定的验证数据。所述'ValidationFrequency'
的选择trainingOptions
确定哪些迭代软件将计算验证指标。对于软件未计算验证指标的迭代,结构中的对应值为NaN
。
如果您的网络包含批处理规范化层,则最终验证指标通常与培训期间评估的验证指标不同。这是因为最终网络中的批处理归一化层执行的操作与训练期间不同。
更多关于
保存检查点网络并继续培训
深度学习工具箱™使您可以在训练期间的每个时期之后将网络另存为.mat文件。当您拥有大型网络或大型数据集并且训练需要很长时间时,这种定期保存特别有用。如果培训由于某种原因而中断,则可以从上次保存的检查点网络恢复培训。如果要 trainNetwork
保存检查点网络,则必须使用的'CheckpointPath'
名称/值对参数指定路径的名称trainingOptions
。如果指定的路径不存在,则 trainingOptions
返回错误。
trainNetwork
自动为检查点网络文件分配唯一的名称。在示例名称中 net_checkpoint__351__2018_04_12__18_09_52.mat
,351是迭代编号,2018_04_12
日期和保存网络18_09_52
的时间trainNetwork
。您可以通过双击或在命令行中使用load命令来加载检查点网络文件。例如:
<span style="color:#404040"><span style="color:inherit">加载net_checkpoint__351__2018_04_12__18_09_52.mat</span></span>
trainNetwork
。例如:
<span style="color:#404040"><span style="color:inherit">trainNetwork(XTrain,YTrain,net.Layers,options)</span></span>
浮点运算
深度学习工具箱中用于深度学习训练,预测和验证的所有功能都使用单精度浮点算术执行计算。深学习功能包括trainNetwork
,predict
, classify
,和 activations
。当您同时使用CPU和GPU训练网络时,该软件使用单精度算术。
参考资料
[1] Kudo,M.,J。Toyama和M.Shimbo。“使用通过区域的多维曲线分类”。 模式识别字母。卷 20,第11-13号,第1103-1111页。
[2] Kudo,M.,J。Toyama和M.Shimbo。日本元音数据集。https://archive.ics.uci.edu/ml/datasets/Japanese+Vowels
扩展功能
自动并行支持通过使用Parallel Computing Toolbox™自动并行
运行计算来加速代码。
更多推荐
所有评论(0)