【pytorch】多卡训练/混合精度/分布式训练之踩坑指北

1. 混合精度

1.1 目的

训练网络的基本上都是在N卡上面执行的,数据集比较大时,训练网络会耗费大量的时间。由于我们需要使用反向传播来更新具有细微变化的权重,因而我们在训练网络的过程中通常会选用FP32类型的数据和权重。
混合精度训练,即当你使用N卡训练你的网络时,混合精度会在内存中用FP16做储存和乘法从而加速计算,用FP32做累加避免舍入误差。
它的优势就是可以使你的训练时间减少一半左右。它的缺陷是只能在支持FP16操作的一些特定类型的显卡上面使用,而且会存在溢出误差和舍入误差。

在这里插入图片描述

  • FP32和FP16都是用来表示某一个数值;
  • FP32和FP16都是由符号位、指数和尾数一起组成;
  • 即FP16最大能够表示的数字是65503;
    在这里插入图片描述
  • FP16计算速度更快、更加节约内存
  • 计算同样的操作,FP16可以获得8倍的加速、2倍左右的内存扇出、节省1/2的内存资源;
  • 下图展示了执行卷积的过程(乘操作和加操作)
    上图展示了执行卷积的过程(乘操作和加操作)

1.2 F16缺点

缺点1-FP16会带来梯度溢出错误

比FP32的动态范围小了很多,因而在计算的过程中很容易出现上溢出(超出能够表示的最大数值)和下溢出(超出能够表示的最小数值)问题,溢出之后就会出现NAN的问题。在深度学习中,由于激活函数的的梯度往往要比权重梯度小,更易出现下溢出的情况。

缺点2-FP16会带来舍入误差

舍入误差,即当梯度过小,小于当前区间内的最小间隔时,该次梯度更新可能会失败,具体的细节如下图所示,由于更新的梯度值超出了FP16能够表示的最小值的范围,因此该数值将会被舍弃,这个权重将不进行更新。请添加图片描述

1.3 解决方案:

APEX是什么

APEX是英伟达开源的,完美支持PyTorch框架,用于改变数据格式来减小模型显存占用的工具。其中最有价值的是amp(Automatic Mixed Precision),将模型的大部分操作都用Float16数据类型测试,一些特别操作仍然使用Float32。并且用户仅仅通过三行代码即可完美将自己的训练代码迁移到该模型。实验证明,使用Float16作为大部分操作的数据类型,并没有降低参数,在一些实验中,反而由于可以增大Batch size,带来精度上的提升,以及训练速度上的提升。

查看能否正确导入apex
from apex import amp

使用混合精度训练。所谓的混合精度训练,即在内存中用FP16做储存和乘法从而加速计算,用FP32做累加避免舍入误差,这样可以很好的解决舍入误差的问题。
损失放大。有些情况下,即使使用了混合精度训练的方法,由于激活梯度的值太小,会造成下溢出,从而导致模型无法收敛的问题。所谓的损失放大,即反向传播前,将损失变化(dLoss)手动增大2^k,因此反向传播时得到的中间变量(激活函数梯度)则不会溢出;反向传播后,将权重梯度缩倍,恢复正常值。

1.4 混合精度代码

代码实例
# ===> 修改前
# coding=utf-8
import torch
N, D_in, D_out = 64, 1024, 512
x = torch.randn(N, D_in, device=“cuda”)
y = torch.randn(N, D_out, device=“cuda”)
model = torch.nn.Linear(D_in, D_out).cuda()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
for t in range(500):
	y_pred = model(x)
	loss = torch.nn.functional.mse_loss(y_pred, y)
	optimizer.zero_grad()
	loss.backward()
	optimizer.step()
	
