参考链接1:GitHub链接:GitHub - qubvel/segmentation_models.pytorch: Segmentation models with pretrained backbones. PyTorch.

参考链接2:框架调用:https://github.com/WangZhenqing-RS/2021Tianchi_RS

链接1框架中模型包括以下几个,具体细节去原链接里看

这个语义分割框架包含了很多常见的模型,我这篇文章结合了下链接2他人在遥感大赛上的策略,调用了框架1的模型,虽然链接2已经可以直接用了,但我还是觉得使用起来要改的东西不少,我这里放一下使用起来相对简单点的版本,改动较大。使用之前最好去别人的原版下面好好看看,这些都是干嘛的,我保留了大部分策略,效果不错,把数据放好,模型选择下就可以直接训练

实际效果如下:

原图

 标签:

 预测结果:

 精度评价:

acc:  0.9674116041666667
acc_cls:  0.9086168591905439
iou:  [0.9643351  0.72580183]
miou:  0.8450684631527481
fwavacc:  0.939678139218996
class_recall:  0.8344991600686457
f1_score:  0.8411183893451285

1.数据准备(文件名字最好一致,不一致就去代码里改)

一级目录:

二级目录:

三级目录:里面就是图像和标签文件了,一一对应,名字相同

2.训练代码 

segmentation_models_pytorch就是链接1,可以直接通过pip安装

pytorch_toolbelt也需要安装

# -*- coding: utf-8 -*-
import time
import warnings
import numpy as np
from tqdm import tqdm
import torch
import torch.nn as nn
from torch.optim.swa_utils import AveragedModel, SWALR
import segmentation_models_pytorch as smp
from segmentation_models_pytorch.losses import DiceLoss, SoftCrossEntropyLoss, LovaszLoss
from pytorch_toolbelt import losses as L
from data import ImageFolder

warnings.filterwarnings('ignore')
torch.backends.cudnn.enabled = True

DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu' 

def iou_pytorch(outputs: torch.Tensor, labels: torch.Tensor, SMOOTH = 1e-6):
    # You can comment out this line if you are passing tensors of equal shape
    # But if you are passing output from UNet or something it will most probably
    # be with the BATCH x 1 x H x W shape
    outputs = outputs.squeeze(1)  # BATCH x 1 x H x W => BATCH x H x W

    intersection = (outputs & labels).float().sum((1, 2))  # Will be zero if Truth=0 or Prediction=0
    union = (outputs | labels).float().sum((1, 2))         # Will be zzero if both are 0
    
    iou = (intersection + SMOOTH) / (union + SMOOTH)  # We smooth our devision to avoid 0/0
    
    thresholded = torch.clamp(20 * (iou - 0.5), 0, 10).ceil() / 10  # This is equal to comparing with thresolds
    
    return thresholded.mean()  # Or thresholded.mean() if you are interested in average across the batch

