用Python代码拆解Diffusion Model:从加噪到去噪的视觉化之旅

Diffusion Model近年来在图像生成领域掀起了一场革命,但许多初学者在面对那些复杂的数学公式时常常感到无从下手。本文将通过Python代码,带你一步步拆解Diffusion Model的核心机制——前向加噪与反向去噪过程。我们将用PyTorch实现每个关键步骤,并通过可视化让你直观感受图像是如何一步步被噪声"腐蚀",又如何神奇地被"复原"的。

1. 理解Diffusion Model的基本原理

Diffusion Model的核心思想其实非常直观:它模拟了一个渐进式的"破坏"与"重建"过程。想象你有一幅清晰的画作,前向过程就像有人不断向画上撒沙子,直到画作完全被沙子覆盖;而反向过程则是一个聪明的修复师,他能根据沙子的分布规律,一步步还原出原始画作。

这个模型的两个关键阶段:

  • 前向过程(Forward Process):逐步向数据添加高斯噪声,相当于将数据"扩散"到一个纯噪声分布
  • 反向过程(Reverse Process):学习如何逐步去除噪声,从纯噪声中重建原始数据

在代码实现前,我们需要设置基本环境:

import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from torchvision import transforms
from PIL import Image
import numpy as np

# 设置随机种子保证可重复性
torch.manual_seed(42)

2. 前向加噪过程的代码实现

前向过程的核心公式是:

Xt = √(αt)*Xt-1 + √(1-αt)*εt

让我们用代码将这个公式具象化。首先需要定义噪声调度(noise schedule),它决定了每个时间步添加多少噪声。

def linear_beta_schedule(timesteps, start=0.0001, end=0.02):
    """线性噪声调度,返回所有时间步的β值"""
    return torch.linspace(start, end, timesteps)

timesteps = 200
betas = linear_beta_schedule(timesteps)
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, axis=0)  # α的累积乘积

接下来实现前向扩散的核心函数:

def forward_diffusion_sample(x0, t, device="cpu"):
    """根据给定时间步t对输入x0添加噪声"""
    noise = torch.randn_like(x0)
    sqrt_alphas_cumprod_t = torch.sqrt(alphas_cumprod[t])[:, None, None, None]
    sqrt_one_minus_alphas_cumprod_t = torch.sqrt(1. - alphas_cumprod[t])[:, None, None, None]
    
    return sqrt_alphas_cumprod_t.to(device) * x0.to(device) + sqrt_one_minus_alphas_cumprod_t.to(device) * noise.to(device), noise.to(device)

让我们用一张示例图片演示这个过程:

# 加载示例图片
image = Image.open("example.jpg")
transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor()
])
x0 = transform(image).unsqueeze(0)

# 可视化不同时间步的加噪效果
plt.figure(figsize=(15, 5))
for i, t in enumerate([0, 50, 100, 150, 199]):
    xt, _ = forward_diffusion_sample(x0, torch.tensor([t]))
    plt.subplot(1, 5, i+1)
    plt.imshow(xt.squeeze().permute(1, 2, 0).numpy())
    plt.title(f"t={t}")
    plt.axis('off')
plt.show()

3. 构建简化的U-Net噪声预测器

反向过程的核心是一个能够预测噪声的神经网络。我们实现一个简化版的U-Net:

class Block(nn.Module):
    def __init__(self, in_ch, out_ch, time_emb_dim):
        super().__init__()
        self.time_mlp = nn.Linear(time_emb_dim, out_ch)
        self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
        self.relu = nn.ReLU()
        
    def forward(self, x, t):
        h = self.relu(self.conv1(x))
        time_emb = self.relu(self.time_mlp(t))
        h = h + time_emb[:, :, None, None]
        return self.relu(self.conv2(h))

class SimpleUNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.time_mlp = nn.Sequential(
            nn.Linear(1, 32),
            nn.ReLU(),
            nn.Linear(32, 32)
        )
        
        self.down1 = Block(3, 64, 32)
        self.down2 = Block(64, 128, 32)
        self.up1 = Block(128, 64, 32)
        self.up2 = Block(64, 3, 32)
        
    def forward(self, x, t):
        t = self.time_mlp(t)
        x1 = self.down1(x, t)
        x2 = self.down2(x1, t)
        x = self.up1(x2, t)
        return self.up2(x, t)

4. 训练噪声预测模型

有了U-Net结构,我们需要定义损失函数和训练过程:

