前言

一句话导读:如果 AI 能“以假乱真”地生成人脸、绘画甚至视频,背后大概率是 GAN 在工作。本文从零讲起,用大白话拆解 GAN 的对抗思想,手把手推导核心公式,并附带 PyTorch 完整代码实现,带你掌握这一“造假大师”的底层逻辑。


1. GAN 是什么?——一场永不停歇的猫鼠游戏

想象两个角色:

  • 生成器(Generator, G):一个伪造货币的罪犯,目标是印出连专家都看不出的假钞。
  • 判别器(Discriminator, D):一个经验丰富的警察,目标是分辨手中钞票是真是假。

GAN 的训练过程,就是这两人不断博弈:

  1. G 不断改进造假技术;
  2. D 不断提升鉴伪能力;
  3. 最终,G 造出的假钞逼真到 D 无法区分——此时 G 就学会了真实数据的分布!

大白话总结:GAN 不是直接学“怎么画人脸”,而是通过“骗过一个越来越聪明的评委”,间接学会生成逼真样本。


2. 核心思想:GAN 为什么是“对抗”?公式是怎么来的?

很多人看到 GAN 的目标函数就头疼:

min ⁡ G max ⁡ D V ( D , G ) = E x ∼ p data ( x ) [ log ⁡ D ( x ) ] + E z ∼ p z ( z ) [ log ⁡ ( 1 − D ( G ( z ) ) ) ] \min_G \max_D V(D, G) = \mathbb{E}_{x \sim p_{\text{data}}(x)}[\log D(x)] + \mathbb{E}_{z \sim p_z(z)}[\log(1 - D(G(z)))] GminDmaxV(D,G)=Expdata(x)[logD(x)]+Ezpz(z)[log(1D(G(z)))]

别急!我们不用一上来就啃公式,而是从人类直觉出发,一步步推导它为什么长这样。

2.1 大白话推导:从“造假 vs 打假”说起

假设你是一个假钞制造者(生成器 G),你的目标是什么?
让警察(判别器 D)无法分辨你造的是假钞

而警察的目标呢?
一眼看出谁是真钞、谁是假钞

于是,两人开始博弈:

  • 如果警察太菜(D 弱),你随便印一张都能蒙混过关 → 你赢了,但学不到真本事;
  • 如果警察太强(D 强),你印的全被识破 → 你压力山大,必须拼命改进技术;
  • 最理想状态:你造的假钞连顶级专家都看不出破绽 → 此时,你就真正掌握了“真钞”的全部特征!

关键洞察:GAN 不是直接学“真钞长什么样”,而是通过“骗过一个越来越聪明的评委”,间接学会真实数据的分布。


2.2 公式怎么来的?一步步拆解

现在,我们把上面的故事翻译成数学语言。

2.2.1 第一步:定义判别器 D 的任务

D 是一个分类器,输入一张图 x x x,输出它是“真实的”概率 D ( x ) ∈ [ 0 , 1 ] D(x) \in [0,1] D(x)[0,1]

  • 真实图像 x ∼ p data x \sim p_{\text{data}} xpdata,我们希望 D ( x ) → 1 D(x) \to 1 D(x)1
  • 伪造图像 G ( z ) G(z) G(z),我们希望 D ( G ( z ) ) → 0 D(G(z)) \to 0 D(G(z))0

