简介

很早之前看了unet3+医学图像分割的论文,本来想直接去github找keras/Tensorflow的实现,奈何找到的似乎都和源码有一些出入,于是自己按照论文和源码写了一下,不过也不能保证和源码完全一致,发出来抛砖引玉。很多讲unet3+的博客都写的挺不错的,要想了解全文可以看看这篇翻译【UNet3+(UNet+++)论文解读 玖零猴】​,这篇文章也简单讲一下自己的理解。

unet3+论文
源码(Pytorch)


一、unet3+

简单来说,unet3+有三个特点:
1 跨尺度连接,防止语义在下采样/上采样之间存在损失
2 全尺度深监督,学习深层次的特征表示
3 为了消除医学图像中噪声导致的假阳性分割,提出一个分类指导模块
4 一个新的混合损失函数(TODO)

呃,前面三点其实各有槽点,后面再说


unet3+的网络结构如上图,总的来说还是非常易懂的,作者认为unet和unet++都没有做到跨尺度的特征图连接,于是想到将编码器不同尺度地信息传递到解码器,解码器中的信息也进行了跨层传递,以此减少信息丢失(真是简单粗暴=_=)。
在这里插入图片描述
以解码器3为例,解码器3融合了编码器1、2、3和解码器4、5的特征,这些特征通过最大池化(来自编码器的特征)或上采样(来自解码器的特征)调整到和解码器3一样的特征图大小,并且通过卷积层(源码里是卷积+BN+ReLu)将特征数调整到一致。这些拼接的特征图再经过一个卷积+BN+ReLu块输出特征就OK。
在这里插入图片描述
这张图解释了另外两个特点,一个是全尺度深监督,另一个是分类指导模块(CGM)。
全尺度深监督是针对所有解码器每一层的输出计算损失函数。
为了防止噪声导致的假阳性分割,作者提出了分类指导模块。分类指导模块是添加在网络瓶颈层(编码器底层,En5)的模块,这一层网络最深,特征图数量最多,且特征图最小,可能过滤掉了一定的噪声。作者在这一层后面添加了一个小的分类头(Dropout + Conv1x1 + Pooling + Sigmoid),这个分类头输出一个概率,表示输入图像中有无目标器官,将这个分类结果和分割头相乘,可以消除假阳性。

结果比较,直接看图叭:
在这里插入图片描述

特点讲完了,说说槽点:
1 全尺度连接好是好,而且作者特地提到了,unet3+的参数是少于unet和unet++的,但实际上训练需要的时间和占用的内存好像都更多一些,似乎是因为unet3+用到了更多的卷积操作(比如,unet解码器每层只需要2次卷积,但看看上面的Fig.2,unet3+的每层解码器需要6次卷积)
2 还没想好
3 CGM只是一个简单的模块,在我自己的实验中,就算加了Dropout也很快就过拟合了,图像分割头的验证集损失还在降低,CGM这边的损失函数却已经不降反升了。

二、完整代码(keras)

注:小孩子不懂事,代码写着玩的,不一定正确,如果有问题欢迎指出和讨论,转载请注明出处。
CGM输出这块的实现还是有待商榷的,我的代码里CGM和分割掩膜是分别输出的,所以后面要手动相乘一下。

1.引入库

import tensorflow as tf
import numpy as np
from keras.models import Model
from keras.layers import Conv2D, Input, concatenate, MaxPooling2D, UpSampling2D, Activation, BatchNormalization, LayerNormalization, Dropout, GlobalMaxPooling2D

2.辅助函数

# helper function to build unet3+
def normalization(input_tensor, normalization):

    if normalization=='batch':
        return(BatchNormalization()(input_tensor))
    elif normalization=='layer':
        return(LayerNormalization()(input_tensor))
    elif normalization == None:
        return input_tensor
    else:
        raise ValueError('Invalid normalization')