# ===> 修改后
# coding=utf-8
import torch
N, D_in, D_out = 64, 1024, 512
x = torch.randn(N, D_in, device=“cuda”)
y = torch.randn(N, D_out, device=“cuda”)
model = torch.nn.Linear(D_in, D_out).cuda()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
model, optimizer = amp.initialize(model, optimizer, opt_level=“O1”)
for t in range(500):
	y_pred = model(x)
	loss = torch.nn.functional.mse_loss(y_pred, y)
	optimizer.zero_grad()
	with amp.scale_loss(loss, optimizer) as scaled_loss:
		scaled_loss.backward()
	optimizer.step()

代码分析
model, optimizer = amp.initialize(model, optimizer, opt_level=“O1”)

这行代码的主要作用是对模型和优化器执行初始化操作,方便后续的混合精度训练。其中opt_level表示优化的等级,当前支持4个等级的优化。
在这里插入图片描述

  • 当opt_level='00’时,表示的是当前执行FP32训练,即正常的训练,当前优化等级执行的具体操作是cast_model_type=torch.float32、patch_torch_function= False、keep_batchnorm_fp32=None、master_weight=False、loss_scale=1.0。
  • 当opt_level='01’时,表示的是当前使用部分FP16混合训练,当前优化等级执行的具体操作是cast_model_type=None、patch_torch_function=True、keep_batch norm_fp32=None、master_weight=None、loss_scale=“dynamic”。
  • 当opt_level='02’时,表示的是除了BN层的权重外,其他层的权重都使用FP16执行训练,当前优化等级执行的具体操作是cast_model_type=torch.float16、patch _torch_function=False、keep_batchnorm_fp32=True、master_weight =True 、loss_scale=“dynamic”。
  • 当opt_level='03’时,表示的是默认所有的层都使用FP16执行计算,当keep_batch norm_fp32=True,则会使用cudnn执行BN层的计算,该优化等级能够获得最快的速度,但是精度可能会有一些较大的损失。当前优化等级执行的具体操作是cast_ model_type=torch.float16、patch _torch_function=False、keep_batchnorm _fp32=False、master_weight =False、loss_scale=1.0
  • 注意,这里opt_level='O1’是大写字母“O”
with amp.scale_loss(loss, optimizer) as scaled_loss:
	scaled_loss.backward()

这行代码的主要作用是在反向传播前进行梯度放大来进行更新,在反向传播后进行梯度缩放,返回原来的值,但是可以很好的解决由于梯度值太小模型无法更新的问题。

效果

在这里插入图片描述
在这里插入图片描述

1.5 问题

  • 不支持DataParallel多卡训练
  • 溢出问题(必看)
    因为Float16保存数据位数少了,能保存数据的上限和下限的绝对值也小了。如果我们在处理分割类问题,需要用到一些涉及到求和的操作,如sigmoid,softmax,这些操作都涉及到求和。分割问题特征图都很大,求个sigmoid可能会导致数据溢出,得到错误的结果。所以针对这些操作,仍然使用float32作为数据格式。那么如何基于上面的代码修改呢。
from apex import amp
class xxxNet(Module):
	def __init__(using_map=False)
		...
		...
		if using_amp:
		     amp.register_float_function(torch, 'sigmoid')
		     amp.register_float_function(torch, 'softmax')

用register_float_function指明后面的函数需要使用float类型。注意第二实参是string类型
和register_float_function相似的注册函数还有

amp.register_half_function(module, function_name)
amp.register_float_function(module, function_name)
amp.register_promote_function(module, function_name)

你必须在使用amp.initialize之前使用注册函数,所以最好的位置就放在模型的构造函数中

2 多卡训练

2.1 目的

最简单的单机多卡操作nn.DataParallel。但是很遗憾这种操作还不够优秀。
单机多卡的办法还有很多(如下)

  • nn.DataParallel 简单方便的 nn.DataParallel
  • torch.distributed 使用. torch.distributed 加速并行训练
  • apex 使用 apex 再加速。
为啥非要单机多卡?