def train(EPOCHES, BATCH_SIZE, data_root, channels, optimizer_name,
          model_path, swa_model_path, loss, early_stop):

    train_dataset = ImageFolder(data_root, mode='train')
    val_dataset = ImageFolder(data_root, mode='val')

    train_data_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size = BATCH_SIZE,
        shuffle=True,
        num_workers=0)
    
    val_data_loader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size = BATCH_SIZE,
        shuffle=True,
        num_workers=0)
    
    # 定义模型,优化器,损失函数
    # model = smp.UnetPlusPlus(
    #         encoder_name="efficientnet-b7",
    #         encoder_weights="imagenet",
    #         in_channels=channels,
    #         classes=10,
    # )
    # model = smp.UnetPlusPlus(
    #         encoder_name="timm-resnest101e",
    #         encoder_weights="imagenet",
    #         in_channels=channels,
    #         classes=2,
    # )
    #这里用UNET试试,注意二分类,classes应该是1
    model = smp.Unet(
    encoder_name="resnet34",        # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
    encoder_weights="imagenet",     # use `imagenet` pre-trained weights for encoder initialization
    in_channels=channels,                  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
    classes=1,                      # model output channels (number of classes in your dataset)
    activation='sigmoid',           #二分类需要换成sigmoid
    )

    model.to(DEVICE);
    # model.load_state_dict(torch.load(model_path))
    if(optimizer_name == "sgd"):
        optimizer = torch.optim.SGD(model.parameters(), 
                                    lr=1e-4, weight_decay=1e-3, momentum=0.9)
    else:
        optimizer = torch.optim.AdamW(model.parameters(),
                                      lr=1e-3, weight_decay=1e-3)
    # 余弦退火调整学习率
    scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
            optimizer, 
            T_0=2, # T_0就是初始restart的epoch数目
            T_mult=2, # T_mult就是重启之后因子,即每个restart后,T_0 = T_0 * T_mult
            eta_min=1e-5 # 最低学习率
            ) 
   
    if(loss == "SoftCE_dice"):  #mode: Loss mode 'binary', 'multiclass' or 'multilabel'
        # 损失函数采用SoftCrossEntropyLoss+DiceLoss
        # diceloss在一定程度上可以缓解类别不平衡,但是训练容易不稳定
        DiceLoss_fn = DiceLoss(mode='binary')   #多分类改为multiclass
        Bceloss_fn = nn.BCELoss()
        # 软交叉熵,即使用了标签平滑的交叉熵,会增加泛化性
        # SoftCrossEntropy_fn=SoftCrossEntropyLoss(smooth_factor=0.1) #用于多分类
        loss_fn = L.JointLoss(first=DiceLoss_fn, second=Bceloss_fn, first_weight=0.5, second_weight=0.5).cuda()
    else: 
        # 损失函数采用SoftCrossEntropyLoss+LovaszLoss
        # LovaszLoss是对基于子模块损失凸Lovasz扩展的mIoU损失的直接优化
        LovaszLoss_fn = LovaszLoss(mode='binary')
        # 软交叉熵,即使用了标签平滑的交叉熵,会增加泛化性
        SoftCrossEntropy_fn=SoftCrossEntropyLoss(smooth_factor=0.1) #这里我没有改,这里是多分类的,有需求就改下
        loss_fn = L.JointLoss(first=LovaszLoss_fn, second=SoftCrossEntropy_fn,
                              first_weight=0.5, second_weight=0.5).cuda()
    
    
    best_miou = 0
    best_miou_epoch = 0
    train_loss_epochs, val_mIoU_epochs, lr_epochs = [], [], []
    for epoch in range(1, EPOCHES+1):
        losses = []
        start_time = time.time()
        model.train()
        model.to(DEVICE)
        for image, target in tqdm(train_data_loader, ncols=20, total=len(train_data_loader)):
            image, target = image.to(DEVICE), target.to(DEVICE)
            optimizer.zero_grad()
            output = model(image)
            loss = loss_fn(output, target)
            loss.backward()
            optimizer.step()
            losses.append(loss.item())

        scheduler.step()

        val_iou = []
        val_data_loader_num = iter(val_data_loader)
        for val_img, val_mask in tqdm(val_data_loader_num,ncols=20,total=len(val_data_loader_num)):
            val_img, label = val_img.to(DEVICE), val_mask.to(DEVICE)
            predict = model(val_img)
            predict[predict >= 0.5] = 1
            predict[predict < 0.5] = 0
            predict = predict.type(torch.LongTensor)
            label = label.squeeze(1).type(torch.LongTensor)
            iou = iou_pytorch(predict, label)
            val_iou.append(iou)


        train_loss_epochs.append(np.array(losses).mean())
        val_mIoU_epochs.append(np.mean(val_iou))
        lr_epochs.append(optimizer.param_groups[0]['lr'])

        print('Epoch:' + str(epoch) + ' Loss:' + str(np.array(losses).mean()) + ' Val_IOU:' + str(np.mean(val_iou)) + ' Time_use:' + str((time.time()-start_time)/60.0))
 
        if best_miou < np.stack(val_iou).mean(0).mean():
            best_miou = np.stack(val_iou).mean(0).mean()
            best_miou_epoch = epoch
            torch.save(model.state_dict(), model_path)
            print("  valid mIoU is improved. the model is saved.")
        else:
            print("")
            if (epoch - best_miou_epoch) >= early_stop:
                break

    return train_loss_epochs, val_mIoU_epochs, lr_epochs


 
