AAAI 2025 | 对比驱动的通用医学图像分割框架,即插即用,高效涨点!
ConDSeg是一个通用医学图像分割框架,通过对比驱动特征增强解决边界模糊和目标共现问题。核心创新包括:1)一致性强化训练策略增强编码器鲁棒性;2)语义信息解耦模块将特征分为前景/背景/不确定区域;3)对比驱动特征聚合模块利用对比信息指导特征融合;4)尺寸感知解码器处理不同大小目标。该框架在结肠息肉、腺体等多种医学图像分割任务中表现优异,显著提升边界精度和抗干扰能力。关键模块可即插即用,代码已开源
1. 基本信息
-
标题: ConDSeg: A General Medical Image Segmentation Framework via Contrast-Driven Feature Enhancement (ConDSeg: 一个通过对比驱动特征增强的通用医学图像分割框架)
-
论文来源: https://arxiv.org/pdf/2412.08345
- 作者与单位:
-
Mengqi Lei, Haochen Wu, Xinhua Lv: 中国地质大学(武汉)
-
Xin Wang: 百度公司(Baidu Inc)
-
2. 核心创新点
-
一致性强化 (Consistency Reinforcement, CR) 训练策略:设计了一种预训练策略,通过对原始图像和强增广图像的输出进行一致性约束,显著增强编码器在弱光照、低对比度等恶劣环境下的特征提取鲁棒性。
-
语义信息解耦 (Semantic Information Decoupling, SID) 模块:提出将编码器的高层特征解耦为前景、背景和不确定性区域三个部分,并通过特定损失函数在训练中学习减少不确定性,从而更精确地区分前景与背景。
-
对比驱动特征聚合 (Contrast-Driven Feature Aggregation, CDFA) 模块:利用 SID 模块解耦出的前景和背景对比特征,来指导多层次特征的融合与关键特征的增强,有效应对目标与复杂背景的区分难题。
-
尺寸感知解码器 (Size-Aware Decoder, SA-Decoder):针对医学图像中普遍存在的共现现象,设计了能够分别预测不同尺寸实体的多个并行解码器,避免模型学习到错误的上下文关联,提高对不同尺寸目标的定位准确性。
➔➔➔➔点击查看原文,获取本文及其他精选即插即用模块集合https://mp.weixin.qq.com/s/1j1hEtFGwIeCrjt-ta0rYg
3. 方法详解
整体结构概述: ConDSeg 是一个两阶段的分割框架。第一阶段,采用一致性强化 (CR) 策略单独预训练编码器,以增强其在多变环境下的鲁棒性。第二阶段,以较低学习率微调该编码器,并训练完整的分割网络。该网络主要包含四个步骤:首先,ResNet-50 编码器提取多层次特征图;其次,语义信息解耦 (SID) 模块将最高层特征分解为前景、背景和不确定性特征;接着,对比驱动特征聚合 (CDFA) 模块利用这些对比信息来指导各层级特征的融合与增强;最后,多个并行的尺寸感知解码器 (SA-Decoder) 分别处理不同层级的输出,对不同尺寸的实体进行预测,并将结果融合得到最终的分割掩码。
ConDSeg 的整体框架图
步骤分解:
-
一致性强化 (CR):
-
此为第一阶段的预训练策略。将原始图像
X
和经过强数据增广(如亮度、对比度、饱和度随机变换)的图像X'
分别输入编码器和一 个简单的辅助预测头,得到两个预测掩码M1
和M2
。 -
通过最小化两个掩码与真值的分割损失 (
L_mask
),以及它们之间的**一致性损失 (L_cons
)**来训练编码器。这迫使编码器在不同光照和对比度条件下也能提取到稳定且高质量的特征。 -
一致性损失
L_cons
通过交叉计算两个预测掩码的 BCE 损失来实现,避免了传统方法(如 KL 散度)可能出现的数值不稳定问题。
-
-
语义信息解耦 (SID):
-
该模块接收编码器最深层的特征图
f4
,并通过三个并行的分支将其解耦为前景特征f_fg
、背景特征f_bg
和不确定性区域特征f_uc
。 -
通过一个辅助头,这三个特征分别生成对应的预测掩码
M_fg
,M_bg
,M_uc
。 -
设计了**互补性损失 (
L_compl
)**,确保每个像素点只属于三者之一,从而在训练中逐步减少不确定区域。 -
同时,为前景和背景损失引入了**动态惩罚项
β1, β2
**,根据预测区域面积的大小动态调整损失权重,以关注小目标的分割。
-
-
对比驱动特征聚合 (CDFA):
-
此模块旨在利用
f_fg
和f_bg
提供的对比信息来指导多级特征的融合。 -
[在此处插入图 4:CDFA 结构图]
-
在每个空间位置,CDFA 通过
f_fg
和f_bg
生成两组不同的注意力权重 (A_fg
,A_bg
),并对局部窗口内的特征值向量V
进行两次加权,从而实现前景和背景对比信息的注入,增强关键特征。 -
加权过程如下,其中
⊗
表示矩阵乘法:
-
-
尺寸感知解码器 (SA-Decoder):
-
为解决共现现象导致的模型误判,论文设计了三个独立的解码器,分别用于预测小、中、大尺寸的目标。
-
Decoder_s
、Decoder_m
和Decoder_l
分别接收来自不同层级 CDFA 模块的输出特征,因为浅层特征适合小目标,深层特征适合大目标。 -
最后,将三个解码器的输出沿通道维度拼接融合,生成最终的分割结果。这种设计使得模型能够有效地区分图像中的不同实体,而不是依赖于它们共同出现的上下文模式。
-
4. 即插即用模块作用
说明:ConDSeg 是一个完整的框架,但其核心模块 CR 训练策略、SID 模块、CDFA 模块 和 SA-Decoder 具有很强的通用性,可以被迁移到其他分割网络中。
适用场景
-
核心任务: 医学图像分割。
- 具体场景:
-
内窥镜图像分割 (如:结肠息肉分割)。
-
组织病理学图像分割 (如:结直肠腺体分割)。
-
皮肤镜图像分割 (如:皮肤恶性病变分割)。
-
其他具有模糊边界或共现现象的分割任务(如:3D 多类别分割,论文在补充材料中验证了其在 Synapse 数据集上的有效性)。
-
主要作用
- 模拟/替代能力:
-
CR 策略: 模拟了在各种恶劣成像条件下(弱光、低对比度)的稳健特征提取能力。
-
SID 模块: 替代了传统的边界预测辅助任务,通过显式建模前景、背景和不确定性,让模型自发学习减少模糊区域。
-
- 性能提升:
-
大幅增强模型鲁棒性: 对光照、对比度变化不敏感,确保在不同质量的图像上表现稳定。
-
显著提升边缘分割精度: 通过解耦和对比,有效处理前景与背景之间的“软边界”问题。
-
有效克服共现干扰: 通过尺寸感知解码,避免模型因学习到错误的物体共现规律而产生误判(例如,在只有单个息肉时错误预测多个)。
-
提升收敛速度与性能上限: 两阶段训练策略和高效的模块设计共同作用,使得模型收敛更快,性能更优 (如 图 5 所示)。
-
总结
ConDSeg 是一个通过“解耦-对比-聚合”范式,系统性解决医学图像分割中“边界模糊”和“目标共现”两大核心挑战的通用框架。
➔➔➔➔点击查看原文,获取本文及其他精选即插即用模块集合https://mp.weixin.qq.com/s/1j1hEtFGwIeCrjt-ta0rYg
5. 即插即用模块
import torch
import torch.nn as nn
import math
import torch.nn.functional as F
class CBR(nn.Module):
def __init__(self, in_c, out_c, kernel_size=3, padding=1, dilation=1, stride=1, act=True):
super().__init__()
self.act = act
self.conv = nn.Sequential(
nn.Conv2d(in_c, out_c, kernel_size, padding=padding, dilation=dilation, bias=False, stride=stride),
nn.BatchNorm2d(out_c)
)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
x = self.conv(x)
if self.act == True:
x = self.relu(x)
return x
class ContrastDrivenFeatureAggregation(nn.Module):
def __init__(self, in_c, dim, num_heads, kernel_size=3, padding=1, stride=1,
attn_drop=0., proj_drop=0.):
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.kernel_size = kernel_size
self.padding = padding
self.stride = stride
self.head_dim = dim // num_heads
self.scale = self.head_dim ** -0.5
self.v = nn.Linear(dim, dim)
self.attn_fg = nn.Linear(dim, kernel_size ** 4 * num_heads)
self.attn_bg = nn.Linear(dim, kernel_size ** 4 * num_heads)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
self.unfold = nn.Unfold(kernel_size=kernel_size, padding=padding, stride=stride)
self.pool = nn.AvgPool2d(kernel_size=stride, stride=stride, ceil_mode=True)
self.input_cbr = nn.Sequential(
CBR(in_c, dim, kernel_size=3, padding=1),
CBR(dim, dim, kernel_size=3, padding=1),
)
self.output_cbr = nn.Sequential(
CBR(dim, dim, kernel_size=3, padding=1),
CBR(dim, dim, kernel_size=3, padding=1),
)
def forward(self, x, fg, bg):
x = self.input_cbr(x)
x = x.permute(0, 2, 3, 1)
fg = fg.permute(0, 2, 3, 1)
bg = bg.permute(0, 2, 3, 1)
B, H, W, C = x.shape
v = self.v(x).permute(0, 3, 1, 2)
v_unfolded = self.unfold(v).reshape(B, self.num_heads, self.head_dim,
self.kernel_size * self.kernel_size,
-1).permute(0, 1, 4, 3, 2)
attn_fg = self.compute_attention(fg, B, H, W, C, 'fg')
x_weighted_fg = self.apply_attention(attn_fg, v_unfolded, B, H, W, C)
v_unfolded_bg = self.unfold(x_weighted_fg.permute(0, 3, 1, 2)).reshape(B, self.num_heads, self.head_dim,
self.kernel_size * self.kernel_size,
-1).permute(0, 1, 4, 3, 2)
attn_bg = self.compute_attention(bg, B, H, W, C, 'bg')
x_weighted_bg = self.apply_attention(attn_bg, v_unfolded_bg, B, H, W, C)
x_weighted_bg = x_weighted_bg.permute(0, 3, 1, 2)
out = self.output_cbr(x_weighted_bg)
return out
def compute_attention(self, feature_map, B, H, W, C, feature_type):
attn_layer = self.attn_fg if feature_type == 'fg' else self.attn_bg
h, w = math.ceil(H / self.stride), math.ceil(W / self.stride)
feature_map_pooled = self.pool(feature_map.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)
attn = attn_layer(feature_map_pooled).reshape(B, h * w, self.num_heads,
self.kernel_size * self.kernel_size,
self.kernel_size * self.kernel_size).permute(0, 2, 1, 3, 4)
attn = attn * self.scale
attn = F.softmax(attn, dim=-1)
attn = self.attn_drop(attn)
return attn
def apply_attention(self, attn, v, B, H, W, C):
x_weighted = (attn @ v).permute(0, 1, 4, 3, 2).reshape(
B, self.dim * self.kernel_size * self.kernel_size, -1)
x_weighted = F.fold(x_weighted, output_size=(H, W), kernel_size=self.kernel_size,
padding=self.padding, stride=self.stride)
x_weighted = self.proj(x_weighted.permute(0, 2, 3, 1))
x_weighted = self.proj_drop(x_weighted)
return x_weighted
if __name__ == '__main__':
cdfa =ContrastDrivenFeatureAggregation(in_c=128, dim=128, num_heads=4)
# 输入特征图
x = torch.randn(1,128,32,32)
# 前景特征图
fg = torch.randn(1,128,32,32)
# 背景特征图
bg = torch.randn(1,128,32,32)
# 打印网络结构
print(cdfa)
#前向传播,输入张量x,fg,和bg
output = cdfa(x,fg,bg)
#打印输出张量的形状
print("input shape:", x.shape)
print("output shape:", output.shape)
更多推荐
所有评论(0)