从零构建UNet:用PyTorch拆解跳跃连接的秘密

当你第一次看到UNet的结构图时,那个对称的U型设计和那些横跨左右的连接线可能让你觉得既美观又神秘。但真正动手实现时,你是否遇到过这样的困惑:"这些跳跃连接到底在传输什么?为什么要在特定位置拼接特征图?"本文将带你用PyTorch从零开始构建一个完整的UNet,在每个关键节点打印张量形状,让你亲眼见证数据是如何在这个精巧的网络中流动的。

1. 理解UNet的核心设计

UNet最初是为医学图像分割设计的,它的成功很大程度上归功于其独特的架构。与传统的编码器-解码器结构不同,UNet引入了跳跃连接(Skip Connection),这使得网络能够同时利用低级的细节信息和高级的语义信息。

为什么跳跃连接如此重要?

  • 信息互补 :下采样路径捕获的是抽象的语义信息("这是什么"),而上采样路径需要精确的定位信息("在哪里")
  • 梯度流动 :跳跃连接为深层网络提供了更直接的梯度传播路径
  • 数据效率 :在医学影像等小数据集场景下,这种设计能更充分地利用有限的数据
import torch
import torch.nn as nn
from torch.nn import functional as F

# 打印设备信息
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

2. 构建基础模块

在搭建完整的UNet之前,我们需要先创建几个基础构建块。这些模块就像乐高积木一样,将被组合成最终的模型。

2.1 双卷积块

每个UNet的层级都包含两个连续的3×3卷积,这是特征提取的基本单元:

class DoubleConv(nn.Module):
    """(卷积 => [BN] => ReLU) * 2"""
    
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
    
    def forward(self, x):
        print(f"DoubleConv input shape: {x.shape}")  # 调试输出
        return self.double_conv(x)

2.2 下采样模块

下采样(也称为收缩路径)通过最大池化逐步降低空间分辨率,同时增加通道数:

class Down(nn.Module):
    """下采样模块:最大池化后接双卷积"""
    
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )
    
    def forward(self, x):
        print(f"Down input shape: {x.shape}")  # 调试输出
        return self.maxpool_conv(x)

2.3 上采样模块

上采样(扩展路径)通过转置卷积实现分辨率提升,并与跳跃连接的特征图拼接:

class Up(nn.Module):
    """上采样模块:转置卷积后与跳跃连接拼接,再接双卷积"""
    
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
        self.conv = DoubleConv(in_channels, out_channels)
    
    def forward(self, x1, x2):
        print(f"Up input shapes - x1: {x1.shape}, x2: {x2.shape}")  # 调试输出
        
        x1 = self.up(x1)
        # 计算填充以确保尺寸匹配
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]
        
        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
        
        x = torch.cat([x2, x1], dim=1)
        print(f"After concatenation shape: {x.shape}")  # 调试输出
        return self.conv(x)

3. 组装完整的UNet

现在我们可以将这些模块组合成完整的UNet结构。我们将按照经典的UNet架构,包含4个下采样和4个上采样阶段。