if __name__ == '__main__':
    EPOCHES = 500
    BATCH_SIZE = 8
    loss = "SoftCE_dice"
    #loss = "SoftCE_Lovasz"
    channels = 3
    optimizer_name = "adamw"

    data_root = "./dataset/build/"
    model_path = "./weights/"
    model_path += "_" + loss
    swa_model_path = model_path + "_swa.pth"
    model_path += ".pth"
    early_stop = 400
    train_loss_epochs, val_mIoU_epochs, lr_epochs = train(EPOCHES, BATCH_SIZE, data_root, channels, optimizer_name, model_path, swa_model_path, loss,early_stop)
    
    
    if(True):    
        import matplotlib.pyplot as plt
        epochs = range(1, len(train_loss_epochs) + 1)
        plt.plot(epochs, train_loss_epochs, 'r', label = 'train loss')
        plt.plot(epochs, val_mIoU_epochs, 'b', label = 'val mIoU')
        plt.title('train loss and val mIoU')
        plt.legend()
        plt.savefig("train loss and val mIoU.png",dpi = 300)
        plt.figure()
        plt.plot(epochs, lr_epochs, 'r', label = 'learning rate')
        plt.title('learning rate')
        plt.legend()
        plt.savefig("learning rate.png", dpi = 300)
        plt.show() 

下面是链接1中激活函数的位置,可以选择的很多,二分类用的是sigmoid

https://github.com/qubvel/segmentation_models.pytorch/blob/master/segmentation_models_pytorch/base/modules.py

不同的模型需要填写的参数可能不同,可能需要去链接1里的模型位置好好看下,具体链接:https://github.com/qubvel/segmentation_models.pytorch/tree/master/segmentation_models_pytorch

以deeplabV3为例,模型结构的位置如下

https://github.com/qubvel/segmentation_models.pytorch/blob/master/segmentation_models_pytorch/deeplabv3/model.py

可调整的参数就在初始化的地方,找到位置好好看下就行了

3.数据加载代码data.py

注意下,这里数据加载我resize了,不需要的可以注释掉,在own_data_loader和own_data_test_loader中

# -*- coding: utf-8 -*-
import torch
import torch.utils.data as data
from torch.autograd import Variable as V
from PIL import Image

import cv2
import numpy as np
import os
import scipy.misc as misc

def randomHueSaturationValue(image, hue_shift_limit=(-180, 180),
                             sat_shift_limit=(-255, 255),
                             val_shift_limit=(-255, 255), u=0.5):
    if np.random.random() < u:
        image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
        h, s, v = cv2.split(image)
        hue_shift = np.random.randint(hue_shift_limit[0], hue_shift_limit[1]+1)
        hue_shift = np.uint8(hue_shift)
        h += hue_shift
        sat_shift = np.random.uniform(sat_shift_limit[0], sat_shift_limit[1])
        s = cv2.add(s, sat_shift)
        val_shift = np.random.uniform(val_shift_limit[0], val_shift_limit[1])
        v = cv2.add(v, val_shift)
        image = cv2.merge((h, s, v))
        #image = cv2.merge((s, v))
        image = cv2.cvtColor(image, cv2.COLOR_HSV2BGR)

    return image

