GAN损失函数

  • 对抗网络中生成器的目的是尽可能使生成样本分布拟合真实样本分布。
  • 鉴别器目的是尽可能鉴别输入样本来自于真实的还是生成的。
  • 大家都知道GAN的优化目标函数如下:
    在这里插入图片描述
  • 但其参数到底是如何优化的呢?答案是交替迭代优化;如下图所示:
    在这里插入图片描述
    • 图(a):固定G参数不变,优化D的参数,即最大化 m a x V ( D , G ) maxV(D,G) maxV(D,G)等价于 m i n [ − V ( D , G ) ] min[-V(D,G)] min[V(D,G)]。因此,D的损失函数等价如下:
      在这里插入图片描述
    • 鉴别器认为来自真实数据样本的标签为1而来自生成样本的标签为0。因此,其优化过程是类似Sigmoid 的二分类,即sigmoid的交叉熵。
    • Tensorflow中的交叉熵是用tf.nn.sigmoid_entropy_with_logits(logits,labels)表示。
    • 查看TF的sigmoid交叉熵API可帮助理解:
      • x = logits表示鉴别器输出特征, z = labels表示对应的标签;则交叉熵表示为 z ∗ − l o g ( s i g m o i d ( x ) ) + ( 1 − z ) ∗ − l o g ( 1 − s i g m o i d ( x ) ) z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x)) zlog(sigmoid(x))+(1z)log(1sigmoid(x)).
      • 推导如下:
        z ∗ − l o g ( s i g m o i d ( x ) ) + ( 1 − z ) ∗ − l o g ( 1 − s i g m o i d ( x ) ) z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x)) zlog(sigmoid(x))+(1z)log(1sigmoid(x))
        = z ∗ − l o g ( 1 / ( 1 + e − x ) ) + ( 1 − z ) ∗ − l o g ( e − x / ( 1 + e − x ) ) = z * -log(1 / (1 +e^{-x})) + (1 - z) * -log(e^{-x} / (1 +e^{-x})) =zlog(1/(1+ex))+(1z)log(ex/(1+ex))
        = z ∗ l o g ( 1 + e − x ) + ( 1 − z ) ∗ ( − l o g ( e − x ) + l o g ( 1 + e − x ) ) = z * log(1 + e^{-x}) + (1 - z) * (-log(e^{-x}) + log(1 + e^{-x})) =zlog(1+ex)+(1z)(log(ex)+log(1+ex))
        = z ∗ l o g ( 1 + e − x ) + ( 1 − z ) ∗ ( x + l o g ( 1 + e − x ) ) = z * log(1 + e^{-x}) + (1 - z) * (x + log(1 + e^{-x})) =zlog(1+ex)+(1z)(x+log(1+ex))
        = ( 1 − z ) ∗ x + l o g ( 1 + e − x ) = (1 - z) * x + log(1 + e^{-x}) =(1z)x+log(1+ex)
        = x − x ∗ z + l o g ( 1 + e − x ) = x - x * z + log(1 + e^{-x}) =xxz+log(1+ex)
        x<0,可进一步化简为:
        = l o g ( e x ) − x ∗ z + l o g ( 1 + e − x ) = log(e^x) - x * z + log(1 + e^{-x}) =log(ex)xz+log(1+ex)
        = − x ∗ z + l o g ( 1 + e x ) = - x * z + log(1 + e^{x}) =xz+log(1+ex)
      • The logistic loss formula from above is x - x * z + log(1 + exp(-x))
      • For x < 0, a more numerically stable formula is -x * z + log(1 + exp(x))
      • Note that these two expressions can be combined into the following:max(x, 0) - x * z + log(1 + exp(-abs(x)))
    • z=1时,真实样本对应的损失为: − l o g ( s i g m o i d ( x ) ) = l o g ( e − x + 1 ) = l o g ( e x + 1 ) − x -log(sigmoid(x))=log(e^{-x}+1)=log(e^x+1)-x log(sigmoid(x))=log(ex+1)=log(ex+1)x.
    • z=0时,生成样本对应的损失为: − l o g ( 1 − s i g m o i d ( x ) ) = x + l o g ( e − x + 1 ) = l o g ( e x + 1 ) -log(1-sigmoid(x))=x+log(e^{-x}+1)=log(e^x+1) log(1sigmoid(x))=x+log(ex+1)=log(ex+1).其中 s o f t p l u s ( x ) = l o g e ( 1 + e x ) softplus(x)=log_e(1+e^x) softplus(x)=loge(1+ex).
      在这里插入图片描述
    • 由于JS散度具有非负性,当两者分布相等时,其散度为零。因此,D(x)训练得越好,G(z)就越接近最优,则生成器的损失越接近于生成样本分布和真实样本分布的JS 散度。
  • GAN网络算法流程如下表:
    在这里插入图片描述
  • 实际上,式(2-6)可能并没有提供足够的梯度来更新G 的参数。训练初期, 由于G 没有得到较好的训练,生成样本很差,D 会以高置信度的概率来拒绝初期生成的样本,导致log(1−D(G(z)))达到饱和,无法提供足够的梯度来更新 G。于是,采用最大化log(D(G(z)))来代替最小化log(1−D(G(z)))更新 G的参数。
  • tensorflow框架下的GAN的损失代码如下:
# the first term of discriminator loss of real sample:-log[D(x)]
d_loss_real = tf.reduce_mean(tf.nn.sigmoid_entropy_with_logits(logits=D_real_logits,labels=tf.ones_like(D));
# the second term of discriminator loss of fake sample:-log[1-D(G(z))]
d_loss_fake = tf.reduce_mean(tf.nn.sigmoid_entropy_with_logits(logits=D_fake_logits,labels=tf.zeros_like(D));
# D_fake_logits是鉴别器对生成器生成样本提取的特征 D(G(z))
d_loss = d_loss_real + d_loss_fake ;
# -log[D(G(z))]
g_loss = tf.reduce_mean(tf.nn.sigmoid_entropy_with_logits(logits=D_fake_logits,labels=tf.ones_like(D));
  • D表示对应维度大小为batchsize的标签
Logo

旨在为数千万中国开发者提供一个无缝且高效的云端环境,以支持学习、使用和贡献开源项目。

更多推荐