用PyTorch实战DDPM:无需深究数学也能玩转扩散模型

当你在社交媒体上看到AI生成的艺术作品时,是否好奇过它们背后的技术原理?扩散模型(Diffusion Models)作为当前最热门的生成式AI技术之一,正以惊人的速度改变着内容创作的格局。本文将带你绕过复杂的数学推导,直接进入代码实践环节,用PyTorch从零构建一个完整的DDPM(Denoising Diffusion Probabilistic Models)模型。

1. 环境准备与数据加载

在开始之前,我们需要配置好开发环境。推荐使用Python 3.8+和PyTorch 1.12+版本:

pip install torch torchvision matplotlib tqdm

对于数据集,我们将使用经典的CIFAR-10,它包含60,000张32x32的彩色图像:

import torch
from torchvision import datasets, transforms

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

dataset = datasets.CIFAR10(
    root='./data', 
    train=True,
    download=True, 
    transform=transform
)

dataloader = torch.utils.data.DataLoader(
    dataset, 
    batch_size=128, 
    shuffle=True
)

提示:如果你的GPU显存较小,可以将batch_size调整为64或32

2. DDPM核心组件实现

2.1 噪声调度器

扩散模型的核心在于如何合理地添加和去除噪声。我们需要定义一个噪声调度器来控制不同时间步的噪声强度:

import math

def linear_beta_schedule(timesteps):
    beta_start = 0.0001
    beta_end = 0.02
    return torch.linspace(beta_start, beta_end, timesteps)

def cosine_beta_schedule(timesteps, s=0.008):
    steps = timesteps + 1
    x = torch.linspace(0, timesteps, steps)
    alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * math.pi * 0.5) ** 2
    alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
    betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
    return torch.clip(betas, 0, 0.999)

timesteps = 1000
betas = cosine_beta_schedule(timesteps)  # 使用余弦调度器效果更好

# 预计算有用的值
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, axis=0)
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod)

2.2 前向加噪过程

前向过程逐步将数据转换为高斯噪声,这个过程是固定的,不需要训练:

def q_sample(x_start, t, noise=None):
    if noise is None:
        noise = torch.randn_like(x_start)
    
    sqrt_alphas_cumprod_t = extract(sqrt_alphas_cumprod, t, x_start.shape)
    sqrt_one_minus_alphas_cumprod_t = extract(sqrt_one_minus_alphas_cumprod, t, x_start.shape)
    
    return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise

def extract(a, t, x_shape):
    batch_size = t.shape[0]
    out = a.gather(-1, t.cpu())
    return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)

3. 构建U-Net模型

U-Net是DDPM中用于预测噪声的核心网络结构。下面我们实现一个简化版的U-Net:

import torch.nn as nn
import torch.nn.functional as F

class Block(nn.Module):
    def __init__(self, in_ch, out_ch, time_emb_dim):
        super().__init__()
        self.time_mlp = nn.Linear(time_emb_dim, out_ch)
        self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
        self.norm = nn.GroupNorm(8, out_ch)
        
    def forward(self, x, t):
        h = self.conv1(x)
        h = self.norm(h)
        h = F.silu(h)
        
        time_emb = F.silu(self.time_mlp(t))
        h = h + time_emb[:, :, None, None]
        
        h = self.conv2(h)
        h = self.norm(h)
        h = F.silu(h)
        return h

class UNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.time_mlp = nn.Sequential(
            SinusoidalPositionEmbeddings(100),
            nn.Linear(100, 256),
            nn.SiLU(),
            nn.Linear(256, 256)
        )
        
        self.down1 = Block(3, 64, 256)
        self.down2 = Block(64, 128, 256)
        self.down3 = Block(128, 256, 256)
        
        self.mid = Block(256, 256, 256)
        
        self.up1 = Block(512, 128, 256)
        self.up2 = Block(256, 64, 256)
        self.up3 = Block(128, 64, 256)
        
        self.out = nn.Conv2d(64, 3, 1)
        
    def forward(self, x, t):
        t = self.time_mlp(t)
        
        # 下采样
        h1 = self.down1(x, t)
        h2 = self.down2(F.max_pool2d(h1, 2), t)
        h3 = self.down3(F.max_pool2d(h2, 2), t)
        
        # 中间层
        h = self.mid(F.max_pool2d(h3, 2), t)
        
        # 上采样
        h = F.interpolate(h, scale_factor=2, mode='nearest')
        h = self.up1(torch.cat([h, h3], dim=1), t)
        h = F.interpolate(h, scale_factor=2, mode='nearest')
        h = self.up2(torch.cat([h, h2], dim=1), t)
        h = F.interpolate(h, scale_factor=2, mode='nearest')
        h = self.up3(torch.cat([h, h1], dim=1), t)
        
        return self.out(h)