def randomShiftScaleRotate(image, mask,
                           shift_limit=(-0.0, 0.0),
                           scale_limit=(-0.0, 0.0),
                           rotate_limit=(-0.0, 0.0), 
                           aspect_limit=(-0.0, 0.0),
                           borderMode=cv2.BORDER_CONSTANT, u=0.5):
    if np.random.random() < u:
        height, width, channel = image.shape

        angle = np.random.uniform(rotate_limit[0], rotate_limit[1])
        scale = np.random.uniform(1 + scale_limit[0], 1 + scale_limit[1])
        aspect = np.random.uniform(1 + aspect_limit[0], 1 + aspect_limit[1])
        sx = scale * aspect / (aspect ** 0.5)
        sy = scale / (aspect ** 0.5)
        dx = round(np.random.uniform(shift_limit[0], shift_limit[1]) * width)
        dy = round(np.random.uniform(shift_limit[0], shift_limit[1]) * height)

        cc = np.math.cos(angle / 180 * np.math.pi) * sx
        ss = np.math.sin(angle / 180 * np.math.pi) * sy
        rotate_matrix = np.array([[cc, -ss], [ss, cc]])

        box0 = np.array([[0, 0], [width, 0], [width, height], [0, height], ])
        box1 = box0 - np.array([width / 2, height / 2])
        box1 = np.dot(box1, rotate_matrix.T) + np.array([width / 2 + dx, height / 2 + dy])

        box0 = box0.astype(np.float32)
        box1 = box1.astype(np.float32)
        mat = cv2.getPerspectiveTransform(box0, box1)
        image = cv2.warpPerspective(image, mat, (width, height), flags=cv2.INTER_LINEAR, borderMode=borderMode,
                                    borderValue=(
                                        0, 0,
                                        0,))
        mask = cv2.warpPerspective(mask, mat, (width, height), flags=cv2.INTER_LINEAR, borderMode=borderMode,
                                   borderValue=(
                                       0, 0,
                                       0,))

    return image, mask

def randomHorizontalFlip(image, mask, u=0.5):
    if np.random.random() < u:
        image = cv2.flip(image, 1)
        mask = cv2.flip(mask, 1)

    return image, mask

def randomVerticleFlip(image, mask, u=0.5):
    if np.random.random() < u:
        image = cv2.flip(image, 0)
        mask = cv2.flip(mask, 0)

    return image, mask

def randomRotate90(image, mask, u=0.5):
    if np.random.random() < u:
        image=np.rot90(image)
        mask=np.rot90(mask)

    return image, mask


def default_loader(img_path, mask_path):

    img = cv2.imread(img_path)
    # print("img:{}".format(np.shape(img)))
    img = cv2.resize(img, (448, 448))

    mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)

    mask = 255. - cv2.resize(mask, (448, 448))
    
    img = randomHueSaturationValue(img,
                                   hue_shift_limit=(-30, 30),
                                   sat_shift_limit=(-5, 5),
                                   val_shift_limit=(-15, 15))

    img, mask = randomShiftScaleRotate(img, mask,
                                       shift_limit=(-0.1, 0.1),
                                       scale_limit=(-0.1, 0.1),
                                       aspect_limit=(-0.1, 0.1),
                                       rotate_limit=(-0, 0))
    img, mask = randomHorizontalFlip(img, mask)
    img, mask = randomVerticleFlip(img, mask)
    img, mask = randomRotate90(img, mask)
    
    mask = np.expand_dims(mask, axis=2)
    #
    # print(np.shape(img))
    # print(np.shape(mask))

    img = np.array(img, np.float32).transpose(2,0,1)/255.0 * 3.2 - 1.6
    mask = np.array(mask, np.float32).transpose(2,0,1)/255.0
    mask[mask >= 0.5] = 1
    mask[mask <= 0.5] = 0
    #mask = abs(mask-1)
    return img, mask

def read_own_data(root_path, mode = 'train'):
    images = []
    masks = []

    image_root = os.path.join(root_path, mode + '/images')
    gt_root = os.path.join(root_path, mode + '/labels')


    for image_name in os.listdir(gt_root):
        image_path = os.path.join(image_root, image_name)
        label_path = os.path.join(gt_root, image_name)

        images.append(image_path)
        masks.append(label_path)

    return images, masks