答1:加速神经网络训练最简单的办法就是上GPU,如果一块GPU还是不够,就多上几块。事实上,比如BERT和GPT-2这样的大型语言模型甚至是在上百块GPU上训练的。为了实现多GPU训练,我们必须想一个办法在多个GPU上分发数据和模型,并且协调训练过程。

单机多卡操作nn.DataParallel,哪里不好?

答2:要回答这个问题我们得先简单回顾一下nn.DataParallel,要使用这玩意,我们将模型和数据加载到多个 GPU 中,控制数据在 GPU 之间的流动,协同不同 GPU 上的模型进行并行训练。具体怎么操作?
我们只需要用 DataParallel 包装模型,再设置一些参数即可。需要定义的参数包括:

  • 参与训练的 GPU 有哪些,device_ids=gpus
  • 用于汇总梯度的 GPU 是哪个,output_device=gpus[0]

DataParallel 会自动帮我们将数据切分 load 到相应 GPU,将模型复制到相应 GPU,进行正向传播计算梯度并汇总。

单机单卡
if not doka_training:
        gpu_ids = [0,1,2,3]
        torch.cuda.set_device('cuda:{}'.format(gpu_ids[1]))
        print('=> using GPU id is: {}'.format(gpu_ids[1]))
单机多卡
model = nn.DataParallel(model.cuda(), device_ids=gpus, output_device=gpus[0])
# 值得注意的是,模型和数据都需要先 load 进 GPU 中,
# DataParallel 的 module 才能对其进行处理,否则会报错
使用DataParallel完整代码指北
# main.py
import torch
import torch.distributed as dist

gpus = [0, 1, 2, 3]
torch.cuda.set_device('cuda:{}'.format(gpus[0]))

train_dataset = ...

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=...)

model = ...
model = nn.DataParallel(model.to(device), device_ids=gpus, output_device=gpus[0])

optimizer = optim.SGD(model.parameters())

for epoch in range(100):
   for batch_idx, (data, target) in enumerate(train_loader):
      images = images.cuda(non_blocking=True)
      target = target.cuda(non_blocking=True)
      ...
      output = model(images)
      loss = criterion(output, target)
      ...
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()
DataParallel大缺点:

在每个训练批次(batch)中,因为模型的权重都是在一个进程上先算出来,然后再把他们分发到每个GPU上,所以网络通信就成为了一个瓶颈,而GPU使用率也通常很低。
除此之外,nn.DataParallel 需要所有的GPU都在一个节点(一台机器)上,且并不支持 Apex 的混合精度训练。

3 分布式多卡训练

3.1 目的

  • 解决DataParallel训练存在的问题
  • DataParallel:单进程控制多 GPU。
  • DistributedDataParallel:多进程控制多 GPU,一起训练模型。

3.2 解决方案

通过 MPI 实现 CPU 通信,通过 NCCL 实现 GPU 通信。
官方也曾经提到用 DistributedDataParallel 解决 DataParallel 速度慢,GPU 负载不均衡的问题,目前已经很成熟了。

使用 nn.DistributedDataParallel 进行Multiprocessing可以在多个gpu之间复制该模型,每个gpu由一个进程控制。(如果你想,也可以一个进程控制多个GPU,但这会比控制一个慢得多。也有可能有多个工作进程为每个GPU获取数据,但为了简单起见,本文将省略这一点。)这些GPU可以位于同一个节点上,也可以分布在多个节点上。每个进程都执行相同的任务,并且每个进程与所有其他进程通信。

只有梯度会在进程/GPU之间传播,这样网络通信就不至于成为一个瓶颈了。

注意 ⚠️ 多进程训练需要注意以下事项:
  • 在喂数据的时候,一个batch被分到了好几个进程,每个进程在取数据的时候要确保拿到的是不同的数据(DistributedSampler)
  • 要告诉每个进程自己是谁,使用哪块GPU(args.local_rank)
  • 在做BatchNormalization的时候要注意同步数据。
3.2.1 执行脚本(可选)