所以,D 的损失应该让:

  • log ⁡ D ( x ) \log D(x) logD(x) 越大越好(因为 D ( x ) → 1 D(x) \to 1 D(x)1 log ⁡ D ( x ) → 0 \log D(x) \to 0 logD(x)0,但负得少)
  • log ⁡ ( 1 − D ( G ( z ) ) ) \log(1 - D(G(z))) log(1D(G(z))) 越大越好(因为 D ( G ( z ) ) → 0 D(G(z)) \to 0 D(G(z))0 时该项也 → 0 \to 0 0

于是,D 的目标是 最大化 这两个期望之和:
max ⁡ D { E x [ log ⁡ D ( x ) ] + E z [ log ⁡ ( 1 − D ( G ( z ) ) ) ] } \max_D \left\{ \mathbb{E}_{x}[\log D(x)] + \mathbb{E}_{z}[\log(1 - D(G(z)))] \right\} Dmax{Ex[logD(x)]+Ez[log(1D(G(z)))]}

2.2.2 第二步:定义生成器 G 的任务

G 的目标正好相反:它希望 D 把假图当成真图,即 D ( G ( z ) ) → 1 D(G(z)) \to 1 D(G(z))1

这意味着 log ⁡ ( 1 − D ( G ( z ) ) ) → log ⁡ ( 0 ) → − ∞ \log(1 - D(G(z))) \to \log(0) \to -\infty log(1D(G(z)))log(0) —— 这对 G 来说是个巨大惩罚

所以 G 要最小化这个惩罚项,等价于:
min ⁡ G E z [ log ⁡ ( 1 − D ( G ( z ) ) ) ] \min_G \mathbb{E}_{z}[\log(1 - D(G(z)))] GminEz[log(1D(G(z)))]

2.2.3 第三步:合二为一 ,极小极大博弈

把两者放在一起,就得到 Goodfellow 在 2014 年提出的经典目标函数:

min ⁡ G max ⁡ D V ( D , G ) = E x ∼ p data [ log ⁡ D ( x ) ] ⏟ D 想让真图得分高 + E z ∼ p z [ log ⁡ ( 1 − D ( G ( z ) ) ) ] ⏟ D 想让假图得分低,G 想让它得分高 \boxed{ \min_G \max_D V(D, G) = \underbrace{\mathbb{E}_{x \sim p_{\text{data}}}[\log D(x)]}_{\text{D 想让真图得分高}} + \underbrace{\mathbb{E}_{z \sim p_z}[\log(1 - D(G(z)))]}_{\text{D 想让假图得分低,G 想让它得分高}} } GminDmaxV(D,G)=想让真图得分高 Expdata[logD(x)]+想让假图得分低,想让它得分高 Ezpz[log(1D(G(z)))]

💡 训练技巧补充
实践中发现,早期 G 很弱, D ( G ( z ) ) ≈ 0 D(G(z)) \approx 0 D(G(z))0,导致 log ⁡ ( 1 − D ( G ( z ) ) ) \log(1 - D(G(z))) log(1D(G(z))) 的梯度几乎为零(梯度消失)。
因此,训练 G 时常用替代损失:最大化 log ⁡ D ( G ( z ) ) \log D(G(z)) logD(G(z))(即最小化 − log ⁡ D ( G ( z ) ) -\log D(G(z)) logD(G(z))),让梯度更稳定。


2.3 GAN 能用在哪儿?典型应用场景

领域 应用案例 说明
图像生成 StyleGAN 生成人脸 可生成超写实、无版权的人脸图像
图像翻译 CycleGAN:马 ↔ 斑马 无需配对数据,实现风格/域转换
数据增强 医疗影像合成 为罕见病生成更多训练样本
超分辨率 SRGAN 将模糊图像重建为高清细节
艺术创作 AI 绘画、音乐生成 辅助创意工作者

核心优势:GAN 是无监督生成模型,不需要标签就能学习数据分布,特别适合“创造新内容”。


2.4 对比 CNN 与 RNN:三大网络的本质差异

特性 CNN(卷积神经网络) RNN(循环神经网络) GAN(生成对抗网络)
核心目标 识别/分类(判别式) 序列建模/预测(判别式) 生成新样本(生成式)
输入输出 图像 → 类别/检测框 序列 → 序列/类别 噪声 → 逼真数据(如图像)
是否需要标签 是(监督学习) 通常是(如语言模型) (无监督生成)
典型结构 卷积+池化+全连接 循环单元(如 LSTM) 生成器 + 判别器(对抗)
代表应用 人脸识别、目标检测 机器翻译、语音识别 人脸生成、图像修复、风格迁移
能否“创造”? ❌ 只能判断已有内容 ❌ 只能预测下一个词 ✅ 能凭空生成全新内容

📌 一句话总结

  • CNN 看图识物
  • RNN 记忆上下文
  • GAN 无中生有

3. 网络架构:谁是谁?

模块 输入 输出 典型结构
生成器 G 随机噪声 z ∈ R 100 z \in \mathbb{R}^{100} zR100 假图像 G ( z ) ∈ R 3 × 64 × 64 G(z) \in \mathbb{R}^{3 \times 64 \times 64} G(z)R3×64×64 全连接 → 转置卷积(ConvTranspose)堆叠
判别器 D 图像 x x x(真 or 假) 概率标量 D ( x ) ∈ [ 0 , 1 ] D(x) \in [0,1] D(x)[0,1] 卷积层堆叠 → 全连接 → Sigmoid

🖼️ 架构图示意
z → [G] → fake_img → [D] → prob
real_img → [D] → prob


4. 完整代码示例:PyTorch 实现 DCGAN(深度卷积 GAN)

# 导入必要库
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms

# ----------------------------
# 1. 定义生成器 Generator 网络
# 功能:将随机噪声向量 z 映射为逼真的假图像
# 输入维度: (batch_size, 100)
# 输出维度: (batch_size, 3, 64, 64) —— RGB 图像
# ----------------------------

class Generator(nn.Module):
    def __init__(self):
        # 调用父类初始化方法
        super(Generator, self).__init__()
        
        # 第一阶段:全连接层 + 批归一化 + ReLU
        # 将 100 维噪声扩展为 1024*4*4 = 16384 维,为后续上采样做准备
        self.fc = nn.Sequential(
            nn.Linear(100, 1024 * 4 * 4),      # 线性变换:100 → 16384
            nn.BatchNorm1d(1024 * 4 * 4),       # 批归一化,加速训练并提升稳定性
            nn.ReLU(True)                       # 激活函数,引入非线性
        )
        
        # 第二阶段:转置卷积(反卷积)堆叠,逐步上采样至 64x64
        self.deconv = nn.Sequential(
            # 层1: (1024, 4, 4) → (512, 8, 8)
            # kernel_size=4, stride=2, padding=1 实现 2 倍上采样
            nn.ConvTranspose2d(1024, 512, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(512),                # 2D 批归一化
            nn.ReLU(True),
            
            # 层2: (512, 8, 8) → (256, 16, 16)
            nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            
            # 层3: (256, 16, 16) → (128, 32, 32)
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            
            # 层4: (128, 32, 32) → (3, 64, 64),输出 3 通道 RGB 图像
            nn.ConvTranspose2d(128, 3, kernel_size=4, stride=2, padding=1),
            nn.Tanh()                           # Tanh 将输出压缩到 [-1, 1],与归一化真实图像匹配
        )

    def forward(self, z):
        # 前向传播:输入噪声 z
        x = self.fc(z)                          # 全连接层处理
        x = x.view(-1, 1024, 4, 4)              # 重塑为 4D 张量 (batch, channels, height, width)
        return self.deconv(x)                   # 通过转置卷积生成图像


# ----------------------------
# 2. 定义判别器 Discriminator 网络
# 功能:判断输入图像是真实的还是生成器伪造的
# 输入维度: (batch_size, 3, 64, 64)
# 输出维度: (batch_size, 1) —— 表示“真实”的概率
# ----------------------------

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        
        # 卷积层堆叠,逐步下采样至 4x4 特征图
        self.conv = nn.Sequential(
            # 层1: (3, 64, 64) → (64, 32, 32)
            # 使用 LeakyReLU 避免梯度消失(尤其在判别器中)
            nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            
            # 层2: (64, 32, 32) → (128, 16, 16)
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),                # 批归一化提升稳定性
            nn.LeakyReLU(0.2, inplace=True),
            
            # 层3: (128, 16, 16) → (256, 8, 8)
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            
            # 层4: (256, 8, 8) → (512, 4, 4)
            nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
        )
        
        # 最终分类头:展平 + 全连接 + Sigmoid
        self.fc = nn.Sequential(
            nn.Flatten(),                       # 将 (512, 4, 4) 展平为 (512*16,)
            nn.Linear(512 * 4 * 4, 1),          # 映射到单个输出
            nn.Sigmoid()                        # 输出 [0,1] 概率
        )

    def forward(self, x):
        # 前向传播:输入图像 x
        x = self.conv(x)                        # 卷积特征提取
        return self.fc(x)                       # 分类输出


# ----------------------------
# 3. 训练配置与流程(核心逻辑)
# ----------------------------

# 自动选择 GPU 或 CPU 设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 实例化生成器和判别器,并移至设备
G = Generator().to(device)
D = Discriminator().to(device)

# 定义损失函数:二元交叉熵(Binary Cross Entropy)
criterion = nn.BCELoss()

# 优化器:使用 Adam,学习率 0.0002,动量参数 (0.5, 0.999) —— DCGAN 推荐设置
optimizer_G = optim.Adam(G.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D = optim.Adam(D.parameters(), lr=0.0002, betas=(0.5, 0.999))

# 标签定义:真实样本标签为 1.0,伪造样本标签为 0.0
real_label = 1.0
fake_label = 0.0

# 假设 dataloader 已定义(此处省略数据加载代码)
# num_epochs 为训练轮数

# 开始训练循环
for epoch in range(num_epochs):
    for i, (real_imgs, _) in enumerate(dataloader):
        batch_size = real_imgs.size(0)          # 获取当前批次大小
        real_imgs = real_imgs.to(device)        # 将真实图像移至 GPU/CPU
        
        # ---------------------
        # 步骤1: 训练判别器 D
        # 目标:最大化 D 对真实/伪造样本的判别能力
        # ---------------------
        optimizer_D.zero_grad()                 # 清空判别器梯度
        
        # 用真实图像训练 D:希望 D(real_imgs) → 1
        label = torch.full((batch_size,), real_label, device=device)  # 创建全 1 标签
        output = D(real_imgs).view(-1)          # 判别器输出,展平为 1D
        loss_D_real = criterion(output, label)  # 计算真实样本损失
        
        # 生成假图像
        noise = torch.randn(batch_size, 100, device=device)  # 从标准正态分布采样噪声
        fake_imgs = G(noise)                    # 生成器生成假图像
        
        # 用假图像训练 D:希望 D(fake_imgs) → 0
        label.fill_(fake_label)                 # 将标签填充为 0
        # 注意:使用 .detach() 阻断梯度回传到生成器 G,仅更新 D
        output = D(fake_imgs.detach()).view(-1)
        loss_D_fake = criterion(output, label)  # 计算伪造样本损失
        
        # 判别器总损失 = 真实损失 + 伪造损失
        loss_D = loss_D_real + loss_D_fake
        loss_D.backward()                       # 反向传播计算梯度
        optimizer_D.step()                      # 更新判别器参数

        # ---------------------
        # 步骤2: 训练生成器 G
        # 目标:最小化 log(1 - D(G(z))),等价于最大化 log(D(G(z)))
        # 即:让 D 认为假图是真实的
        # ---------------------
        optimizer_G.zero_grad()                 # 清空生成器梯度
        
        # 重新设置标签为 1(欺骗 D)
        label.fill_(real_label)
        # 注意:此处不使用 detach(),梯度需回传至 G
        output = D(fake_imgs).view(-1)          # D 对假图的判别结果
        loss_G = criterion(output, label)       # G 的损失:希望 D 输出接近 1
        loss_G.backward()                       # 反向传播(梯度流经 D 和 G)
        optimizer_G.step()                      # 更新生成器参数

代码亮点

  • 严格遵循 DCGAN 论文推荐架构(转置卷积、BatchNorm、LeakyReLU/Tanh)
  • 交替训练策略:每步先更新 D,再更新 G,防止一方过强导致训练崩溃
  • 梯度控制:在训练 D 时对 fake_imgs 使用 .detach(),避免不必要的梯度计算
  • 标签平滑(可选改进):可将 real_label 设为 0.9 而非 1.0,提升鲁棒性

5. GAN 的挑战与改进

问题 表现 解决方案
模式崩溃(Mode Collapse) G 只生成少数几种样本 WGAN(Wasserstein GAN)、Unrolled GAN
训练不稳定 D 过强导致 G 梯度消失 特征匹配、谱归一化(Spectral Norm)
评估困难 无明确损失衡量生成质量 Inception Score (IS)、FID 分数

📌 WGAN 核心改进:将原始 JS 散度替换为 Wasserstein 距离,损失函数变为:
min ⁡ G max ⁡ ∥ f ∥ L ≤ 1 E x ∼ p data [ f ( x ) ] − E z ∼ p z [ f ( G ( z ) ) ] \min_G \max_{\|f\|_L \leq 1} \mathbb{E}_{x \sim p_{\text{data}}}[f(x)] - \mathbb{E}_{z \sim p_z}[f(G(z))] GminfL1maxExpdata[f(x)]Ezpz[f(G(z))]
其中 f f f 为 1-Lipschitz 函数,通过权重裁剪梯度惩罚实现。


6. 应用场景:GAN 能做什么?

  • 图像生成:StyleGAN 生成超写实人脸
  • 图像翻译:CycleGAN 实现马↔斑马、照片↔油画
  • 数据增强:为医疗影像生成稀缺病灶样本
  • 超分辨率:SRGAN 将低清图变高清

7. 结语与预告

GAN 开启了“无监督生成”的新纪元,其对抗思想也深刻影响了强化学习、自监督学习等领域。理解 GAN,是掌握现代生成式 AI 的关键一步。

🔗 延伸阅读

  • 上一篇:【[人工智能][深度学习] ① RNN核心算法介绍:从循环结构到LSTM门控机制](https://blog.csdn.net/xiezhiyi007/article/details/155281772)**
  • 再上一篇:【[人工智能]发展历程全景解析:从图灵测试到大模型时代(含CNN、Q-Learning深度实践)](https://blog.csdn.net/xiezhiyi007/article/details/155131043)

💡 下一篇预告
【人工智能】【深度学习】 ③ Transformer核心算法介绍:自注意力机制如何颠覆AI?
将深入剖析 QKV 计算、多头注意力、位置编码,并对比 RNN/CNN 的优劣,敬请期待!

Logo

更多推荐