用PyTorch实战DDPM:零数学基础也能玩转扩散模型

在咖啡馆里,我遇到一位刚入行AI的开发者小张。他盯着Stable Diffusion生成的图片发呆,却对背后的扩散模型原理望而却步:"那些数学公式看着就头疼,难道不精通概率论就玩不转生成式AI吗?"这让我意识到,大多数教程都把扩散模型讲成了数学考试,而忽略了它本质上是一个可以通过代码直观理解的算法框架。本文将用PyTorch带你从零实现DDPM(Denoising Diffusion Probabilistic Models),全程只需基础Python知识,我们会把复杂理论转化为可运行的代码块,让你在动手实践中建立直觉认知。

1. 环境准备与数据加载

1.1 安装依赖库

确保你的Python环境≥3.8,然后安装以下核心库:

pip install torch torchvision matplotlib tqdm

1.2 选择训练数据集

我们将使用MNIST作为示例数据集,它的低分辨率特性适合快速验证模型:

from torchvision import datasets, transforms

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])
dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=128, shuffle=True)

提示:如果想尝试人脸生成,可替换为CelebA数据集,但需要调整后续的模型容量和训练时长

2. DDPM核心组件实现

2.1 噪声调度器

这是控制加噪过程的关键组件,我们采用余弦调度方案:

import math

def cosine_beta_schedule(timesteps, s=0.008):
    """
    余弦噪声调度器
    Args:
        timesteps: 总时间步数
        s: 控制起始噪声率的偏移量
    """
    steps = timesteps + 1
    x = torch.linspace(0, timesteps, steps)
    alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * math.pi * 0.5) ** 2
    alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
    betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
    return torch.clip(betas, 0, 0.999)

timesteps = 200
betas = cosine_beta_schedule(timesteps)

2.2 前向加噪过程

这是扩散模型区别于其他生成模型的关键步骤:

def q_sample(x_start, t, noise=None):
    """
    对输入图像逐步加噪
    Args:
        x_start: 原始图像 (B, C, H, W)
        t: 时间步 (B,)
        noise: 可选的外部噪声输入
    """
    if noise is None:
        noise = torch.randn_like(x_start)
    
    sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod[t])[:, None, None, None]
    sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod[t])[:, None, None, None]
    
    return sqrt_alphas_cumprod * x_start + sqrt_one_minus_alphas_cumprod * noise

可视化加噪过程的效果:

时间步 图像示例 噪声比例
t=0 ![原始图像] 0%
t=50 ![轻度加噪] 30%
t=100 ![中度加噪] 60%
t=200 ![完全噪声] 100%

3. 构建U-Net噪声预测器

3.1 基础残差块

这是U-Net的核心构建模块:

class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, time_emb_dim):
        super().__init__()
        self.time_mlp = nn.Linear(time_emb_dim, out_channels)
        self.block = nn.Sequential(
            nn.GroupNorm(32, in_channels),
            nn.SiLU(),
            nn.Conv2d(in_channels, out_channels, 3, padding=1),
            nn.GroupNorm(32, out_channels),
            nn.SiLU(),
            nn.Conv2d(out_channels, out_channels, 3, padding=1)
        )
        self.res_conv = nn.Conv2d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity()

    def forward(self, x, t):
        h = self.block(x)
        t_emb = self.time_mlp(t)[:, :, None, None]
        return h + t_emb + self.res_conv(x)

3.2 完整U-Net架构

实现一个简化版的DDPM U-Net:

class UNet(nn.Module):
    def __init__(self, in_channels=1, out_channels=1, dim=32):
        super().__init__()
        self.time_mlp = nn.Sequential(
            SinusoidalPositionEmbeddings(dim),
            nn.Linear(dim, dim * 4),
            nn.SiLU(),
            nn.Linear(dim * 4, dim)
        )
        
        self.down1 = ResidualBlock(in_channels, dim, dim)
        self.down2 = ResidualBlock(dim, dim*2, dim)
        self.mid = ResidualBlock(dim*2, dim*2, dim)
        self.up1 = ResidualBlock(dim*3, dim, dim)
        self.up2 = ResidualBlock(dim*2, out_channels, dim)
        self.conv_out = nn.Conv2d(out_channels, out_channels, 1)

    def forward(self, x, t):
        t_emb = self.time_mlp(t)
        # 下采样路径
        h1 = self.down1(x, t_emb)
        h2 = self.down2(F.max_pool2d(h1, 2), t_emb)
        # 中间层
        h_mid = self.mid(F.max_pool2d(h2, 2), t_emb)
        # 上采样路径
        h_up1 = self.up1(F.interpolate(h_mid, scale_factor=2), t_emb)
        h_up2 = self.up2(F.interpolate(torch.cat([h_up1, h2], dim=1), scale_factor=2), t_emb)
        return self.conv_out(torch.cat([h_up2, h1], dim=1))

4. 训练与采样流程

4.1 训练循环实现

关键训练步骤分解:

  1. 随机采样时间步:均匀选择加噪强度
  2. 生成带噪图像:按选定强度加噪
  3. 预测噪声:U-Net尝试还原添加的噪声
  4. 计算损失:比较预测噪声与真实噪声
def train_step(model, x_start, optimizer):
    model.train()
    optimizer.zero_grad()
    
    # 随机采样时间步
    t = torch.randint(0, timesteps, (x_start.shape[0],), device=device)
    
    # 生成带噪图像和随机噪声
    noise = torch.randn_like(x_start)
    x_noisy = q_sample(x_start, t, noise)
    
    # 预测噪声并计算损失
    predicted_noise = model(x_noisy, t)
    loss = F.mse_loss(noise, predicted_noise)
    
    loss.backward()
    optimizer.step()
    return loss.item()

4.2 图像生成过程

反向去噪的典型流程:

@torch.no_grad()
def p_sample(model, x, t, t_index):
    betas_t = extract(betas, t, x.shape)
    sqrt_one_minus_alphas_cumprod_t = extract(sqrt_one_minus_alphas_cumprod, t, x.shape)
    sqrt_recip_alphas_t = extract(sqrt_recip_alphas, t, x.shape)
    
    # 计算预测均值
    model_mean = sqrt_recip_alphas_t * (x - betas_t * model(x, t) / sqrt_one_minus_alphas_cumprod_t)
    
    if t_index == 0:
        return model_mean
    else:
        posterior_variance_t = extract(posterior_variance, t, x.shape)
        noise = torch.randn_like(x)
        return model_mean + torch.sqrt(posterior_variance_t) * noise

5. 实战技巧与性能优化

5.1 加速采样的关键方法

  • 时间步压缩:将200步压缩到50步
  • 混合精度训练:使用torch.cuda.amp
  • 缓存计算结果:预先计算调度参数
# 示例:时间步重参数化
def rescale_timesteps(t, new_timesteps):
    return (t.float() * (new_timesteps - 1) / timesteps).long()

5.2 常见问题排查表

问题现象 可能原因 解决方案
生成图像模糊 模型容量不足 增加U-Net通道数
训练损失不下降 学习率不当 尝试1e-4到1e-5范围
生成图像有网格伪影 反卷积操作导致 替换为插值+卷积

在Colab上实测,使用单个T4 GPU训练MNIST约30分钟即可看到初步效果。记得保存中间检查点,观察不同训练阶段的生成质量变化。

Logo

免费领 100 小时云算力,进群参与显卡、AI PC 幸运抽奖

更多推荐