pytorch-unet

来源:https://github.com/milesial/Pytorch-UNet

前两天搞了一下图像分割,用了下unet。之前没怎么用过。复现了一下18年的une pytorch 版本,记录学习一下 (//过了一年了来补充完善一下。。)

1. 主函数

if __name__ == '__main__':
    args = get_args()

    logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    logging.info(f'Using device {device}')

    # Change here to adapt to your data
    # n_channels=3 for RGB images
    # n_classes is the number of probabilities you want to get per pixel
    # 如果是RGB图像 n_channels=3,如果是医学图像(大部分是灰度图)n_channels=1。n_classes是
    # 你要分割的类别数加1,比如你的前景有两类,n_classes = 3哦
    net = UNet(n_channels=3, n_classes=args.classes, bilinear=args.bilinear)

    logging.info(f'Network:\n'
                 f'\t{net.n_channels} input channels\n'
                 f'\t{net.n_classes} output channels (classes)\n'
                 f'\t{"Bilinear" if net.bilinear else "Transposed conv"} upscaling')

    if args.load:
        net.load_state_dict(torch.load(args.load, map_location=device))
        logging.info(f'Model loaded from {args.load}')

    net.to(device=device)
    try:
        train_net(net=net,
                  epochs=args.epochs,
                  batch_size=args.batch_size,
                  learning_rate=args.lr,
                  device=device,
                  img_scale=args.scale,
                  val_percent=args.val / 100,
                  amp=args.amp)
    except KeyboardInterrupt:
        torch.save(net.state_dict(), 'INTERRUPTED.pth')
        logging.info('Saved interrupt')
        raise

有一个定义的函数get_args():

def get_args():
    parser = argparse.ArgumentParser(description='Train the UNet on images and target masks')
    parser.add_argument('--epochs', '-e', metavar='E', type=int, default=5, help='Number of epochs')
    parser.add_argument('--batch-size', '-b', dest='batch_size', metavar='B', type=int, default=1, help='Batch size')
    parser.add_argument('--learning-rate', '-l', metavar='LR', type=float, default=1e-5,
                        help='Learning rate', dest='lr')
    parser.add_argument('--load', '-f', type=str, default=False, help='Load model from a .pth file')
    parser.add_argument('--scale', '-s', type=float, default=0.5, help='Downscaling factor of the images')
    parser.add_argument('--validation', '-v', dest='val', type=float, default=10.0,
                        help='Percent of the data that is used as validation (0-100)')
    parser.add_argument('--amp', action='store_true', default=False, help='Use mixed precision')
    parser.add_argument('--bilinear', action='store_true', default=False, help='Use bilinear upsampling')
    parser.add_argument('--classes', '-c', type=int, default=2, help='Number of classes')

    return parser.parse_args()

argparser主要有三个步骤:

1. Argumentparser()对象,将命令行解析成python数据类型所需要的全部信息。

2.add_argument()方法添加函数,主要定batchsize,lr,epochs,这些乱七八糟的东西,这样方便在命令行直接修改。

3. 这个封装的函数相当于最终解析出来(parparse_args())。

创建解析器 - 添加命令行参数-解析参数

主要这个函数其实就是干了三个事:1.通过arg parser设定需要的轮次,bs,学习率等 。2. 设定输入图像,输出图像尺寸。 3. 是用cpu,一块gpu还是用多块gpu

2. 训练模型

1. 创建数据集 (这一步的作用主要是实现loading data 和 augmentation)

    # 1. Create dataset
    try:
        dataset = CarvanaDataset(dir_img, dir_mask, img_scale)
    except (AssertionError, RuntimeError):
        dataset = BasicDataset(dir_img, dir_mask, img_scale)

定义数据集的路径,mask的路径。可以在这做数据增强,下面举个例子是训练集的数据增强,一般都是用 transforms.Compose的方法,比如下图就是用了随机旋转,随即翻转,转成tensor,做标准化(如果想添加自己的数据增强方法就在transform里自己定义一个类,然后compose进来实现自定义数据增强)。

transform_train = transforms.Compose([
    transforms.RandomRotation(degrees=8),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean, std), ])

