别再死记硬背UNet结构了!用PyTorch从零手搓一个,顺便搞懂跳跃连接到底在跳什么
本文通过PyTorch从零构建UNet模型,深入解析跳跃连接的工作原理及其在医学图像分割中的关键作用。通过代码实现和形状打印,直观展示数据在UNet中的流动过程,帮助开发者彻底理解这一经典网络结构的设计精髓。
从零构建UNet:用PyTorch拆解跳跃连接的秘密
当你第一次看到UNet的结构图时,那个对称的U型设计和那些横跨左右的连接线可能让你觉得既美观又神秘。但真正动手实现时,你是否遇到过这样的困惑:"这些跳跃连接到底在传输什么?为什么要在特定位置拼接特征图?"本文将带你用PyTorch从零开始构建一个完整的UNet,在每个关键节点打印张量形状,让你亲眼见证数据是如何在这个精巧的网络中流动的。
1. 理解UNet的核心设计
UNet最初是为医学图像分割设计的,它的成功很大程度上归功于其独特的架构。与传统的编码器-解码器结构不同,UNet引入了跳跃连接(Skip Connection),这使得网络能够同时利用低级的细节信息和高级的语义信息。
为什么跳跃连接如此重要?
- 信息互补 :下采样路径捕获的是抽象的语义信息("这是什么"),而上采样路径需要精确的定位信息("在哪里")
- 梯度流动 :跳跃连接为深层网络提供了更直接的梯度传播路径
- 数据效率 :在医学影像等小数据集场景下,这种设计能更充分地利用有限的数据
import torch
import torch.nn as nn
from torch.nn import functional as F
# 打印设备信息
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
2. 构建基础模块
在搭建完整的UNet之前,我们需要先创建几个基础构建块。这些模块就像乐高积木一样,将被组合成最终的模型。
2.1 双卷积块
每个UNet的层级都包含两个连续的3×3卷积,这是特征提取的基本单元:
class DoubleConv(nn.Module):
"""(卷积 => [BN] => ReLU) * 2"""
def __init__(self, in_channels, out_channels):
super().__init__()
self.double_conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def forward(self, x):
print(f"DoubleConv input shape: {x.shape}") # 调试输出
return self.double_conv(x)
2.2 下采样模块
下采样(也称为收缩路径)通过最大池化逐步降低空间分辨率,同时增加通道数:
class Down(nn.Module):
"""下采样模块:最大池化后接双卷积"""
def __init__(self, in_channels, out_channels):
super().__init__()
self.maxpool_conv = nn.Sequential(
nn.MaxPool2d(2),
DoubleConv(in_channels, out_channels)
)
def forward(self, x):
print(f"Down input shape: {x.shape}") # 调试输出
return self.maxpool_conv(x)
2.3 上采样模块
上采样(扩展路径)通过转置卷积实现分辨率提升,并与跳跃连接的特征图拼接:
class Up(nn.Module):
"""上采样模块:转置卷积后与跳跃连接拼接,再接双卷积"""
def __init__(self, in_channels, out_channels):
super().__init__()
self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
self.conv = DoubleConv(in_channels, out_channels)
def forward(self, x1, x2):
print(f"Up input shapes - x1: {x1.shape}, x2: {x2.shape}") # 调试输出
x1 = self.up(x1)
# 计算填充以确保尺寸匹配
diffY = x2.size()[2] - x1.size()[2]
diffX = x2.size()[3] - x1.size()[3]
x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
diffY // 2, diffY - diffY // 2])
x = torch.cat([x2, x1], dim=1)
print(f"After concatenation shape: {x.shape}") # 调试输出
return self.conv(x)
3. 组装完整的UNet
现在我们可以将这些模块组合成完整的UNet结构。我们将按照经典的UNet架构,包含4个下采样和4个上采样阶段。
class UNet(nn.Module):
def __init__(self, n_channels=3, n_classes=1):
super(UNet, self).__init__()
self.n_channels = n_channels
self.n_classes = n_classes
# 下采样路径
self.inc = DoubleConv(n_channels, 64)
self.down1 = Down(64, 128)
self.down2 = Down(128, 256)
self.down3 = Down(256, 512)
self.down4 = Down(512, 1024)
# 上采样路径
self.up1 = Up(1024, 512)
self.up2 = Up(512, 256)
self.up3 = Up(256, 128)
self.up4 = Up(128, 64)
# 输出层
self.outc = nn.Conv2d(64, n_classes, kernel_size=1)
def forward(self, x):
print(f"Initial input shape: {x.shape}\n")
# 下采样
x1 = self.inc(x)
print(f"After inc shape: {x1.shape}\n")
x2 = self.down1(x1)
print(f"After down1 shape: {x2.shape}\n")
x3 = self.down2(x2)
print(f"After down2 shape: {x3.shape}\n")
x4 = self.down3(x3)
print(f"After down3 shape: {x4.shape}\n")
x5 = self.down4(x4)
print(f"After down4 shape: {x5.shape}\n")
# 上采样
x = self.up1(x5, x4)
print(f"After up1 shape: {x.shape}\n")
x = self.up2(x, x3)
print(f"After up2 shape: {x.shape}\n")
x = self.up3(x, x2)
print(f"After up3 shape: {x.shape}\n")
x = self.up4(x, x1)
print(f"After up4 shape: {x.shape}\n")
# 输出
logits = self.outc(x)
print(f"Final output shape: {logits.shape}")
return logits
4. 数据流动可视化
让我们创建一个测试输入,观察数据在网络中的流动过程:
# 创建测试输入
test_input = torch.randn((1, 3, 572, 572)).to(device)
print("测试输入张量形状:", test_input.shape)
# 初始化模型
model = UNet().to(device)
# 前向传播
with torch.no_grad():
output = model(test_input)
运行这段代码,你将在控制台看到类似以下的输出:
测试输入张量形状: torch.Size([1, 3, 572, 572])
Initial input shape: torch.Size([1, 3, 572, 572])
DoubleConv input shape: torch.Size([1, 3, 572, 572])
After inc shape: torch.Size([1, 64, 572, 572])
Down input shape: torch.Size([1, 64, 572, 572])
DoubleConv input shape: torch.Size([1, 64, 286, 286])
After down1 shape: torch.Size([1, 128, 286, 286])
[...更多中间形状输出...]
Final output shape: torch.Size([1, 1, 388, 388])
5. 跳跃连接的内部机制
跳跃连接是UNet最精妙的设计,但也是最容易让人困惑的部分。让我们深入分析一个具体的上采样步骤:
# 假设我们正在处理up1阶段
x5_shape = torch.Size([1, 1024, 35, 35]) # 来自最深层的特征
x4_shape = torch.Size([1, 512, 35, 35]) # 对应的跳跃连接特征
# 上采样过程
up = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
x5_up = up(x5) # 形状变为 [1, 512, 70, 70]
# 裁剪x4以匹配尺寸
x4_cropped = center_crop(x4, [70, 70])
# 拼接操作
x = torch.cat([x4_cropped, x5_up], dim=1) # 形状变为 [1, 1024, 70, 70]
为什么需要裁剪? 由于卷积的舍入误差,特征图尺寸可能会有微小差异。UNet原始论文中使用了镜像填充的卷积,确保所有中间特征图尺寸都能被2整除。
6. 训练技巧与优化
实现UNet结构只是第一步,要让���型真正发挥作用,还需要注意以下训练细节:
数据增强策略
- 随机旋转和翻转
- 弹性变形(特别适用于医学图像)
- 灰度值变化
# 示例数据增强变换
train_transform = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomVerticalFlip(),
transforms.RandomRotation(20),
transforms.ColorJitter(brightness=0.1, contrast=0.1),
transforms.ToTensor()
])
损失函数选择
- 二分类任务:BCEWithLogitsLoss
- 多分类任务:CrossEntropyLoss
- 类别不平衡:DiceLoss或FocalLoss
# Dice Loss实现示例
class DiceLoss(nn.Module):
def __init__(self, smooth=1.):
super(DiceLoss, self).__init__()
self.smooth = smooth
def forward(self, logits, targets):
probs = torch.sigmoid(logits)
intersection = (probs * targets).sum()
union = probs.sum() + targets.sum()
dice = (2. * intersection + self.smooth) / (union + self.smooth)
return 1 - dice
7. 实际应用中的调整
原始UNet设计针对的是医学图像,但在不同应用中可能需要调整:
输入输出尺寸
- 原始UNet:输入572×572,输出388×388
- 现代实现:通常使用填充保持输入输出尺寸一致
深度调整
- 小数据集:减少层数防止过拟合
- 复杂场景:增加层数或通道数
# 浅层UNet变体示例
class ShallowUNet(nn.Module):
def __init__(self, n_channels, n_classes):
super().__init__()
self.inc = DoubleConv(n_channels, 32)
self.down1 = Down(32, 64)
self.down2 = Down(64, 128)
self.up1 = Up(128, 64)
self.up2 = Up(64, 32)
self.outc = nn.Conv2d(32, n_classes, kernel_size=1)
def forward(self, x):
x1 = self.inc(x)
x2 = self.down1(x1)
x3 = self.down2(x2)
x = self.up1(x3, x2)
x = self.up2(x, x1)
return self.outc(x)
理解UNet的关键在于实践。建议你在实现完整模型后,尝试以下实验:
- 移除所有跳跃连接,观察性能变化
- 尝试不同的上采样方法(双线性插值、最近邻等)
- 可视化中间特征图,直观理解各层学习的内容
更多推荐


所有评论(0)