def own_data_loader(img_path, mask_path):
    img = cv2.imread(img_path)
    img = cv2.resize(img, (512,512), interpolation = cv2.INTER_NEAREST)
    mask = cv2.imread(mask_path, 0)
    mask = cv2.resize(mask, (512,512), interpolation = cv2.INTER_NEAREST)

    img = randomHueSaturationValue(img,
                                   hue_shift_limit=(-30, 30),
                                   sat_shift_limit=(-5, 5),
                                   val_shift_limit=(-15, 15))

    img, mask = randomShiftScaleRotate(img, mask,
                                       shift_limit=(-0.1, 0.1),
                                       scale_limit=(-0.1, 0.1),
                                       aspect_limit=(-0.1, 0.1),
                                       rotate_limit=(-0, 0))
    img, mask = randomHorizontalFlip(img, mask)
    img, mask = randomVerticleFlip(img, mask)
    img, mask = randomRotate90(img, mask)

    mask = np.expand_dims(mask, axis=2)

    img = np.array(img, np.float32) / 255.0 * 3.2 - 1.6
    # img = np.array(img, np.float32) / 255.0
    mask = np.array(mask, np.float32) / 255.0
    mask[mask >= 0.5] = 1
    mask[mask < 0.5] = 0

    img = np.array(img, np.float32).transpose(2, 0, 1)
    mask = np.array(mask, np.float32).transpose(2, 0, 1)
    return img, mask

def own_data_test_loader(img_path, mask_path):
    img = cv2.imread(img_path)
    img = cv2.resize(img, (512,512), interpolation = cv2.INTER_NEAREST)
    mask = cv2.imread(mask_path, 0)
    mask = cv2.resize(mask, (512,512), interpolation = cv2.INTER_NEAREST)
    mask = np.expand_dims(mask, axis=2)

    img = np.array(img, np.float32) / 255.0 * 3.2 - 1.6
    mask = np.array(mask, np.float32) / 255.0
    mask[mask >= 0.5] = 1
    mask[mask < 0.5] = 0

    img = np.array(img, np.float32).transpose(2, 0, 1)
    mask = np.array(mask, np.float32).transpose(2, 0, 1)

    return img, mask

class ImageFolder(data.Dataset):

    def __init__(self,root_path, mode='train'):
        self.root = root_path
        self.mode = mode
        self.images, self.labels = read_own_data(self.root, self.mode)

    def __getitem__(self, index):
        if self.mode == 'test':
            img, mask = own_data_test_loader(self.images[index], self.labels[index])
        else:
            img, mask = own_data_loader(self.images[index], self.labels[index])
            img = torch.Tensor(img)
            mask = torch.Tensor(mask)
        return img, mask

    def __len__(self):
        assert len(self.images) == len(self.labels), 'The number of images must be equal to labels'
        return len(self.images)

4.预测代码inference.py

预测只用了test_1函数,保留了原作者的多模型预测函数test,感兴趣的朋友改改试试看

# -*- coding: utf-8 -*-
import os
import glob
import time
import cv2
import numpy as np
import torch
import segmentation_models_pytorch as smp
from torch.optim.swa_utils import AveragedModel
from data import ImageFolder

DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu' 