def conv2d_block(input_tensor, filters, kernel_size, 
                norm_type, use_residual, act_type='relu',
                double_features = False, dilation=[1, 1]):

    x = Conv2D(filters, kernel_size, padding='same', dilation_rate=dilation[0], use_bias=False, kernel_initializer='he_normal')(input_tensor)
    x = normalization(x, norm_type)
    x = Activation(act_type)(x)

    if double_features:
        filters *= 2

    x = Conv2D(filters, kernel_size, padding='same', dilation_rate=dilation[1], use_bias=False, kernel_initializer='he_normal')(x)
    x = normalization(x, norm_type)

    if use_residual:
        if K.int_shape(input_tensor)[-1] != K.int_shape(x)[-1]:
            shortcut = Conv2D(filters, kernel_size=1, padding='same', use_bias=False, kernel_initializer='he_normal')(input_tensor)
            shortcut = normalization(shortcut, norm_type)
            x = add([x, shortcut])
        else:
            x = add([x, input_tensor])

    x = Activation(act_type)(x)

    return x

def down_layer_2d(input_tensor, down_pattern, filters, norm_type=None):
    if down_pattern == 'maxpooling':
        x = MaxPooling2D(pool_size=(2, 2))(input_tensor)
    elif down_pattern == 'avgpooling':
        x = AveragePooling2D(pool_size=(2, 2))(input_tensor)
    elif down_pattern == 'conv':
        x = Conv2D(filters, kernel_size=(2, 2), strides=(2, 2), padding='same', use_bias=False if norm_type is None else True, kernel_initializer='he_normal')(input_tensor)
        normalization(x, norm_type)
    elif down_pattern == 'normconv':
        x = normalization(input_tensor, norm_type)
        x = Conv2D(filters, kernel_size=(2, 2), strides=(2, 2), padding='same', kernel_initializer='he_normal')(x)
    else:
        raise ValueError('Invalid down_pattern')
    return x

def conv_norm_act(input_tensor, filters, kernel_size , norm_type='batch', act_type='relu', dilation=1):
    output_tensor = Conv2D(filters, kernel_size, padding='same', dilation_rate=(dilation, dilation), use_bias=False if norm_type is not None else True, kernel_initializer='he_normal')(input_tensor)
    output_tensor = normalization(output_tensor, normalization=norm_type)
    output_tensor = Activation(act_type)(output_tensor)
    return output_tensor

def aggregate(l1, l2, l3, l4, l5, filters, kernel_size, norm_type='batch', act_type='relu'):
    out = concatenate([l1, l2, l3, l4, l5], axis = -1)
    out = Conv2D(filters * 5, kernel_size, padding = 'same', use_bias=False if norm_type is not None else True, kernel_initializer = 'he_normal')(out)
    out = normalization(out, norm_type)
    out = Activation(act_type)(out)

    return out

def cgm_block(input_tensor, class_num, dropout_rate = 0.):
    x = Dropout(rate = dropout_rate)(input_tensor)
    x = Conv2D(class_num, 1, padding='same', kernel_initializer='he_normal')(x)
#     x = BatchNormalization()(x)
    x = GlobalMaxPooling2D()(x) # 用全局最大池化代替原文中的自适应最大池化,这里的效果应该是一样的
    x = Activation('sigmoid', name='cgm_output')(x)
    # x = Lambda(lambda x: K.expand_dims(x, axis=1))(x)
    # x = Lambda(lambda x: K.expand_dims(x, axis=1), name = 'cgm_output')(x)
#     x = Reshape((batch_size, 1, 1, class_num))(x)

    return x

3.搭建网络