在多进程的启动方面,无需自己手写 multiprocess 进行一系列复杂的CPU、GPU分配任务。
PyTorch为我们提供了一个很方便的启动器 torch.distributed.launch 用于启动文件,所以我们运行训练代码的方式就变成了这样:

CUDA_VISIBLE_DEVICES=0,1,2,3 python \-m torch.distributed.launch \--nproc_per_node=4 main.py

其中的 --nproc_per_node 参数用于指定为当前主机创建的进程数,由于我们是单机多卡,所以这里node数量为1,所以我们这里设置为所使用的GPU数量即可。

进程数 = 节点数 * 显卡数

sh脚本如下:

执行脚本

bash launch.sh

脚本代码

python                            \ 
    -m torch.distributed.launch   \
    --nproc_per_node = 4          \
    train.py                      \
3.2.2 训练代码初始化

导入库

import os
from datetime import datetime
import argparse
import torch.multiprocessing as mp
import torchvision
import torchvision.transforms as transforms
import torch
import torch.nn as nn
import torch.distributed as dist
from apex.parallel import DistributedDataParallel as DDP
from apex import amp

启动器会将当前进程的(其实就是 GPU的)index 通过参数传递给 python,我们可以这样获得当前进程的 index:即通过参数 local_rank 来告诉我们当前进程使用的是哪个GPU,用于我们在每个进程中指定不同的device

需要一个脚本,用来启动一个进程的每一个GPU。每个进程需要知道使用哪个GPU,以及它在所有正在运行的进程中的阶序(rank)。而且,我们需要在每个节点上运行脚本。

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('-n', '--nodes', default=1,
                        type=int, metavar='N')
    parser.add_argument('-g', '--gpus', default=1, type=int,
                        help='number of gpus per node')
    parser.add_argument('-nr', '--nr', default=0, type=int,
                        help='ranking within the nodes')
    parser.add_argument('--epochs', default=2, type=int, 
                        metavar='N',
                        help='number of total epochs to run')
    args = parser.parse_args()
    #########################################################
    args.world_size = args.gpus * args.nodes                #
    # 基于结点数以及每个结点的GPU数,我们可以计算 world_size 或者需要运行的总进程数,这和总GPU数相等。
    os.environ['MASTER_ADDR'] = '10.57.23.164'              #
    # 告诉Multiprocessing模块去哪个IP地址找process 0以确保初始同步所有进程。
    os.environ['MASTER_PORT'] = '8888'                      #
    mp.spawn(train, nprocs=args.gpus, args=(args,))         #
    # 现在,我们需要生成 args.gpus 个进程, 每个进程都运行 train(i, args), 其中 i 从 0 到 args.gpus - 1。
    # 注意, main() 在每个结点上都运行, 因此总共就有 args.nodes * args.gpus = args.world_size 个进程.
    #########################################################

其中

  • args.nodes 是我们使用的结点数
  • args.gpus 是每个结点的GPU数
  • args.nr 是当前结点的阶序rank,这个值的取值范围是 0 到 args.nodes - 1