def get_loss(model, x0, t):
    """计算噪声预测的损失"""
    xt, noise = forward_diffusion_sample(x0, t)
    predicted_noise = model(xt, t)
    return torch.nn.functional.mse_loss(noise, predicted_noise)

def train(model, dataloader, epochs=100, device="cpu"):
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    model.to(device)
    
    for epoch in range(epochs):
        for batch in dataloader:
            optimizer.zero_grad()
            t = torch.randint(0, timesteps, (batch.size(0),), device=device).float()
            loss = get_loss(model, batch, t)
            loss.backward()
            optimizer.step()
        
        print(f"Epoch {epoch}, Loss: {loss.item():.4f}")

5. 实现反向去噪过程

训练好噪声预测器后,我们可以实现反向去噪过程:

@torch.no_grad()
def reverse_process(model, shape, device="cpu"):
    """从纯噪声开始逐步去噪"""
    x = torch.randn(shape, device=device)
    
    for t in range(timesteps-1, -1, -1):
        t_tensor = torch.tensor([t], device=device).float()
        predicted_noise = model(x, t_tensor)
        alpha_t = alphas[t]
        alpha_t_cumprod = alphas_cumprod[t]
        
        if t > 0:
            noise = torch.randn_like(x)
        else:
            noise = torch.zeros_like(x)
            
        x = (1 / torch.sqrt(alpha_t)) * (x - ((1 - alpha_t) / torch.sqrt(1 - alpha_t_cumprod)) * predicted_noise) + torch.sqrt(1 - alpha_t) * noise
    
    return x

6. 完整流程演示与可视化

现在让我们把整个过程串起来,从加噪到去噪:

# 初始化模型和优化器
model = SimpleUNet()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

# 模拟训练过程(实际使用时需要真实数据集)
for epoch in range(10):
    t = torch.randint(0, timesteps, (1,))
    loss = get_loss(model, x0, t)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    print(f"Epoch {epoch}, Loss: {loss.item():.4f}")

# 生成新图像
generated = reverse_process(model, x0.shape)
plt.imshow(generated.squeeze().permute(1, 2, 0).detach().numpy())
plt.axis('off')
plt.show()

7. 关键参数的影响与调优

Diffusion Model的性能很大程度上依赖于以下几个关键参数的选择:

参数 典型值 影响 调优建议
时间步数 200-1000 步数越多,生成质量越高但速度越慢 根据硬件条件平衡质量与速度
β起始值 0.0001 控制初始噪声强度 太小会导致前几步变化不明显
β结束值 0.02 控制最终噪声强度 太大会导致图像过早被破坏
学习率 1e-3到1e-4 影响训练稳定性 使用学习率调度器

在实际项目中,你可以尝试不同的噪声调度策略:

def cosine_beta_schedule(timesteps, s=0.008):
    """余弦噪声调度,通常能获得更好的结果"""
    steps = timesteps + 1
    x = torch.linspace(0, timesteps, steps)
    alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2
    alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
    betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
    return torch.clip(betas, 0, 0.999)

8. 实际应用中的技巧与挑战

在实现Diffusion Model时,有几个常见的陷阱需要注意:

  • 梯度不稳定:当时间步数很大时,不同时间步的梯度规模可能差异很大。解决方案包括:

    • 使用梯度裁剪
    • 对不同时间步的损失进行加权
    • 采用渐进式训练策略
  • 模式坍塌:模型可能只学会生成有限的几种样本。可以通过以下方法缓解:

    • 增加模型容量
    • 使用更复杂的架构(如带Attention的U-Net)
    • 调整噪声调度策略
  • 计算资源消耗:Diffusion Model训练通常需要大量计算资源。可以考虑:

    • 使用混合精度训练
    • 实现分布式训练
    • 采用知识蒸馏技术训练更小的模型

一个实用的训练技巧是在训练初期使用较小的图像尺寸,后期再切换到全尺寸:

# 渐进式训练示例
for phase in [(64, 100), (128, 200), (256, 300)]:
    size, epochs = phase
    # 调整数据加载器使用新尺寸
    # 调整模型架构(如果需要)
    for epoch in range(epochs):
        # 训练逻辑

通过本文的代码实现,你应该已经对Diffusion Model的核心机制有了直观理解。记住,真正掌握这些概念的最佳方式是在你自己的项目中实践这些代码,并尝试调整不同的参数和架构。

Logo

免费领 50 小时云算力,进群参与显卡、AI PC 幸运抽奖

更多推荐