class UNet(nn.Module):
    def __init__(self, n_channels=3, n_classes=1):
        super(UNet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        
        # 下采样路径
        self.inc = DoubleConv(n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        self.down4 = Down(512, 1024)
        
        # 上采样路径
        self.up1 = Up(1024, 512)
        self.up2 = Up(512, 256)
        self.up3 = Up(256, 128)
        self.up4 = Up(128, 64)
        
        # 输出层
        self.outc = nn.Conv2d(64, n_classes, kernel_size=1)
    
    def forward(self, x):
        print(f"Initial input shape: {x.shape}\n")
        
        # 下采样
        x1 = self.inc(x)
        print(f"After inc shape: {x1.shape}\n")
        
        x2 = self.down1(x1)
        print(f"After down1 shape: {x2.shape}\n")
        
        x3 = self.down2(x2)
        print(f"After down2 shape: {x3.shape}\n")
        
        x4 = self.down3(x3)
        print(f"After down3 shape: {x4.shape}\n")
        
        x5 = self.down4(x4)
        print(f"After down4 shape: {x5.shape}\n")
        
        # 上采样
        x = self.up1(x5, x4)
        print(f"After up1 shape: {x.shape}\n")
        
        x = self.up2(x, x3)
        print(f"After up2 shape: {x.shape}\n")
        
        x = self.up3(x, x2)
        print(f"After up3 shape: {x.shape}\n")
        
        x = self.up4(x, x1)
        print(f"After up4 shape: {x.shape}\n")
        
        # 输出
        logits = self.outc(x)
        print(f"Final output shape: {logits.shape}")
        return logits

4. 数据流动可视化

让我们创建一个测试输入,观察数据在网络中的流动过程:

# 创建测试输入
test_input = torch.randn((1, 3, 572, 572)).to(device)
print("测试输入张量形状:", test_input.shape)

# 初始化模型
model = UNet().to(device)

# 前向传播
with torch.no_grad():
    output = model(test_input)

运行这段代码,你将在控制台看到类似以下的输出:

测试输入张量形状: torch.Size([1, 3, 572, 572])
Initial input shape: torch.Size([1, 3, 572, 572])

DoubleConv input shape: torch.Size([1, 3, 572, 572])
After inc shape: torch.Size([1, 64, 572, 572])

Down input shape: torch.Size([1, 64, 572, 572])
DoubleConv input shape: torch.Size([1, 64, 286, 286])
After down1 shape: torch.Size([1, 128, 286, 286])

[...更多中间形状输出...]

Final output shape: torch.Size([1, 1, 388, 388])

5. 跳跃连接的内部机制

跳跃连接是UNet最精妙的设计,但也是最容易让人困惑的部分。让我们深入分析一个具体的上采样步骤:

# 假设我们正在处理up1阶段
x5_shape = torch.Size([1, 1024, 35, 35])  # 来自最深层的特征
x4_shape = torch.Size([1, 512, 35, 35])   # 对应的跳跃连接特征

# 上采样过程
up = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
x5_up = up(x5)  # 形状变为 [1, 512, 70, 70]

# 裁剪x4以匹配尺寸
x4_cropped = center_crop(x4, [70, 70])

# 拼接操作
x = torch.cat([x4_cropped, x5_up], dim=1)  # 形状变为 [1, 1024, 70, 70]

为什么需要裁剪? 由于卷积的舍入误差,特征图尺寸可能会有微小差异。UNet原始论文中使用了镜像填充的卷积,确保所有中间特征图尺寸都能被2整除。

6. 训练技巧与优化

实现UNet结构只是第一步,要让���型真正发挥作用,还需要注意以下训练细节:

数据增强策略

  • 随机旋转和翻转
  • 弹性变形(特别适用于医学图像)
  • 灰度值变化
# 示例数据增强变换
train_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.RandomRotation(20),
    transforms.ColorJitter(brightness=0.1, contrast=0.1),
    transforms.ToTensor()
])

损失函数选择

  • 二分类任务:BCEWithLogitsLoss
  • 多分类任务:CrossEntropyLoss
  • 类别不平衡:DiceLoss或FocalLoss
# Dice Loss实现示例
class DiceLoss(nn.Module):
    def __init__(self, smooth=1.):
        super(DiceLoss, self).__init__()
        self.smooth = smooth
    
    def forward(self, logits, targets):
        probs = torch.sigmoid(logits)
        intersection = (probs * targets).sum()
        union = probs.sum() + targets.sum()
        dice = (2. * intersection + self.smooth) / (union + self.smooth)
        return 1 - dice

7. 实际应用中的调整

原始UNet设计针对的是医学图像,但在不同应用中可能需要调整:

输入输出尺寸

  • 原始UNet:输入572×572,输出388×388
  • 现代实现:通常使用填充保持输入输出尺寸一致

深度调整

  • 小数据集:减少层数防止过拟合
  • 复杂场景:增加层数或通道数
# 浅层UNet变体示例
class ShallowUNet(nn.Module):
    def __init__(self, n_channels, n_classes):
        super().__init__()
        self.inc = DoubleConv(n_channels, 32)
        self.down1 = Down(32, 64)
        self.down2 = Down(64, 128)
        self.up1 = Up(128, 64)
        self.up2 = Up(64, 32)
        self.outc = nn.Conv2d(32, n_classes, kernel_size=1)
    
    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x = self.up1(x3, x2)
        x = self.up2(x, x1)
        return self.outc(x)

理解UNet的关键在于实践。建议你在实现完整模型后,尝试以下实验:

  1. 移除所有跳跃连接,观察性能变化
  2. 尝试不同的上采样方法(双线性插值、最近邻等)
  3. 可视化中间特征图,直观理解各层学习的内容
Logo

免费领 200 小时云算力,进群参与显卡、AI PC 幸运抽奖

更多推荐