针对训练函数进行修改
def train(gpu, args):
    ############################################################
    # 这里是该进程在所有进程中的全局rank(一个进程对应一个GPU)。
    rank = args.nr * args.gpus + gpu	
    '''
    初始化进程并加入其他进程。这就叫做“blocking”,也就是说只有当所有进程都加入了,单个进程才会运行。
    这里使用了 nccl 后端,因为Pytorch文档说它是跑得最快的。 
    init_method 让进程组知道去哪里找到它需要的设置。
    在这里,它就在寻找名为 MASTER_ADDR 以及 MASTER_PORT 的环境变量,这些环境变量在 main 函数中设置过。
    当然,本来可以把world_size 设置成一个全局变量,
    不过本脚本选择把它作为一个关键字参量(和当前进程的全局阶序global rank一样)
    '''                          
    dist.init_process_group(                                   
    	backend='nccl',                                         
   		init_method='env://',                                   
    	world_size=args.world_size,                              
    	rank=rank                                               
    )                                                          
    ############################################################
    
    torch.manual_seed(0)
    model = ConvNet()
    torch.cuda.set_device(gpu)
    model.cuda(gpu)
    batch_size = 100
    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda(gpu)
    optimizer = torch.optim.SGD(model.parameters(), 1e-4)
    
    ###############################################################
    # Wrap the model
    # 将模型封装为一个 DistributedDataParallel 模型。这将把模型复制到GPU上进行处理。
    model = nn.parallel.DistributedDataParallel(model,
                                                device_ids=[gpu])
    ###############################################################

    # Data loading code
    train_dataset = torchvision.datasets.MNIST(
        root='./data',
        train=True,
        transform=transforms.ToTensor(),
        download=True
    )                                               
    ################################################################
    # nn.utils.data.DistributedSampler 确保每个进程拿到的都是不同的训练数据切片。
    train_sampler = torch.utils.data.distributed.DistributedSampler(
    	train_dataset,
    	num_replicas=args.world_size,
    	rank=rank
    )
    ################################################################

    train_loader = torch.utils.data.DataLoader(
    	dataset=train_dataset,
       batch_size=batch_size,
    ##############################
    # 因为用了 nn.utils.data.DistributedSampler 所以不能用正常的办法做shuffle。
       shuffle=False,            #
    ##############################
       num_workers=0,
       pin_memory=True,
    #############################
      sampler=train_sampler)    # 
    #############################
    ...

为了简单起见,上面的代码去掉了简单循环并用 … 代替,不过你可以在这里看到完整脚本 。

要在4个节点上运行它(每个节点上有8个gpu),我们需要4个终端(每个节点上有一个)。在节点0上(由 main 中的第13行设置):

python src/mnist-distributed.py -n 4 -g 8 -nr 0

而在其他的节点上:

python src/mnist-distributed.py -n 4 -g 8 -nr i

其中 i∈1,2,3. 换句话说,我们要把这个脚本在每个结点上运行脚本,让脚本运行 args.gpus 个进程以在训练开始之前同步每个进程。

注意,脚本中的batchsize设置的是每个GPU的batchsize,因此实际的batchsize要乘上总共的GPU数目(worldsize)。

3.2.3 使用apex进行混合精度训练

    # Wrap the model
    ##############################################################
    '''
    amp.initialize 将模型和优化器为了进行后续混合精度训练而进行封装。
    注意,在调用 amp.initialize 之前,模型模型必须已经部署在GPU上。 
    opt_level 从 O0 (全部使用浮点数)一直到 O3 (全部使用半精度浮点数)。
    而 O1 和 O2 属于不同的混合精度程度,具体可以参阅APEX的官方文档。注意之前数字前面的是大写字母O。
    '''
    model, optimizer = amp.initialize(model, optimizer, 
                                      opt_level='O2')
    # apex.parallel.DistributedDataParallel 是一个 nn.DistributedDataParallel 的替换版本。
    # 我们不需要指定GPU,因为Apex在一个进程中只允许用一个GPU。
    # 且它也假设程序在把模型搬到GPU之前已经调用了 torch.cuda.set_device(local_rank)(line 10) .
    model = DDP(model)
    ##############################################################
    # Data loading code
	...
    start = datetime.now()
    total_step = len(train_loader)
    for epoch in range(args.epochs):
        for i, (images, labels) in enumerate(train_loader):
            images = images.cuda(non_blocking=True)
            labels = labels.cuda(non_blocking=True)
            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, labels)

            # Backward and optimize
            optimizer.zero_grad()
    ##############################################################
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
    ##############################################################
            optimizer.step()
     ...
同步BN

Apex为我们实现了同步BN,用于解决单GPU的minibatch太小导致BN在训练时不收敛的问题。