# build unet3+ model
def unet3p_2d(input_shape, initial_features=32, kernel_size=3,
              class_num=1, norm_type='batch', double_features=False,
              use_residual=False, down_pattern='maxpooling', using_deep_supervision=True, 
              using_cgm=False, cgm_drop_rate=0.5, show_summary=True):
    '''
    input_shape: (height, width, channel)
    initial_features: int, 初始特征图数量,每次下采样特征图数量加倍, unet3+原文中用的是64
    kernel_size: int, 卷积核大小
    class_num: int, 图像分割的类别数
    norm_type: str, 标准化方式, 'batch' 或 'layer', unet3+使用的是BatchNormalization
    double_features: bool, 在conv2d_block模块中是否在第二个卷积中将特征图数量翻倍,3dunet论文中提出该方法可以避免瓶颈问题,通常可以设为False
    use_residual: bool, 编码器部分是否使用残差连接
    down_pattern: str, 下采样方式, 'maxpooling' 或 'avgpooling' 或 'conv' 或 'normconv', unet3+使用的是MaxPooling
    using_deep_supervision: bool, 是否使用全尺度深度监督
    using_cgm: bool, 是否使用分类指导模块(CGM)
    cgm_drop_rate: float, CGM模块中Dropout比率
    show_summary: bool, 是否显示模型概况
    '''

    if class_num == 1:
        last_layer_activation = 'sigmoid'
    else:
        last_layer_activation = 'softmax'
    
    inputs = Input(input_shape)

    xe1 = conv2d_block(input_tensor=inputs, filters=initial_features, kernel_size=kernel_size, 
                    norm_type=norm_type, double_features=double_features, use_residual=use_residual)
    xe1_pool = down_layer_2d(input_tensor=xe1, down_pattern=down_pattern, filters=initial_features)

    xe2 = conv2d_block(input_tensor=xe1_pool, filters=initial_features * 2, kernel_size=kernel_size, 
                       norm_type=norm_type, double_features=double_features, use_residual=use_residual)
    xe2_pool = down_layer_2d(input_tensor=xe2, down_pattern=down_pattern, filters=initial_features * 2)

    xe3 = conv2d_block(input_tensor=xe2_pool, filters=initial_features * 4, kernel_size=kernel_size, 
                       norm_type=norm_type, double_features=double_features, use_residual=use_residual)
    xe3_pool = down_layer_2d(input_tensor=xe3, down_pattern=down_pattern, filters=initial_features * 4)

    xe4 = conv2d_block(input_tensor=xe3_pool, filters=initial_features * 8, kernel_size=kernel_size, 
                       norm_type=norm_type, double_features=double_features, use_residual=use_residual)
    xe4_pool = down_layer_2d(input_tensor=xe4, down_pattern=down_pattern, filters=initial_features * 8)

    xe5 = conv2d_block(input_tensor=xe4_pool, filters=initial_features * 16, kernel_size=kernel_size, 
                       norm_type=norm_type, double_features=double_features, use_residual=use_residual)

    if using_cgm:
        cgm = cgm_block(input_tensor = xe5 , class_num = class_num ,dropout_rate = cgm_drop_rate)

    xd4_from_xe5 = UpSampling2D(size=(2,2), interpolation='bilinear')(xe5)
    xd4_from_xe5 = conv_norm_act(input_tensor=xd4_from_xe5, filters=initial_features, kernel_size=kernel_size ,norm_type=norm_type)
    xd4_from_xe4 = conv_norm_act(input_tensor=xe4, filters=initial_features, kernel_size=kernel_size ,norm_type=norm_type)
    xd4_from_xe3 = MaxPooling2D(pool_size = (2, 2))(xe3)
    xd4_from_xe3 = conv_norm_act(input_tensor=xd4_from_xe3, filters=initial_features, kernel_size=kernel_size ,norm_type=norm_type)
    xd4_from_xe2 = MaxPooling2D(pool_size = (4, 4))(xe2)
    xd4_from_xe2 = conv_norm_act(input_tensor=xd4_from_xe2, filters=initial_features, kernel_size=kernel_size ,norm_type=norm_type)
    xd4_from_xe1 = MaxPooling2D(pool_size = (8, 8))(xe1)
    xd4_from_xe1 = conv_norm_act(input_tensor=xd4_from_xe1, filters=initial_features, kernel_size=kernel_size ,norm_type=norm_type)
    xd4 = aggregate(xd4_from_xe5, xd4_from_xe4, xd4_from_xe3, xd4_from_xe2, xd4_from_xe1, filters=initial_features, kernel_size=kernel_size, norm_type=norm_type)

    xd3_from_xe5 = UpSampling2D(size=(4, 4), interpolation='bilinear')(xe5)
    xd3_from_xe5 = conv_norm_act(input_tensor=xd3_from_xe5, filters=initial_features, kernel_size=kernel_size ,norm_type=norm_type)
    xd3_from_xd4 = UpSampling2D(size=(2, 2), interpolation='bilinear')(xd4)
    xd3_from_xd4 = conv_norm_act(input_tensor=xd3_from_xd4, filters=initial_features, kernel_size=kernel_size ,norm_type=norm_type)
    xd3_from_xe3 = conv_norm_act(input_tensor=xe3, filters=initial_features, kernel_size=kernel_size ,norm_type=norm_type)
    xd3_from_xe2 = MaxPooling2D(pool_size = (2, 2))(xe2)
    xd3_from_xe2 = conv_norm_act(input_tensor=xd3_from_xe2, filters=initial_features, kernel_size=kernel_size ,norm_type=norm_type)
    xd3_from_xe1 = MaxPooling2D(pool_size = (4, 4))(xe1)
    xd3_from_xe1 = conv_norm_act(input_tensor=xd3_from_xe1, filters=initial_features, kernel_size=kernel_size ,norm_type=norm_type)
    xd3 = aggregate(xd3_from_xe5, xd3_from_xd4, xd3_from_xe3, xd3_from_xe2, xd3_from_xe1, filters=initial_features, kernel_size=kernel_size, norm_type=norm_type)

    xd2_from_xe5 = UpSampling2D(size=(8, 8), interpolation='bilinear')(xe5)
    xd2_from_xe5 = conv_norm_act(input_tensor=xd2_from_xe5, filters=initial_features, kernel_size=kernel_size ,norm_type=norm_type)
    xd2_from_xd4 = UpSampling2D(size=(4, 4), interpolation='bilinear')(xd4)
    xd2_from_xd4 = conv_norm_act(input_tensor=xd2_from_xd4, filters=initial_features, kernel_size=kernel_size ,norm_type=norm_type)
    xd2_from_xd3 = UpSampling2D(size=(2, 2), interpolation='bilinear')(xd3)
    xd2_from_xd3 = conv_norm_act(input_tensor=xd2_from_xd3, filters=initial_features, kernel_size=kernel_size ,norm_type=norm_type)
    xd2_from_xe2 = conv_norm_act(input_tensor=xe2, filters=initial_features, kernel_size=kernel_size ,norm_type=norm_type)
    xd2_from_xe1 = MaxPooling2D(pool_size = (2, 2))(xe1)
    xd2_from_xe1 = conv_norm_act(input_tensor=xd2_from_xe1, filters=initial_features, kernel_size=kernel_size ,norm_type=norm_type)
    xd2 = aggregate(xd2_from_xe5, xd2_from_xd4, xd2_from_xd3, xd2_from_xe2, xd2_from_xe1, filters=initial_features, kernel_size=kernel_size, norm_type=norm_type)

    xd1_from_xe5 = UpSampling2D(size=(16, 16), interpolation='bilinear')(xe5)
    xd1_from_xe5 = conv_norm_act(input_tensor=xd1_from_xe5, filters=initial_features, kernel_size=kernel_size ,norm_type=norm_type)
    xd1_from_xd4 = UpSampling2D(size=(8, 8), interpolation='bilinear')(xd4)
    xd1_from_xd4 = conv_norm_act(input_tensor=xd1_from_xd4, filters=initial_features, kernel_size=kernel_size ,norm_type=norm_type)
    xd1_from_xd3 = UpSampling2D(size=(4, 4), interpolation='bilinear')(xd3)
    xd1_from_xd3 = conv_norm_act(input_tensor=xd1_from_xd3, filters=initial_features, kernel_size=kernel_size ,norm_type=norm_type)
    xd1_from_xd2 = UpSampling2D(size=(2, 2), interpolation='bilinear')(xd2)
    xd1_from_xd2 = conv_norm_act(input_tensor=xd1_from_xd2, filters=initial_features, kernel_size=kernel_size ,norm_type=norm_type)
    xd1_from_xe1 = conv_norm_act(input_tensor=xe1, filters=initial_features, kernel_size=kernel_size ,norm_type=norm_type)
    xd1 = aggregate(xd1_from_xe5, xd1_from_xd4, xd1_from_xd3, xd1_from_xd2, xd1_from_xe1, filters=initial_features, kernel_size=kernel_size, norm_type=norm_type)

    if using_deep_supervision:
        xd55 = Conv2D(class_num, kernel_size, activation=None, padding='same')(xe5)
        xd55 = UpSampling2D(size=(16, 16))(xd55)
        xd55 = Activation(last_layer_activation, name='output_de5')(xd55)

        xd44 = Conv2D(class_num, kernel_size, activation=None, padding='same')(xd4)
        xd44 = UpSampling2D(size=(8, 8))(xd44)
        xd44 = Activation(last_layer_activation, name='output_de4')(xd44)

        xd33 = Conv2D(class_num, kernel_size, activation=None, padding='same')(xd3)
        xd33 = UpSampling2D(size=(4, 4))(xd33)
        xd33 = Activation(last_layer_activation, name='output_de3')(xd33)

        xd22 = Conv2D(class_num, kernel_size, activation=None, padding='same')(xd2)
        xd22 = UpSampling2D(size=(2, 2))(xd22)
        xd22 = Activation(last_layer_activation, name='output_de2')(xd22)

        xd11 = Conv2D(class_num, kernel_size, activation=None, padding='same')(xd1)
        xd11 = Activation(last_layer_activation, name='output_de1')(xd11)

        if using_cgm: outputs=[xd11, xd22, xd33, xd44, xd55, cgm]
        else: outputs=[xd11, xd22, xd33, xd44, xd55]

    else:
        conv_output = Conv2D(class_num, 1, activation=last_layer_activation, name='output')(xd1)
        if using_cgm: outputs=[conv_output, cgm]
        else: outputs = conv_output

    model = Model(inputs, outputs)
    if show_summary: model.summary()

    return model

