【人工智能】【深度学习】 ② GAN核心算法介绍:生成器与判别器的博弈艺术
本文为【人工智能】【深度学习】系列第②篇,深入浅出讲解生成对抗网络(GAN)的核心原理。从“造假者 vs 警察”的大白话比喻出发,逐步推导极小极大博弈公式,详解生成器与判别器的对抗机制,并对比 CNN、RNN 的本质差异。附带完整 PyTorch 代码实现(逐行注释),涵盖 DCGAN 架构、训练流程与工程技巧,适合零基础读者入门生成式 AI。结尾预告下一期 Transformer 主题,延续系统
📖目录
前言
一句话导读:如果 AI 能“以假乱真”地生成人脸、绘画甚至视频,背后大概率是 GAN 在工作。本文从零讲起,用大白话拆解 GAN 的对抗思想,手把手推导核心公式,并附带 PyTorch 完整代码实现,带你掌握这一“造假大师”的底层逻辑。
1. GAN 是什么?——一场永不停歇的猫鼠游戏
想象两个角色:
- 生成器(Generator, G):一个伪造货币的罪犯,目标是印出连专家都看不出的假钞。
- 判别器(Discriminator, D):一个经验丰富的警察,目标是分辨手中钞票是真是假。
GAN 的训练过程,就是这两人不断博弈:
- G 不断改进造假技术;
- D 不断提升鉴伪能力;
- 最终,G 造出的假钞逼真到 D 无法区分——此时 G 就学会了真实数据的分布!
✅ 大白话总结:GAN 不是直接学“怎么画人脸”,而是通过“骗过一个越来越聪明的评委”,间接学会生成逼真样本。
2. 核心思想:GAN 为什么是“对抗”?公式是怎么来的?
很多人看到 GAN 的目标函数就头疼:
min G max D V ( D , G ) = E x ∼ p data ( x ) [ log D ( x ) ] + E z ∼ p z ( z ) [ log ( 1 − D ( G ( z ) ) ) ] \min_G \max_D V(D, G) = \mathbb{E}_{x \sim p_{\text{data}}(x)}[\log D(x)] + \mathbb{E}_{z \sim p_z(z)}[\log(1 - D(G(z)))] GminDmaxV(D,G)=Ex∼pdata(x)[logD(x)]+Ez∼pz(z)[log(1−D(G(z)))]
别急!我们不用一上来就啃公式,而是从人类直觉出发,一步步推导它为什么长这样。
2.1 大白话推导:从“造假 vs 打假”说起
假设你是一个假钞制造者(生成器 G),你的目标是什么?
→ 让警察(判别器 D)无法分辨你造的是假钞。
而警察的目标呢?
→ 一眼看出谁是真钞、谁是假钞。
于是,两人开始博弈:
- 如果警察太菜(D 弱),你随便印一张都能蒙混过关 → 你赢了,但学不到真本事;
- 如果警察太强(D 强),你印的全被识破 → 你压力山大,必须拼命改进技术;
- 最理想状态:你造的假钞连顶级专家都看不出破绽 → 此时,你就真正掌握了“真钞”的全部特征!
✅ 关键洞察:GAN 不是直接学“真钞长什么样”,而是通过“骗过一个越来越聪明的评委”,间接学会真实数据的分布。
2.2 公式怎么来的?一步步拆解
现在,我们把上面的故事翻译成数学语言。
2.2.1 第一步:定义判别器 D 的任务
D 是一个分类器,输入一张图 x x x,输出它是“真实的”概率 D ( x ) ∈ [ 0 , 1 ] D(x) \in [0,1] D(x)∈[0,1]。
- 对真实图像 x ∼ p data x \sim p_{\text{data}} x∼pdata,我们希望 D ( x ) → 1 D(x) \to 1 D(x)→1
- 对伪造图像 G ( z ) G(z) G(z),我们希望 D ( G ( z ) ) → 0 D(G(z)) \to 0 D(G(z))→0
所以,D 的损失应该让:
- log D ( x ) \log D(x) logD(x) 越大越好(因为 D ( x ) → 1 D(x) \to 1 D(x)→1 时 log D ( x ) → 0 \log D(x) \to 0 logD(x)→0,但负得少)
- log ( 1 − D ( G ( z ) ) ) \log(1 - D(G(z))) log(1−D(G(z))) 越大越好(因为 D ( G ( z ) ) → 0 D(G(z)) \to 0 D(G(z))→0 时该项也 → 0 \to 0 →0)
于是,D 的目标是 最大化 这两个期望之和:
max D { E x [ log D ( x ) ] + E z [ log ( 1 − D ( G ( z ) ) ) ] } \max_D \left\{ \mathbb{E}_{x}[\log D(x)] + \mathbb{E}_{z}[\log(1 - D(G(z)))] \right\} Dmax{Ex[logD(x)]+Ez[log(1−D(G(z)))]}
2.2.2 第二步:定义生成器 G 的任务
G 的目标正好相反:它希望 D 把假图当成真图,即 D ( G ( z ) ) → 1 D(G(z)) \to 1 D(G(z))→1。
这意味着 log ( 1 − D ( G ( z ) ) ) → log ( 0 ) → − ∞ \log(1 - D(G(z))) \to \log(0) \to -\infty log(1−D(G(z)))→log(0)→−∞ —— 这对 G 来说是个巨大惩罚!
所以 G 要最小化这个惩罚项,等价于:
min G E z [ log ( 1 − D ( G ( z ) ) ) ] \min_G \mathbb{E}_{z}[\log(1 - D(G(z)))] GminEz[log(1−D(G(z)))]
2.2.3 第三步:合二为一 ,极小极大博弈
把两者放在一起,就得到 Goodfellow 在 2014 年提出的经典目标函数:
min G max D V ( D , G ) = E x ∼ p data [ log D ( x ) ] ⏟ D 想让真图得分高 + E z ∼ p z [ log ( 1 − D ( G ( z ) ) ) ] ⏟ D 想让假图得分低,G 想让它得分高 \boxed{ \min_G \max_D V(D, G) = \underbrace{\mathbb{E}_{x \sim p_{\text{data}}}[\log D(x)]}_{\text{D 想让真图得分高}} + \underbrace{\mathbb{E}_{z \sim p_z}[\log(1 - D(G(z)))]}_{\text{D 想让假图得分低,G 想让它得分高}} } GminDmaxV(D,G)=D 想让真图得分高 Ex∼pdata[logD(x)]+D 想让假图得分低,G 想让它得分高 Ez∼pz[log(1−D(G(z)))]
💡 训练技巧补充:
实践中发现,早期 G 很弱, D ( G ( z ) ) ≈ 0 D(G(z)) \approx 0 D(G(z))≈0,导致 log ( 1 − D ( G ( z ) ) ) \log(1 - D(G(z))) log(1−D(G(z))) 的梯度几乎为零(梯度消失)。
因此,训练 G 时常用替代损失:最大化 log D ( G ( z ) ) \log D(G(z)) logD(G(z))(即最小化 − log D ( G ( z ) ) -\log D(G(z)) −logD(G(z))),让梯度更稳定。
2.3 GAN 能用在哪儿?典型应用场景
| 领域 | 应用案例 | 说明 |
|---|---|---|
| 图像生成 | StyleGAN 生成人脸 | 可生成超写实、无版权的人脸图像 |
| 图像翻译 | CycleGAN:马 ↔ 斑马 | 无需配对数据,实现风格/域转换 |
| 数据增强 | 医疗影像合成 | 为罕见病生成更多训练样本 |
| 超分辨率 | SRGAN | 将模糊图像重建为高清细节 |
| 艺术创作 | AI 绘画、音乐生成 | 辅助创意工作者 |
✅ 核心优势:GAN 是无监督生成模型,不需要标签就能学习数据分布,特别适合“创造新内容”。
2.4 对比 CNN 与 RNN:三大网络的本质差异
| 特性 | CNN(卷积神经网络) | RNN(循环神经网络) | GAN(生成对抗网络) |
|---|---|---|---|
| 核心目标 | 识别/分类(判别式) | 序列建模/预测(判别式) | 生成新样本(生成式) |
| 输入输出 | 图像 → 类别/检测框 | 序列 → 序列/类别 | 噪声 → 逼真数据(如图像) |
| 是否需要标签 | 是(监督学习) | 通常是(如语言模型) | 否(无监督生成) |
| 典型结构 | 卷积+池化+全连接 | 循环单元(如 LSTM) | 生成器 + 判别器(对抗) |
| 代表应用 | 人脸识别、目标检测 | 机器翻译、语音识别 | 人脸生成、图像修复、风格迁移 |
| 能否“创造”? | ❌ 只能判断已有内容 | ❌ 只能预测下一个词 | ✅ 能凭空生成全新内容 |
📌 一句话总结:
- CNN 看图识物,
- RNN 记忆上下文,
- GAN 无中生有。
3. 网络架构:谁是谁?
| 模块 | 输入 | 输出 | 典型结构 |
|---|---|---|---|
| 生成器 G | 随机噪声 z ∈ R 100 z \in \mathbb{R}^{100} z∈R100 | 假图像 G ( z ) ∈ R 3 × 64 × 64 G(z) \in \mathbb{R}^{3 \times 64 \times 64} G(z)∈R3×64×64 | 全连接 → 转置卷积(ConvTranspose)堆叠 |
| 判别器 D | 图像 x x x(真 or 假) | 概率标量 D ( x ) ∈ [ 0 , 1 ] D(x) \in [0,1] D(x)∈[0,1] | 卷积层堆叠 → 全连接 → Sigmoid |
🖼️ 架构图示意:
z → [G] → fake_img → [D] → probreal_img → [D] → prob
4. 完整代码示例:PyTorch 实现 DCGAN(深度卷积 GAN)
# 导入必要库
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
# ----------------------------
# 1. 定义生成器 Generator 网络
# 功能:将随机噪声向量 z 映射为逼真的假图像
# 输入维度: (batch_size, 100)
# 输出维度: (batch_size, 3, 64, 64) —— RGB 图像
# ----------------------------
class Generator(nn.Module):
def __init__(self):
# 调用父类初始化方法
super(Generator, self).__init__()
# 第一阶段:全连接层 + 批归一化 + ReLU
# 将 100 维噪声扩展为 1024*4*4 = 16384 维,为后续上采样做准备
self.fc = nn.Sequential(
nn.Linear(100, 1024 * 4 * 4), # 线性变换:100 → 16384
nn.BatchNorm1d(1024 * 4 * 4), # 批归一化,加速训练并提升稳定性
nn.ReLU(True) # 激活函数,引入非线性
)
# 第二阶段:转置卷积(反卷积)堆叠,逐步上采样至 64x64
self.deconv = nn.Sequential(
# 层1: (1024, 4, 4) → (512, 8, 8)
# kernel_size=4, stride=2, padding=1 实现 2 倍上采样
nn.ConvTranspose2d(1024, 512, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(512), # 2D 批归一化
nn.ReLU(True),
# 层2: (512, 8, 8) → (256, 16, 16)
nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(True),
# 层3: (256, 16, 16) → (128, 32, 32)
nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(True),
# 层4: (128, 32, 32) → (3, 64, 64),输出 3 通道 RGB 图像
nn.ConvTranspose2d(128, 3, kernel_size=4, stride=2, padding=1),
nn.Tanh() # Tanh 将输出压缩到 [-1, 1],与归一化真实图像匹配
)
def forward(self, z):
# 前向传播:输入噪声 z
x = self.fc(z) # 全连接层处理
x = x.view(-1, 1024, 4, 4) # 重塑为 4D 张量 (batch, channels, height, width)
return self.deconv(x) # 通过转置卷积生成图像
# ----------------------------
# 2. 定义判别器 Discriminator 网络
# 功能:判断输入图像是真实的还是生成器伪造的
# 输入维度: (batch_size, 3, 64, 64)
# 输出维度: (batch_size, 1) —— 表示“真实”的概率
# ----------------------------
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
# 卷积层堆叠,逐步下采样至 4x4 特征图
self.conv = nn.Sequential(
# 层1: (3, 64, 64) → (64, 32, 32)
# 使用 LeakyReLU 避免梯度消失(尤其在判别器中)
nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1),
nn.LeakyReLU(0.2, inplace=True),
# 层2: (64, 32, 32) → (128, 16, 16)
nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(128), # 批归一化提升稳定性
nn.LeakyReLU(0.2, inplace=True),
# 层3: (128, 16, 16) → (256, 8, 8)
nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(256),
nn.LeakyReLU(0.2, inplace=True),
# 层4: (256, 8, 8) → (512, 4, 4)
nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(512),
nn.LeakyReLU(0.2, inplace=True),
)
# 最终分类头:展平 + 全连接 + Sigmoid
self.fc = nn.Sequential(
nn.Flatten(), # 将 (512, 4, 4) 展平为 (512*16,)
nn.Linear(512 * 4 * 4, 1), # 映射到单个输出
nn.Sigmoid() # 输出 [0,1] 概率
)
def forward(self, x):
# 前向传播:输入图像 x
x = self.conv(x) # 卷积特征提取
return self.fc(x) # 分类输出
# ----------------------------
# 3. 训练配置与流程(核心逻辑)
# ----------------------------
# 自动选择 GPU 或 CPU 设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 实例化生成器和判别器,并移至设备
G = Generator().to(device)
D = Discriminator().to(device)
# 定义损失函数:二元交叉熵(Binary Cross Entropy)
criterion = nn.BCELoss()
# 优化器:使用 Adam,学习率 0.0002,动量参数 (0.5, 0.999) —— DCGAN 推荐设置
optimizer_G = optim.Adam(G.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D = optim.Adam(D.parameters(), lr=0.0002, betas=(0.5, 0.999))
# 标签定义:真实样本标签为 1.0,伪造样本标签为 0.0
real_label = 1.0
fake_label = 0.0
# 假设 dataloader 已定义(此处省略数据加载代码)
# num_epochs 为训练轮数
# 开始训练循环
for epoch in range(num_epochs):
for i, (real_imgs, _) in enumerate(dataloader):
batch_size = real_imgs.size(0) # 获取当前批次大小
real_imgs = real_imgs.to(device) # 将真实图像移至 GPU/CPU
# ---------------------
# 步骤1: 训练判别器 D
# 目标:最大化 D 对真实/伪造样本的判别能力
# ---------------------
optimizer_D.zero_grad() # 清空判别器梯度
# 用真实图像训练 D:希望 D(real_imgs) → 1
label = torch.full((batch_size,), real_label, device=device) # 创建全 1 标签
output = D(real_imgs).view(-1) # 判别器输出,展平为 1D
loss_D_real = criterion(output, label) # 计算真实样本损失
# 生成假图像
noise = torch.randn(batch_size, 100, device=device) # 从标准正态分布采样噪声
fake_imgs = G(noise) # 生成器生成假图像
# 用假图像训练 D:希望 D(fake_imgs) → 0
label.fill_(fake_label) # 将标签填充为 0
# 注意:使用 .detach() 阻断梯度回传到生成器 G,仅更新 D
output = D(fake_imgs.detach()).view(-1)
loss_D_fake = criterion(output, label) # 计算伪造样本损失
# 判别器总损失 = 真实损失 + 伪造损失
loss_D = loss_D_real + loss_D_fake
loss_D.backward() # 反向传播计算梯度
optimizer_D.step() # 更新判别器参数
# ---------------------
# 步骤2: 训练生成器 G
# 目标:最小化 log(1 - D(G(z))),等价于最大化 log(D(G(z)))
# 即:让 D 认为假图是真实的
# ---------------------
optimizer_G.zero_grad() # 清空生成器梯度
# 重新设置标签为 1(欺骗 D)
label.fill_(real_label)
# 注意:此处不使用 detach(),梯度需回传至 G
output = D(fake_imgs).view(-1) # D 对假图的判别结果
loss_G = criterion(output, label) # G 的损失:希望 D 输出接近 1
loss_G.backward() # 反向传播(梯度流经 D 和 G)
optimizer_G.step() # 更新生成器参数
✅ 代码亮点:
- 严格遵循 DCGAN 论文推荐架构(转置卷积、BatchNorm、LeakyReLU/Tanh)
- 交替训练策略:每步先更新 D,再更新 G,防止一方过强导致训练崩溃
- 梯度控制:在训练 D 时对 fake_imgs 使用 .detach(),避免不必要的梯度计算
- 标签平滑(可选改进):可将 real_label 设为 0.9 而非 1.0,提升鲁棒性
5. GAN 的挑战与改进
| 问题 | 表现 | 解决方案 |
|---|---|---|
| 模式崩溃(Mode Collapse) | G 只生成少数几种样本 | WGAN(Wasserstein GAN)、Unrolled GAN |
| 训练不稳定 | D 过强导致 G 梯度消失 | 特征匹配、谱归一化(Spectral Norm) |
| 评估困难 | 无明确损失衡量生成质量 | Inception Score (IS)、FID 分数 |
📌 WGAN 核心改进:将原始 JS 散度替换为 Wasserstein 距离,损失函数变为:
min G max ∥ f ∥ L ≤ 1 E x ∼ p data [ f ( x ) ] − E z ∼ p z [ f ( G ( z ) ) ] \min_G \max_{\|f\|_L \leq 1} \mathbb{E}_{x \sim p_{\text{data}}}[f(x)] - \mathbb{E}_{z \sim p_z}[f(G(z))] Gmin∥f∥L≤1maxEx∼pdata[f(x)]−Ez∼pz[f(G(z))]
其中 f f f 为 1-Lipschitz 函数,通过权重裁剪或梯度惩罚实现。
6. 应用场景:GAN 能做什么?
- 图像生成:StyleGAN 生成超写实人脸
- 图像翻译:CycleGAN 实现马↔斑马、照片↔油画
- 数据增强:为医疗影像生成稀缺病灶样本
- 超分辨率:SRGAN 将低清图变高清
7. 结语与预告
GAN 开启了“无监督生成”的新纪元,其对抗思想也深刻影响了强化学习、自监督学习等领域。理解 GAN,是掌握现代生成式 AI 的关键一步。
🔗 延伸阅读:
- 上一篇:【[人工智能][深度学习] ① RNN核心算法介绍:从循环结构到LSTM门控机制](https://blog.csdn.net/xiezhiyi007/article/details/155281772)**
- 再上一篇:【[人工智能]发展历程全景解析:从图灵测试到大模型时代(含CNN、Q-Learning深度实践)](https://blog.csdn.net/xiezhiyi007/article/details/155131043)
💡 下一篇预告:
【人工智能】【深度学习】 ③ Transformer核心算法介绍:自注意力机制如何颠覆AI?
将深入剖析 QKV 计算、多头注意力、位置编码,并对比 RNN/CNN 的优劣,敬请期待!
更多推荐



所有评论(0)