from apex.parallel import convert_syncbn_model
from apex.parallel import DistributedDataParallel
注意顺序:三个顺序不能错
model = convert_syncbn_model(UNet3d(n_channels=1, n_classes=1)).to(device)
model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
model = DistributedDataParallel(model, delay_allreduce=True)

调用该函数后,Apex会自动遍历model的所有层,将BatchNorm层替换掉。

def reduce_tensor(tensor: torch.Tensor) -> torch.Tensor:
    rt = tensor.clone()
    distributed.all_reduce(rt, op=distributed.reduce_op.SUM)
    rt /= distributed.get_world_size()#总进程数
    return rt

# calculate loss
loss = criterion(predict, labels)
reduced_loss = reduce_tensor(loss.data)
train_epoch_loss += reduced_loss.item()

# 注意在写入TensorBoard的时候只让一个进程写入就够了:

# TensorBoard
if args.local_rank == 0:
    writer.add_scalars('Loss/training', {
        'train_loss': train_epoch_loss,
        'val_loss': val_epoch_loss
    }, epoch + 1)
模型保存

在保存模型的时候,由于是Apex混合精度模型,我们需要使用Apex提供的保存、载入方法(见Apex README)

# Save checkpoint
checkpoint = {
    'model': model.state_dict(),
    'optimizer': optimizer.state_dict(),
    'amp': amp.state_dict()
}
torch.save(checkpoint, 'amp_checkpoint.pt')
...

# Restore
model = ...
optimizer = ...
checkpoint = torch.load('amp_checkpoint.pt')

model, optimizer = amp.initialize(model, optimizer, opt_level=opt_level)
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
amp.load_state_dict(checkpoint['amp'])

# Continue training
...

3.3 多卡后的 batch_size 和 learning_rate 的调整

从理论上来说,lr = batch_size * base lr,因为 batch_size 的增大会导致你 update 次数的减少,所以为了达到相同的效果,应该是同比例增大的。
但是更大的 lr 可能会导致收敛的不够好,尤其是在刚开始的时候,如果你使用很大的 lr,可能会直接爆炸,所以可能会需要一些 warmup 来逐步的把 lr 提高到你想设定的 lr。
实际应用中发现不一定要同比例增长,有时候可能增大到 batch_size/2 倍的效果已经很不错了。
在我的实验中,使用8卡训练,则增大batch_size 8倍,learning_rate 4倍是差不多的。

代码

import os
import datetime
import argparse
from tqdm import tqdm
import torch
from torch import distributed, optim
from torch.utils.data import DataLoader
#每个进程不同sampler
from torch.utils.data.distributed import DistributedSampler
from torch.utils.tensorboard import SummaryWriter
#混合精度
from apex import amp
#同步BN
from apex.parallel import convert_syncbn_model
#Distributed DataParallel
from apex.parallel import DistributedDataParallel

from models import UNet3d
from datasets import IronGrain3dDataset
from losses import BCEDiceLoss
from eval import eval_net

train_images_folder = '../../datasets/IronGrain/74x320x320/train_patches/images/'
train_labels_folder = '../../datasets/IronGrain/74x320x320/train_patches/labels/'
val_images_folder = '../../datasets/IronGrain/74x320x320/val_patches/images/'
val_labels_folder = '../../datasets/IronGrain/74x320x320/val_patches/labels/'


def parse():
    parser = argparse.ArgumentParser()
    parser.add_argument('--local_rank', type=int, default=0)
    args = parser.parse_args()
    return args


def main():
    args = parse()

#设置当前进程的device,GPU通信方式为NCCL
    torch.cuda.set_device(args.local_rank)
    distributed.init_process_group(
        'nccl',
        init_method='env://'
    )

#制作Dataset和sampler
    train_dataset = IronGrain3dDataset(train_images_folder, train_labels_folder)
    val_dataset = IronGrain3dDataset(val_images_folder, val_labels_folder)
    train_sampler = DistributedSampler(train_dataset)
    val_sampler = DistributedSampler(val_dataset)

    epochs = 100
    batch_size = 8
    lr = 2e-4
    weight_decay = 1e-4
    device = torch.device(f'cuda:{args.local_rank}')