class SinusoidalPositionEmbeddings(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim
        
    def forward(self, time):
        device = time.device
        half_dim = self.dim // 2
        embeddings = math.log(10000) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
        embeddings = time[:, None] * embeddings[None, :]
        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
        return embeddings

4. 训练与采样

4.1 训练循环

DDPM的训练目标是让网络学会预测添加到图像中的噪声:

def train(model, dataloader, optimizer, epochs, device):
    model.train()
    for epoch in range(epochs):
        for step, (images, _) in enumerate(dataloader):
            images = images.to(device)
            
            # 随机采样时间步
            t = torch.randint(0, timesteps, (images.shape[0],), device=device).long()
            
            # 生成随机噪声
            noise = torch.randn_like(images)
            
            # 前向加噪过程
            noisy_images = q_sample(images, t, noise)
            
            # 预测噪声
            predicted_noise = model(noisy_images, t)
            
            # 计算损失
            loss = F.mse_loss(noise, predicted_noise)
            
            # 反向传播
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            if step % 100 == 0:
                print(f"Epoch {epoch} | Step {step} | Loss: {loss.item():.4f}")

4.2 采样生成

训练完成后,我们可以使用模型从纯噪声开始逐步去噪生成新图像:

@torch.no_grad()
def p_sample(model, x, t, t_index):
    betas_t = extract(betas, t, x.shape)
    sqrt_one_minus_alphas_cumprod_t = extract(sqrt_one_minus_alphas_cumprod, t, x.shape)
    sqrt_recip_alphas_t = extract(sqrt_recip_alphas, t, x.shape)
    
    # 计算模型均值
    model_mean = sqrt_recip_alphas_t * (
        x - betas_t * model(x, t) / sqrt_one_minus_alphas_cumprod_t
    )
    
    if t_index == 0:
        return model_mean
    else:
        posterior_variance_t = extract(posterior_variance, t, x.shape)
        noise = torch.randn_like(x)
        return model_mean + torch.sqrt(posterior_variance_t) * noise

@torch.no_grad()
def sample(model, image_size, batch_size=16, channels=3):
    device = next(model.parameters()).device
    # 从随机噪声开始
    img = torch.randn((batch_size, channels, image_size, image_size), device=device)
    
    for i in reversed(range(0, timesteps)):
        t = torch.full((batch_size,), i, device=device, dtype=torch.long)
        img = p_sample(model, img, t, i)
    
    # 将图像从[-1,1]转换到[0,1]
    img = (img + 1) * 0.5
    return img

5. 模型优化与技巧

在实际应用中,我们可以采用以下几种策略来提升DDPM的性能:

  1. 学习率调度:使用余弦退火学习率可以显著提升模型收敛速度
  2. 混合精度训练:通过FP16训练可以节省显存并加快训练速度
  3. EMA模型:使用指数移动平均的模型参数可以提高生成质量
  4. 渐进式训练:从低分辨率开始训练,逐步提高分辨率
# 示例:EMA模型实现
class EMA:
    def __init__(self, beta):
        super().__init__()
        self.beta = beta
        self.step = 0
        
    def update_model_average(self, ema_model, current_model):
        for current_params, ema_params in zip(current_model.parameters(), ema_model.parameters()):
            old_weight, new_weight = ema_params.data, current_params.data
            ema_params.data = self.update_average(old_weight, new_weight)
            
    def update_average(self, old, new):
        if old is None:
            return new
        return old * self.beta + (1 - self.beta) * new

在CIFAR-10数据集上训练约50个epoch后,你应该能够看到模型开始生成可识别的物体图像。虽然32x32的分辨率不高,但这个完整的实现已经包含了DDPM的所有关键组件。

Logo

免费领 100 小时云算力,进群参与显卡、AI PC 幸运抽奖

更多推荐