def test(model_path_effi7, model_path_resnest, output_dir, test_loader):
    in_channels = 4
    model_resnest = smp.UnetPlusPlus(
        encoder_name="timm-resnest101e",
        encoder_weights="imagenet",
        in_channels=in_channels,
        classes=10,
        )
    model_effi7 = smp.UnetPlusPlus(
        encoder_name="efficientnet-b7",
        encoder_weights="imagenet",
        in_channels=in_channels,
        classes=10,   
        )
    # 如果模型是SWA
    if("swa" in model_path_resnest):
        model_resnest = AveragedModel(model_resnest)
    if("swa" in model_path_effi7):
        model_effi7 = AveragedModel(model_effi7)
    model_resnest.to(DEVICE);
    model_resnest.load_state_dict(torch.load(model_path_resnest))
    model_resnest.eval()
    model_effi7.to(DEVICE);
    model_effi7.load_state_dict(torch.load(model_path_effi7))
    model_effi7.eval()
    for image, image_stretch, image_path, ndvi in test_loader:
        with torch.no_grad():
            # image.shape: 16,4,256,256
            image_flip2 = torch.flip(image,[2])
            image_flip2 = image_flip2.cuda()
            image_flip3 = torch.flip(image,[3])
            image_flip3 = image_flip3.cuda()
            image = image.cuda()
            image_stretch = image_stretch.cuda()
            
            output1 = model_resnest(image).cpu().data.numpy()
            output2 = model_resnest(image_stretch).cpu().data.numpy()
            output3 = model_effi7(image).cpu().data.numpy()
            output4 = model_effi7(image_stretch).cpu().data.numpy()
            
            output5 = torch.flip(model_resnest(image_flip2),[2]).cpu().data.numpy()
            output6 = torch.flip(model_effi7(image_flip2),[2]).cpu().data.numpy()
            output7 = torch.flip(model_resnest(image_flip3),[3]).cpu().data.numpy()
            output8 = torch.flip(model_effi7(image_flip3),[3]).cpu().data.numpy()
            
        output = (output1 + output2 + output3 + output4 + output5 + output6 + output7 + output8) / 8.0
        # output.shape: 16,10,256,256
        for i in range(output.shape[0]):
            pred = output[i]
            # for low_ndvi in range(3,8):
            #     pred[low_ndvi][ndvi[i]>35] = 0
            # for high_ndvi in range(3):
            #     pred[high_ndvi][ndvi[i]<0.02] = 0
            pred = np.argmax(pred, axis = 0) + 1
            pred = np.uint8(pred)
            save_path = os.path.join(output_dir, image_path[i][-10:].replace('.tif', '.png'))
            print(save_path)
            cv2.imwrite(save_path, pred)
        
def test_1(channels, model_path, output_dir, test_path):
    # model = smp.UnetPlusPlus(
    #         encoder_name="resnet101",
    #         encoder_weights="imagenet",
    #         in_channels=4,
    #         classes=10,
    # )
    # model = smp.DeepLabV3Plus(
    #         encoder_name="resnet101",
    #         encoder_weights="imagenet",
    #         in_channels=in_channels,
    #         classes=1,
    # )

    model = smp.Unet(
    encoder_name="resnet34",        # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
    encoder_weights="imagenet",     # use `imagenet` pre-trained weights for encoder initialization
    in_channels=channels,                  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
    classes=1,                      # model output channels (number of classes in your dataset)
    activation='sigmoid',
    )

    # 如果模型是SWA
    if("swa" in model_path):
        model = AveragedModel(model)
    model.to(DEVICE);
    model.load_state_dict(torch.load(model_path))
    model.eval()

    im_names = os.listdir(test_path)
    for name in im_names:
        full_path = os.path.join(test_path, name)
        img = cv2.imread(full_path)
        img = cv2.resize(img, (512,512), interpolation = cv2.INTER_NEAREST)
        image = np.array(img, np.float32) / 255.0 * 3.2 - 1.6
        image = np.array(image, np.float32).transpose(2, 0, 1)
        image = np.expand_dims(image, axis=0)
        image = torch.Tensor(image)
        image = image.cuda()
        output = model(image).cpu().data.numpy()
        output[output < 0.5] = 0
        output[output >= 0.5] = 255
        output = output.squeeze()
        save_path = os.path.join(output_dir, name)
        cv2.imwrite(save_path, output)
        
if __name__ == "__main__":
    start_time = time.time()
    # model_path_effi7 = "../user_data/model_data/unetplusplus_effi7_upsample_SoftCE_dice.pth"
    # model_path_resnest = "../user_data/model_data/unetplusplus_resnest_upsample_SoftCE_dice.pth"

    data_root = "./dataset/build/test/images/"
    model_path = "./weights/_SoftCE_dice.pth"
    output_dir = './result/'

    test_1(3, model_path, output_dir, data_root)

题外话:有什么新的比较好的网络可以评论推荐给我,我来复现贴出来大家一起用一用 

Logo

秉承“创新、开放、协作、共享”的开源价值观,致力于为大规模开源开放协同创新助力赋能,打造创新成果孵化和新时代开发者培养的开源创新生态!支持公有云使用、私有化部署以及软硬一体化私有部署。

更多推荐