#制作DataLoader
    train_loader = DataLoader(train_dataset, batch_size=batch_size, num_workers=4,
                              pin_memory=True, sampler=train_sampler)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, num_workers=4,
                            pin_memory=True, sampler=val_sampler)

#3步曲:同步BN,初始化amp,DistributedDataParallel封装
    net = convert_syncbn_model(UNet3d(n_channels=1, n_classes=1)).to(device)
    optimizer = optim.Adam(net.parameters(), lr=lr, weight_decay=weight_decay)
    net, optimizer = amp.initialize(net, optimizer, opt_level='O1')
    net = DistributedDataParallel(net, delay_allreduce=True)
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[25, 50, 75], gamma=0.2)
    criterion = BCEDiceLoss().to(device)

    if args.local_rank == 0:
        print(f'''Starting training:
            Epochs:          {epochs}
            Batch size:      {batch_size}
            Learning rate:   {lr}
            Training size:   {len(train_dataset)}
            Validation size: {len(val_dataset)}
            Device:          {device.type}
        ''')
        writer = SummaryWriter(
            log_dir=f'runs/irongrain/unet3d_32x160x160_BS_{batch_size}_{datetime.datetime.now()}'
        )
    for epoch in range(epochs):
        train_epoch_loss = 0
        with tqdm(total=len(train_dataset), desc=f'Epoch {epoch + 1}/{epochs}', unit='img') as pbar:

            images = None
            labels = None
            predict = None

            # train
            net.train()
            for batch_idx, batch in enumerate(train_loader):
                images = batch['image']
                labels = batch['label']
                images = images.to(device, dtype=torch.float32)
                labels = labels.to(device, dtype=torch.float32)

                predict = net(images)

                # calculate loss
                # reduce不同进程的loss
                loss = criterion(predict, labels)
                reduced_loss = reduce_tensor(loss.data)
                train_epoch_loss += reduced_loss.item()

                # optimize
                optimizer.zero_grad()
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
                optimizer.step()
                scheduler.step()

                # set progress bar
                pbar.set_postfix(**{'loss (batch)': loss.item()})
                pbar.update(images.shape[0])

            train_epoch_loss /= (batch_idx + 1)

            # eval
            val_epoch_loss, dice, iou = eval_net(net, criterion, val_loader, device, len(val_dataset))

            # TensorBoard
            if args.local_rank == 0:
                writer.add_scalars('Loss/training', {
                    'train_loss': train_epoch_loss,
                    'val_loss': val_epoch_loss
                }, epoch + 1)

                writer.add_scalars('Metrics/validation', {
                    'dice': dice,
                    'iou': iou
                }, epoch + 1)

                writer.add_images('images', images[:, :, 0, :, :], epoch + 1)
                writer.add_images('Label/ground_truth', labels[:, :, 0, :, :], epoch + 1)
                writer.add_images('Label/predict', torch.sigmoid(predict[:, :, 0, :, :]) > 0.5, epoch + 1)

            if args.local_rank == 0:
                torch.save(net, f'unet3d-epoch{epoch + 1}.pth')


def reduce_tensor(tensor: torch.Tensor) -> torch.Tensor:
    rt = tensor.clone()
    distributed.all_reduce(rt, op=distributed.reduce_op.SUM)
    rt /= distributed.get_world_size()#进程数
    return rt


if __name__ == '__main__':
    main()

问题总结

总结

参考资料

  • Thanks for https://blog.csdn.net/WZZ18191171661/article/details/103218532
  • Thanks for https://blog.csdn.net/qq_34914551/article/details/103203862
  • Thanks for https://jishuin.proginn.com/p/763bfbd2f4d2
Logo

鸿蒙生态一站式服务平台。

更多推荐