2. 划分数据集(训练集,验证集,测试集,通过random_split对数据集进行一定比例的随即划分。

    # 2. Split into train / validation partitions
    n_val = int(len(dataset) * val_percent)
    n_train = len(dataset) - n_val
    train_set, val_set = random_split(dataset, [n_train, n_val], generator=torch.Generator().manual_seed(0))

 3. 创建dataloder

# 3. Create data loaders
    loader_args = dict(batch_size=batch_size, num_workers=4, pin_memory=True)
    train_loader = DataLoader(train_set, shuffle=True, **loader_args)
    val_loader = DataLoader(val_set, shuffle=False, drop_last=True, **loader_args)

首先简单介绍一下啥是dataLoader,它是PyTorch中数据读取的一个重要接口,该接口定义在dataloader.py中,该接口的目的:将自定义的Dataset根据batch size大小、是否shuffle等封装成一个一个Batch Size大小的Tensor,用于后面的训练。

例如:定义的train_loder继承了dataloder,用自己的train_set数据集,按batchsize分成一批一批的tensor去训练;shuffle是每一个epoch结束之后,是否要重新排序;num_worker这个参数决定了有几个进程来处理data loading。0意味着所有的数据都会被load进主进程。根据我的经验哈,一般如果出席那这些osa,显存太小这种报错,就把num_workers改成0,或者4就行了。一般gpu上跑还是8或者16差不多,这个是影响训练速度的,num_workers太小的话,gpu利用率会非常低,训练不好。(很多时候我在本地都是num = 0,然后放到服务器上忘了改num,直接显存就爆炸了。。)

4. 创建优化器,定义学习率策略,定义损失函数

 # 4. Set up the optimizer, the loss, the learning rate scheduler and the loss scaling for AMP
    optimizer = optim.RMSprop(net.parameters(), lr=learning_rate, weight_decay=1e-8, momentum=0.9)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', patience=2)  # goal: maximize Dice score
    grad_scaler = torch.cuda.amp.GradScaler(enabled=amp)
    criterion = nn.CrossEntropyLoss()
    global_step = 0

optimizer能够保持当前参数状态并基于计算得到的梯度进行参数更新,可以继承网络初始参数,权重衰减,学习率策略啊一些东西。方法分为2大类:一大类方法是SGD及其改进(加Momentum)另外一大类是Per-parameter adaptive learning rate methods(逐参数适应学习率方法),包括AdaGrad、RMSProp、Adam等。这东西就跟机器学习当中选择什么算法来进行梯度更新一样。

我的经验:一般优化器就是SGD或者Adam; 学习率策略一般是  lr_ = base_lr * (1.0 - iter_num / max_iterations) ** 0.9 或者过多少轮减半什么的,初始学习率最多0.1,一般0.01或者0.001;bs是2,4,8,16。

5. 开始训练 begin

# 5. Begin training
    for epoch in range(1, epochs+1):
        net.train()
        epoch_loss = 0
        with tqdm(total=n_train, desc=f'Epoch {epoch}/{epochs}', unit='img') as pbar:
            for batch in train_loader:
                images = batch['image']
                true_masks = batch['mask']

                assert images.shape[1] == net.n_channels, \
                    f'Network has been defined with {net.n_channels} input channels, ' \
                    f'but loaded images have {images.shape[1]} channels. Please check that ' \
                    'the images are loaded correctly.'

                images = images.to(device=device, dtype=torch.float32)
                true_masks = true_masks.to(device=device, dtype=torch.long)

                with torch.cuda.amp.autocast(enabled=amp):
                    masks_pred = net(images)
                    loss = criterion(masks_pred, true_masks) \
                           + dice_loss(F.softmax(masks_pred, dim=1).float(),
                                       F.one_hot(true_masks, net.n_classes).permute(0, 3, 1, 2).float(),
                                       multiclass=True)

                optimizer.zero_grad(set_to_none=True)
                grad_scaler.scale(loss).backward()
                grad_scaler.step(optimizer)
                grad_scaler.update()

                pbar.update(images.shape[0])
                global_step += 1
                epoch_loss += loss.item()
                experiment.log({
                    'train loss': loss.item(),
                    'step': global_step,
                    'epoch': epoch
                })
                pbar.set_postfix(**{'loss (batch)': loss.item()})

在每一个epoch里,通过train_loder得到多少个batch,每个batch,每个batch训练。通过网络分割得到的masks_prede和传入的true_masks进行loss计算,在优化器内不断反向传播,更新梯度,更新优化器,使得loss越来越小,并且趋于稳定。

这里主要就是loss的选择了,一般就是交叉熵和dice_loss

6. 计算

       val_score = evaluate(net, val_loader, device)
       scheduler.step(val_score)

       logging.info('Validation Dice score: {}'.format(val_score))

通过训练的网络对验证集图片进行预测,然后与验证集的true_masks进行比较得到精度。

7. 保存权重

 if save_checkpoint:
            Path(dir_checkpoint).mkdir(parents=True, exist_ok=True)
            torch.save(net.state_dict(), str(dir_checkpoint / 'checkpoint_epoch{}.pth'.format(epoch + 1)))
            logging.info(f'Checkpoint {epoch + 1} saved!')

在这里可以将训练的每一轮参数保存下来,保存成pth文件,到时候在预测的时候直接用就可以了。

3. model

也就是训练当中所用到的net的结构是什么样子的

class UNet(nn.Module):
    def __init__(self, n_channels, n_classes, bilinear=True):
        super(UNet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear

        self.inc = DoubleConv(n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        factor = 2 if bilinear else 1
        self.down4 = Down(512, 1024 // factor)
        self.up1 = Up(1024, 512 // factor, bilinear)
        self.up2 = Up(512, 256 // factor, bilinear)
        self.up3 = Up(256, 128 // factor, bilinear)
        self.up4 = Up(128, 64, bilinear)
        self.outc = OutConv(64, n_classes)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.outc(x)
        return logits

结构比较简单,在encoder阶段,主要有几个模块

第一个模块就是doublecov(两个(conv2d,bn,relu)),主要用在最开始将三通道图片转化为64通道图片。照例子来说,每一次unet conv后,图片尺寸都会下降2,但是在代码中

out_size = (in_size - K + 2P)/ S +1

特意将大小设成3,padding设成1,stride设成1,这样在做conv的时候图片尺寸就不会发生变化了

第二个模块就是down模块(maxpool2d,doubleconv),每次将图片尺寸减半,并在池化后进行conv的操作,增加通道数。

在decoder阶段:

up模块(unsample)+conv 上采样将图片尺寸增加,conv将通道数减少,并且和endocder同层的特征图进行连接

最后outc模块,看你是想输出几通道的图片,就有几个卷积核就好了。

注:复现时候一些bug:

debug:

AssertionError: Either no mask or multiple masks found for the ID 0008052191_9: []

解决方案:找到了img_file路径,mask file路径找不到。在data_loading里将mask_suffix改为空,如果你的img和mask是一摸一样的名字的话。

RuntimeError: CUDA error: CUBLAS_STATUS_NOT_INITIALIZED when calling `cublasCreate(handle)`

解决方案:是分类标签越界的问题,最终mask是要分0,1的。项目当中数据集mask是0,1。但是我的image跟mask都是0-255,位深度24,所以原项目是只将img除了255,代码中只要将is_mask改成False就好了。

RuntimeError: 1only batches of spatial targets supported (3D tensors) but got targets of size: : [1, 1, 256, 256]

解决方案:这个报错比较明显,可以打印一下自己在求loss的时候的target跟ground truth,将图像尺寸reshape一下(比如:true_mask.reshape(1,256,256)就好了)。同样在验证集的时候也遇到这个问题,同样的解决方法自然就完成了。

wandb.errors.CommError: check_hostname requires server_hostname

解决方案:这个我也不大懂反正大概意思因为我开了翻墙软件,可能wandb那里出现了什么问题,把他关了就好了。

BrokenPipeError: [Errno 32] Broken pipe

解决方案:好像还有是说os:显存太小一类的,将num_workers改成0就好了。

Logo

为武汉地区的开发者提供学习、交流和合作的平台。社区聚集了众多技术爱好者和专业人士,涵盖了多个领域,包括人工智能、大数据、云计算、区块链等。社区定期举办技术分享、培训和活动,为开发者提供更多的学习和交流机会。

更多推荐