这是一篇使用增强学习来进行模型搜索的论文。
结构如下图:

 

overview

由于不知道网络的长度和结构,作者使用了一个RNN作为控制器,使用该控制器来产生一串信息,用于构建网络。之后训练该网络,并用网络的accuracy作为reward返回给控制器来更新控制器的参数,达到更优的策略。
其中控制器(RNN)的设计借鉴了sequence to sequence的思想,不同的是它优化的是一个不可微的目标,也就是 网络的accuracy。

方法

CNN

上图展示了如何使用RNN控制器产生一个简单的CNN网络,对于CNN网络的每一层,控制器都会产生一组超参数,当层数达到一个阈值,就会停止。RNN的参数\theta_{t}会通过增强学习算法更新,以得到更好的模型结构。

使用REINFORCE来训练

控制器可以看作agent,控制器产生一组token,也就是超参数,看作agent的action,使用产生的模型在验证集的准确率作为reward。因此,控制器需要优化下面公式:

optimization target


但是

R

是不可微分的,因此不能使用传统的BP算法,在论文中,作者使用了REINFORCE。该算法是增强学习的常用算法之一,算法将agent的policy看作一个函数,通过reward来进行参数的更新,从而实现reward的最优化。并且该算法给出了rewardpolicy参数的导数公式。

REINFORCE

 

它的一个经验近似公式如下:

 

empirical approximation

m是一个batch中的模型个数
T是超参数的个数

由于以上公式会遇到variance过大的问题,可以使用如下带baseline的公式

 

with baseline

分布式训练加速

分布式框架如下图

PS


思路:其中 parameter server共同保存了控制器的所有参数,这些server将参数分发给controller,每一个controller使用得到的参数进行模型的构建,这里由于得到的参数可能不同,构建模型的策略是随机的,导致每次构建的网络结构也会不同。每个controller会构建一个batch,也就是

m

个网络,然后并行地训练这些网络,得到它们的accuracy。也就是说,每一个controller会得到一个batch也就是

m

个网络,和它们的accuracy,然后根据之前提到的公式,计算参数的梯度。接着,计算完梯度的controller会将梯度发送给servers。这些server在得到梯度后,分别对自己负责的参数进行更新。更新后,当controller再次训练时,会得到更新后的参数。这里如果每个controller各自发送自己的梯度,之间不进行同步,就是异步更新。

 

skip connection and other layer types

为了能够在搜索空间中加入类似于resnet,inception的skip connection。作者设计了anchor,用于表示是否和前面几个层进行连接。如下

 

anchor

相对应的,agent的action选择如下图:

 

agent action

它会根据前面算出的概率P和自己的策略,判断是否加入connection。
最终,所有没有后续连接的层都会被连接到输出层,如果连接的两个层大小不一致,就将小的层用0来填充。(pad with zeros)

产生RNN网络cell

为了产生RNN网络cell,类似于LSTM,作者使用了一种树的结构,每一个树的节点都会拥有一个操作(addition, elementwise multiplication, etc.)和一个激活函数(tanh, sigmoid等)。每一个节点的输入,都连接了两个其他节点的输出。为了使用上面描述的方法,作者将每个节点编号,按照顺序预测。如下图:

 

RNN

 

根据预测的结果,将会按照如下方式构建网络:

 

Computation steps

总结

这篇文章将增强学习的算法应用在了模型预测上,并且巧妙地使用RNN来预测参数。总体思路依旧是通过在一个有限的搜索空间进行高效的搜索,来不断提高agent预测的模型的准确率。
note:REINFORCE算法真神奇,能够直接使用一个简单的标量reward来知道agent更新参数的方式。



作者:Junr_0926
链接:https://www.jianshu.com/p/b4fd2d4b96d9
来源:简书
简书著作权归作者所有,任何形式的转载都请联系作者获得授权并注明出处。

Logo

CSDN联合极客时间,共同打造面向开发者的精品内容学习社区,助力成长!

更多推荐