4.创建模型

如果以上代码都在同一个.py文件下,可以加上以下代码尝试构建网络:

if __name__ == '__main__':
    model = unet3p_2d(input_shape=(256, 256, 1), initial_features=32, kernel_size=3,
                      class_num=1, norm_type='batch', double_features=False,
                      use_residual=False, down_pattern='maxpooling', 
                      using_deep_supervision=True, using_cgm=False, show_summary=True)
    model = unet3p_2d(input_shape=(256, 256, 1), initial_features=32, kernel_size=3,
                      class_num=1, norm_type='batch', double_features=False,
                      use_residual=False, down_pattern='maxpooling', 
                      using_deep_supervision=True, using_cgm=True, show_summary=True)
    model = unet3p_2d(input_shape=(256, 256, 1), initial_features=32, kernel_size=3,
                      class_num=1, norm_type='batch', double_features=False,
                      use_residual=False, down_pattern='maxpooling', 
                      using_deep_supervision=False, using_cgm=False, show_summary=True)

如果用到了预训练的主干网络,需要修改下编码器(En)部分。

感觉自己好菜,不知道能不能顺利be yeah,哎TAT

Logo

旨在为数千万中国开发者提供一个无缝且高效的云端环境,以支持学习、使用和贡献开源项目。

更多推荐