生成式对抗网络即GAN由生成器和判别器组成。原论文中,关于生成器和判别器的损失函数是写成以下形式:
在这里插入图片描述
首先,第一个式子我们不看梯度符号的话即为判别器的损失函数,logD(xi)为判别器将真实数据判定为真实数据的概率,log(1-D(G(zi)))为判别器将生成器生成的虚假数据判定为真实数据的对立面即将虚假数据仍判定为虚假数据的概率。判别器就相当于警察,在鉴别真伪时,必须要保证鉴别的结果真的就是真的假的就是假的,所以判别器的总损失即为二者之和,应当最大化该损失。由于判别器(警察)鉴别真伪的能力随着训练次数的增加越来越高,生成器就要与之“对抗”,生成器就要相应地提高“造假”技术,来迷惑判别器。第二个式子为第一个式子的第二项,含义相同,只不过对于生成器应当最小化该项,生成器当然希望辨别器将虚假数据仍判定为虚假数据的概率越低越好,即将虚假数据误判定为真实数据的概率越大越好,即最大化log(D(G(zi)))损失函数。所以二者相互提高或者减小自身的损失,以不断互相对抗。
GitHub上的Deep Convolutional Generative Adversarial Networks(DCGAN)的损失函数是用nn.BCELoss()来写的,具体如下:

import torch
from torch.autograd import Variable

batch_size = 10

adversarial_loss = nn.BCELoss()

valid = Variable(torch.Tensor(batch_size, 1).fill_(1.0), requires_grad=False)
fake = Variable(torch.Tensor(batch_size, 1).fill_(0.0), requires_grad=False)

# dis:鉴别器 
# gen_imgs:生成器生成图像
# real_imgs:真实图像
g_loss = adversarial_loss(dis(gen_imgs), valid)

real_loss = adversarial_loss(dis(real_imgs), valid)
fake_loss = adversarial_loss(dis(gen_imgs), fake)
d_loss = real_loss + fake_loss

nn.BCELoss的计算公式这里不再赘述,可以查看官方文档,我手写一下代码中的g_lossd_loss:
在这里插入图片描述
在这里插入图片描述
torch中都是最小化损失函数,所以d_loss能理解,而g_loss只不过对原论文中的写法换了一种表述,即最大化D(G(Z))的概率:使得鉴别器将生成器生成的图像鉴别为真的概率越大越好。

在这里插入图片描述
我用pytorch搭建了一个简易的GAN,没用卷积层,只是单纯的全连接层,利用mnist图像作为真实数据,随机生成100维的随机噪声作为生成器的输入,20次迭代的最终结果如上图,可以看出GAN多多少少能有些真实图像的大概轮廓。

Logo

CSDN联合极客时间,共同打造面向开发者的精品内容学习社区,助力成长!

更多推荐