用PyTorch复现CycleGAN:从零开始手搓一个图像风格转换模型(附完整代码)
本文详细介绍了如何使用PyTorch从零开始实现CycleGAN模型,完成图像风格转换任务。通过环境准备、核心架构实现、损失函数设计到训练策略的完整流程,帮助开发者掌握无监督图像转换技术。文章包含完整代码示例,特别适合想深入生成对抗网络实践的PyTorch开发者。
用PyTorch构建CycleGAN:从零实现图像风格转换的工程实践
在计算机视觉领域,图像到图像的转换一直是个令人着迷的课题。想象一下,将夏日的风景照瞬间变成冬日的雪景,或是将素描草图自动转化为逼真的彩色图像——这正是CycleGAN展现魔力的地方。不同于普通的GAN,CycleGAN不需要成对的训练数据,这种无监督学习的能力让它成为解决现实问题的利器。本文将带你从PyTorch的基础张量操作开始,逐步构建完整的CycleGAN框架,特别适合那些已经熟悉PyTorch但想深入生成对抗网络实践的开发者。
1. 环境准备与数据加载
1.1 搭建PyTorch开发环境
推荐使用Python 3.8+和PyTorch 1.10+的组合,这是经过验证的稳定版本搭配。通过conda可以快速创建隔离环境:
conda create -n cyclegan python=3.8
conda activate cyclegan
pip install torch torchvision torchaudio
pip install opencv-python matplotlib tqdm
对于GPU加速,确保安装对应CUDA版本的PyTorch。可以通过以下代码验证环境:
import torch
print(f"PyTorch版本: {torch.__version__}")
print(f"CUDA可用: {torch.cuda.is_available()}")
print(f"设备数量: {torch.cuda.device_count()}")
1.2 准备图像数据集
CycleGAN的魅力在于它只需要两个域的独立图像集合。以经典的马↔斑马转换为例,我们需要:
-
创建项目目录结构:
cyclegan_project/ ├── datasets/ │ ├── horse2zebra/ │ │ ├── trainA/ # 马训练集 │ │ ├── trainB/ # 斑马训练集 │ │ └── testA/ # 马测试集 │ └── README.md ├── checkpoints/ ├── results/ └── src/ -
使用
torchvision.datasets.ImageFolder配合自定义transform:
from torchvision import transforms
transform = transforms.Compose([
transforms.Resize(286, transforms.InterpolationMode.BICUBIC),
transforms.RandomCrop(256),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
horse_dataset = ImageFolder('datasets/horse2zebra/trainA', transform=transform)
zebra_dataset = ImageFolder('datasets/horse2zebra/trainB', transform=transform)
提示:图像预处理中的随机裁剪和水平翻转是重要的数据增强手段,能有效防止过拟合。
2. CycleGAN核心架构实现
2.1 生成器网络设计
CycleGAN采用U-Net结构的生成器,包含编码器-解码器架构和跳跃连接。以下是关键实现细节:
class ResidualBlock(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.block = nn.Sequential(
nn.ReflectionPad2d(1),
nn.Conv2d(in_channels, in_channels, 3),
nn.InstanceNorm2d(in_channels),
nn.ReLU(inplace=True),
nn.ReflectionPad2d(1),
nn.Conv2d(in_channels, in_channels, 3),
nn.InstanceNorm2d(in_channels)
)
def forward(self, x):
return x + self.block(x)
class Generator(nn.Module):
def __init__(self, in_channels=3, out_channels=3, num_residual=9):
super().__init__()
# 编码器部分
self.encoder = nn.Sequential(
nn.ReflectionPad2d(3),
nn.Conv2d(in_channels, 64, 7),
nn.InstanceNorm2d(64),
nn.ReLU(inplace=True),
nn.Conv2d(64, 128, 3, stride=2, padding=1),
nn.InstanceNorm2d(128),
nn.ReLU(inplace=True),
nn.Conv2d(128, 256, 3, stride=2, padding=1),
nn.InstanceNorm2d(256),
nn.ReLU(inplace=True)
)
# 残差块
self.residual = nn.Sequential(
*[ResidualBlock(256) for _ in range(num_residual)]
)
# 解码器部分
self.decoder = nn.Sequential(
nn.ConvTranspose2d(256, 128, 3, stride=2, padding=1, output_padding=1),
nn.InstanceNorm2d(128),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(128, 64, 3, stride=2, padding=1, output_padding=1),
nn.InstanceNorm2d(64),
nn.ReLU(inplace=True),
nn.ReflectionPad2d(3),
nn.Conv2d(64, out_channels, 7),
nn.Tanh()
)
def forward(self, x):
x = self.encoder(x)
x = self.residual(x)
x = self.decoder(x)
return x
关键设计选择:
- 反射填充(ReflectionPad):比零填充更能保持图像边缘连续性
- 实例归一化(InstanceNorm):对风格转换任务特别有效
- 残差连接:帮助解决深层网络梯度消失问题
- Tanh激活:将输出限制在[-1,1]范围,对应归一化后的输入
2.2 判别器网络实现
判别器采用PatchGAN架构,对图像的局部区域进行真伪判断:
class Discriminator(nn.Module):
def __init__(self, in_channels=3):
super().__init__()
self.model = nn.Sequential(
nn.Conv2d(in_channels, 64, 4, stride=2, padding=1),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(64, 128, 4, stride=2, padding=1),
nn.InstanceNorm2d(128),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(128, 256, 4, stride=2, padding=1),
nn.InstanceNorm2d(256),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(256, 512, 4, stride=1, padding=1),
nn.InstanceNorm2d(512),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(512, 1, 4, stride=1, padding=1)
)
def forward(self, x):
return self.model(x)
PatchGAN的特点在于:
- 输出不是单一的真假判断,而是N×N的矩阵
- 每个元素对应输入图像的一个局部区域
- 能捕捉图像的局部细节特征
- 计算效率高,参数少
3. 损失函数与训练策略
3.1 复合损失函数设计
CycleGAN的核心创新在于循环一致性损失,完整的损失函数包含多个部分:
# 对抗损失
criterion_GAN = nn.MSELoss()
# 循环一致性损失
criterion_cycle = nn.L1Loss()
# 身份损失
criterion_identity = nn.L1Loss()
# 生成器G的完整损失
loss_G = (
criterion_GAN(D_B(fake_B), valid) + # GAN loss G
criterion_GAN(D_A(fake_A), valid) + # GAN loss F
criterion_cycle(rec_A, real_A) * lambda_cycle + # 循环一致性A→B→A
criterion_cycle(rec_B, real_B) * lambda_cycle + # 循环一致性B→A→B
criterion_identity(identity_A, real_A) * lambda_identity + # 身份损失A
criterion_identity(identity_B, real_B) * lambda_identity # 身份损失B
)
各损失项的平衡系数经验值:
- λ_cycle:10(循环一致性损失权重)
- λ_identity:0.5(身份损失权重)
3.2 优化器配置与学习率调整
使用Adam优化器并实现学习率衰减:
optimizer_G = torch.optim.Adam(
itertools.chain(G_AB.parameters(), G_BA.parameters()),
lr=lr, betas=(0.5, 0.999)
)
optimizer_D = torch.optim.Adam(
itertools.chain(D_A.parameters(), D_B.parameters()),
lr=lr, betas=(0.5, 0.999)
)
# 学习率线性衰减
def lambda_rule(epoch):
lr_l = 1.0 - max(0, epoch + 1 - 100) / float(101)
return lr_l
scheduler_G = lr_scheduler.LambdaLR(optimizer_G, lr_lambda=lambda_rule)
scheduler_D = lr_scheduler.LambdaLR(optimizer_D, lr_lambda=lambda_rule)
训练过程中的关键技巧:
- 判别器比生成器多训练几次(如D_step=3)
- 使用历史生成的图像缓冲池(buffer_size=50)
- 逐步降低学习率稳定训练
4. 训练循环与结果可视化
4.1 完整训练流程实现
def train_one_epoch(G_AB, G_BA, D_A, D_B,
dataloader_A, dataloader_B,
optimizer_G, optimizer_D,
device, epoch, n_epochs):
for i, (real_A, real_B) in enumerate(zip(dataloader_A, dataloader_B)):
# 数据转移到设备
real_A = real_A[0].to(device)
real_B = real_B[0].to(device)
# 生成假图像
fake_B = G_AB(real_A)
fake_A = G_BA(real_B)
# 训练生成器
optimizer_G.zero_grad()
loss_G = compute_generator_loss(
G_AB, G_BA, D_A, D_B,
real_A, real_B, fake_A, fake_B
)
loss_G.backward()
optimizer_G.step()
# 训练判别器
if i % opt.D_step == 0:
optimizer_D.zero_grad()
loss_D = compute_discriminator_loss(
D_A, D_B, real_A, real_B,
fake_A.detach(), fake_B.detach()
)
loss_D.backward()
optimizer_D.step()
# 打印训练信息
if i % opt.print_freq == 0:
print(f"[Epoch {epoch}/{n_epochs}] [Batch {i}/{len(dataloader_A)}] "
f"Loss_D: {loss_D.item():.4f} Loss_G: {loss_G.item():.4f}")
4.2 结果可视化与模型保存
训练过程中定期保存样本图像和模型权重:
def save_sample_images(epoch, G_AB, G_BA, val_dataloader_A, device):
G_AB.eval()
G_BA.eval()
with torch.no_grad():
real_A = next(iter(val_dataloader_A))[0].to(device)
fake_B = G_AB(real_A)
rec_A = G_BA(fake_B)
# 将图像从[-1,1]转换到[0,1]
real_A = 0.5 * (real_A + 1)
fake_B = 0.5 * (fake_B + 1)
rec_A = 0.5 * (rec_A + 1)
# 拼接对比图像
comparison = torch.cat([real_A, fake_B, rec_A], dim=3)
save_image(comparison, f"results/cyclegan_{epoch}.png", nrow=1)
G_AB.train()
G_BA.train()
# 保存模型检查点
def save_checkpoint(epoch, G_AB, G_BA, D_A, D_B):
torch.save({
'epoch': epoch,
'G_AB_state_dict': G_AB.state_dict(),
'G_BA_state_dict': G_BA.state_dict(),
'D_A_state_dict': D_A.state_dict(),
'D_B_state_dict': D_B.state_dict(),
}, f"checkpoints/cyclegan_{epoch}.pth")
5. 高级技巧与性能优化
5.1 训练稳定性提升策略
GAN训练 notoriously tricky,以下技巧能显著提升稳定性:
-
梯度惩罚:在判别器损失中加入梯度惩罚项
def compute_gradient_penalty(D, real_samples, fake_samples): alpha = torch.rand(real_samples.size(0), 1, 1, 1).to(device) interpolates = (alpha * real_samples + (1-alpha) * fake_samples).requires_grad_(True) d_interpolates = D(interpolates) gradients = torch.autograd.grad( outputs=d_interpolates, inputs=interpolates, grad_outputs=torch.ones_like(d_interpolates), create_graph=True, retain_graph=True, only_inputs=True )[0] gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() return gradient_penalty -
谱归一化:稳定判别器训练
from torch.nn.utils import spectral_norm class DiscriminatorWithSN(nn.Module): def __init__(self): super().__init__() self.model = nn.Sequential( spectral_norm(nn.Conv2d(3, 64, 4, stride=2, padding=1)), nn.LeakyReLU(0.2), spectral_norm(nn.Conv2d(64, 128, 4, stride=2, padding=1)), nn.LeakyReLU(0.2), spectral_norm(nn.Conv2d(128, 256, 4, stride=2, padding=1)), nn.LeakyReLU(0.2), spectral_norm(nn.Conv2d(256, 1, 4, stride=1, padding=1)) )
5.2 多GPU训练加速
对于大规模数据集,使用DataParallel加速:
if torch.cuda.device_count() > 1:
print(f"使用 {torch.cuda.device_count()} 个GPU")
G_AB = nn.DataParallel(G_AB)
G_BA = nn.DataParallel(G_BA)
D_A = nn.DataParallel(D_A)
D_B = nn.DataParallel(D_B)
G_AB.to(device)
G_BA.to(device)
D_A.to(device)
D_B.to(device)
5.3 混合精度训练
利用Apex或PyTorch原生AMP减少显存占用:
from torch.cuda.amp import autocast, GradScaler
scaler = GradScaler()
with autocast():
fake_B = G_AB(real_A)
loss_G = compute_generator_loss(...)
scaler.scale(loss_G).backward()
scaler.step(optimizer_G)
scaler.update()
6. 模型部署与应用扩展
6.1 导出为生产环境可用模型
将训练好的模型导出为TorchScript:
# 跟踪模型
example_input = torch.rand(1, 3, 256, 256).to(device)
traced_G = torch.jit.trace(G_AB, example_input)
# 保存
traced_G.save("horse2zebra.pt")
# 加载使用
model = torch.jit.load("horse2zebra.pt")
fake_B = model(real_A)
6.2 扩展到其他图像转换任务
只需更换数据集,同样的架构可用于:
- 照片↔油画风格转换
- 白天↔夜晚场景转换
- 季节变换(夏↔冬)
- 医学图像模态转换(CT↔MRI)
关键调整点:
- 根据图像复杂度调整生成器残差块数量
- 对于高分辨率图像,增加PatchGAN的感受野
- 调整损失函数权重平衡
6.3 网页应用集成示例
使用Flask创建简单的Web API:
from flask import Flask, request, jsonify
import torchvision.transforms as transforms
from PIL import Image
import io
app = Flask(__name__)
model = torch.jit.load("horse2zebra.pt").eval()
@app.route('/transform', methods=['POST'])
def transform():
file = request.files['image']
img = Image.open(io.BytesIO(file.read()))
transform = transforms.Compose([
transforms.Resize(256),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
img_tensor = transform(img).unsqueeze(0)
with torch.no_grad():
output = model(img_tensor)
output_img = transforms.ToPILImage()(output.squeeze().cpu() * 0.5 + 0.5)
byte_arr = io.BytesIO()
output_img.save(byte_arr, format='PNG')
return jsonify({'image': byte_arr.getvalue().hex()})
if __name__ == '__main__':
app.run(host='0.0.0.0', port=5000)
更多推荐




所有评论(0)