本笔记为阿里云天池龙珠计划深度学习训练营的学习内容,链接为:

https://tianchi.aliyun.com/specials/promotion/aicampdl

目录

  1. SRGAN超分辨模型

  2. SRGAN模型训练与测试
    1. 项目解读
    2. 模型训练
    3. 模型测试

本项目需要在 GPU 环境下运行,点击本页面最右边 < 进行切换

SRGAN超分辨模型(Super-Resolution Generative Adversarial Network

随着生成对抗网络GAN的发展,生成器和判别器的对抗学习机制在图像生成任务中展现出很强大的学习能力。Twitter的研究者们使用ResNet作为生成器结构,使用VGG作为判别器结构,提出了SRGAN模型,这是本次实践课使用的模型,其结构示意图如下:

Image

SRGAN模型

生成器结构包含了若干个不改变特征分辨率的残差模块和多个基于亚像素卷积的后上采样模块。

判别器结构则包含了若干个通道数不断增加的卷积层,每次特征通道数增加一倍时,特征分辨率降低为原来的一半。

SRGAN模型的损失函数包括两部分,内容损失与对抗损失。

对抗损失就是标准的GAN损失,而内容损失则是基于VGG网络特征构建,它代替了之前SRCNN使用的MSE损失函数,如下:

SRGAN通过生成器和判别器的对抗学习取得了视觉感知上更好的重建结果。不过基于GAN的模型虽然可以取得好的超分结果,但是也往往容易放大噪声。

SRGAN模型训练与测试

1. 项目解读

下面我们首先来剖析整个项目的代码。

1.1 数据集和基准模型

首先我们来介绍使用的数据集和基准模型,大多数超分重建任务的数据集都是通过从高分辨率图像进行降采样获得,这里我们也采用这样的方案。数据集既可以选择ImageNet这样包含上百万图像的大型数据集,也可以选择模式足够丰富的小数据集,这里我们选择一个垂直领域的高清人脸数据集,CelebA-HQ。CelebA-HQ数据集发布于2019年,包含30000张包括不同属性的高清人脸图,其中图像大小均为320×320。

数据集放置在项目根目录的 dataset 目录下,包括两个子文件夹,train 和 val。

在项目开始之前需要加载数据集和预训练模型,加载方式如下图所示: 仅第一次使用时操作!!!

Image

由于数据集较大,加载需要较长时间,加载完毕后,会有弹窗显示

Image

数据下载成功后,会得到两个 zip 文件,一个是数据集(Face_SuperResolution_Dataset.zip),一个是预训练模型(vgg16-397923af.zip)

Image

运行下方代码解压数据集,第一次使用时运行即可!!,当显示 10 个 . 时,代表解压完成。

# 屏蔽解压过程,并且在以“ ”开头的行,且是3000的整数倍时,打印“.”,30W数据解压完毕打印10个“.”;-o覆盖已存在的文件;|管道,将前一个命令的输出作为后一个命令的输入;
!unzip -o ./downloads/99975/Face_SuperResolution_Dataset.zip | awk 'BEGIN {ORS=" "} {if(NR%3000==0)print "."}'
. . . . . . . . . . 

解压完成后会得到一个 dataset 文件夹,其文件结构如下

dataset
    - train
    - val

运行下方代码解压预训练模型,第一次使用时运行即可!!

!mkdir checkpoints
!unzip -o vgg16-397923af.zip -d ./checkpoints/
Archive:  vgg16-397923af.zip

  inflating: ./hub/checkpoints/vgg16-397923af.pth  

1.2 数据集接口

下面我们从高分辨率图进行采样得到低分辨率图,然后组成训练用的图像对,核心代码如下:

from os import listdir  # 返回指定文件夹下包含的文件和文件夹列表
from os.path import join  # 拼接路径
from PIL import Image  # 加载图像
from torch.utils.data.dataset import Dataset  # 数据集基类
from torchvision.transforms import Compose, RandomCrop, ToTensor, ToPILImage, CenterCrop, Resize  # 图像预处理
def is_image_file(filename):

          # 通过文件扩展名判断是否为图像,若是返回True

    return any(
        filename.endswith(extension)
        for extension in ['.png', '.jpg', '.jpeg', '.PNG', '.JPG', '.JPEG'])
# 基于上采样因子对裁剪尺寸进行调整,使其为upscale_factor的整数倍
def calculate_valid_crop_size(crop_size, upscale_factor):
    return crop_size - (crop_size % upscale_factor)  # %——取模并返回余数
# 训练集高分辨率图预处理函数
def train_hr_transform(crop_size):
    return Compose([
        RandomCrop(crop_size),  # 随机剪裁后输出尺寸为crop_size,crop_size小于原始图像的H和W
        ToTensor(),  # 转化为tensor
    ])
# 训练集低分辨率图预处理函数
def train_lr_transform(crop_size, upscale_factor):
    return Compose([
        ToPILImage(),  # 将tensor转变为PILImage,然后进行预处理

                  # 将图像短边缩小到crop_size // upscale_factor的整数值,并保持图像高宽比不变

        Resize(crop_size // upscale_factor, interpolation=Image.BICUBIC),
        ToTensor()
    ])
def display_transform():
    return Compose([ToPILImage(), Resize(400), CenterCrop(400), ToTensor()])
# 训练数据集类,继承Dataset类
class TrainDatasetFromFolder(Dataset):
    def __init__(self, dataset_dir, crop_size, upscale_factor):
        super(TrainDatasetFromFolder, self).__init__()

                  # 将所有训练图像完整路径存储到一个list里,for-if-join

        self.image_filenames = [
            join(dataset_dir, x) for x in listdir(dataset_dir)
            if is_image_file(x)
        ]  # 获得所有图像
        crop_size = calculate_valid_crop_size(crop_size,
                                              upscale_factor)  # 获得裁剪尺寸
        self.hr_transform = train_hr_transform(crop_size)  # 高分辨率图预处理函数
        self.lr_transform = train_lr_transform(crop_size,  
                                               upscale_factor)  # 低分辨率图预处理函数,此时crop_size已处理为upscale_factor的整数倍
    # 数据集迭代指针
    def __getitem__(self, index):
        hr_image = self.hr_transform(Image.open(
            self.image_filenames[index]))  # 随机裁剪获得高分辨率图
        lr_image = self.lr_transform(hr_image)  # 获得低分辨率图,基于随机剪裁后的图片进行缩小,两张图像内容一致
        return lr_image, hr_image  # 返回训练集图像对
    def __len__(self):
        return len(self.image_filenames)  # 返回所有训练图片的数量
# 验证数据集类
class ValDatasetFromFolder(Dataset):
    def __init__(self, dataset_dir, upscale_factor):
        super(ValDatasetFromFolder, self).__init__()
        self.upscale_factor = upscale_factor
        self.image_filenames = [
            join(dataset_dir, x) for x in listdir(dataset_dir)
            if is_image_file(x)
        ]
    def __getitem__(self, index):
        hr_image = Image.open(self.image_filenames[index])
        # 获得图像窄边,获得裁剪尺寸,size返回图像的(h,w)
        h, w = hr_image.size

                  # 验证集的剪裁尺寸由每张图像的窄边决定,数据集里的图像尺寸是320*320

        crop_size = calculate_valid_crop_size(min(w, h), self.upscale_factor)
        lr_scale = Resize(crop_size // self.upscale_factor,
                          interpolation=Image.BICUBIC)

                  # 高分辨率图像与原始图像大小一致

        hr_scale = Resize(crop_size, interpolation=Image.BICUBIC)
        hr_image = CenterCrop(crop_size)(hr_image)  # 中心裁剪获得高分辨率图
        lr_image = lr_scale(hr_image)  # 获得低分辨率图,基于高分辨率图像,图像尺寸变小,但是清晰度没变
        hr_restore_img = hr_scale(lr_image)  # 将低分辨率图像放大到与高分辨率图像一致的尺寸,图像会变模糊
        return ToTensor()(lr_image), ToTensor()(hr_restore_img), ToTensor()(
            hr_image)  # 将处理后的图像转化为张量并返回,网络只能处理张量
    def __len__(self):
        return len(self.image_filenames)
# 测试数据集类
class TestDatasetFromFolder(Dataset):
    def __init__(self, dataset_dir, upscale_factor):
        super(TestDatasetFromFolder, self).__init__()
        self.lr_path = dataset_dir + '/SRF_' + str(upscale_factor) + '/data/'
        self.hr_path = dataset_dir + '/SRF_' + str(upscale_factor) + '/target/'
        self.upscale_factor = upscale_factor
        self.lr_filenames = [
            join(self.lr_path, x) for x in listdir(self.lr_path)
            if is_image_file(x)
        ]
        self.hr_filenames = [
            join(self.hr_path, x) for x in listdir(self.hr_path)
            if is_image_file(x)
        ]
    def __getitem__(self, index):
        image_name = self.lr_filenames[index].split('/')[-1]
        lr_image = Image.open(self.lr_filenames[index])
        h, w = lr_image.size
        hr_image = Image.open(self.hr_filenames[index])
        hr_scale = Resize((self.upscale_factor * h, self.upscale_factor * w),
                          interpolation=Image.BICUBIC)
        hr_restore_img = hr_scale(lr_image)
        return image_name, ToTensor()(lr_image), ToTensor()(
            hr_restore_img), ToTensor()(hr_image)
    def __len__(self):
        return len(self.lr_filenames)

从上述代码可以看出,包含了两个预处理函数接口,分别是train_hr_transform,train_lr_transform。train_hr_transform包含的操作主要是随机裁剪,而train_lr_transform包含的操作主要是缩放。

另外还有一个函数calculate_valid_crop_size,对于训练集来说,它用于当配置的图像尺寸crop_size不能整除上采样因子upscale_factor时对crop_size进行调整,我们在使用的时候应该避免这一点,即配置crop_size让它等于upscale_factor的整数倍。对于验证集,图像的窄边min(w, h)会被用于crop_size的初始化,所以该函数的作用是当图像的窄边不能整除上采样因子upscale_factor时对crop_size进行调整。

训练集类TrainDatasetFromFolder包含了若干操作,它使用train_hr_transform从原图像中随机裁剪大小为裁剪尺寸的正方形的图像,使用train_lr_transform获得对应的低分辨率图。而验证集类ValDatasetFromFolder则将图像按照调整后的crop_size进行中心裁剪,然后使用train_lr_transform获得对应的低分辨率图。

在这里我们只使用了随机裁剪作为训练时的数据增强操作,实际训练工程项目时,应该根据需要添加多种数据增强操作才能获得泛化能力更好的模型。

1.3 生成器

生成器是一个基于残差模块的上采样模型,它的定义包括残差模块,上采样模块以及主干模型,如下:

import math  # 数学函数
import torch
from torch import nn  # 神经网络模型
# 生成模型,先通过卷积层提取特征信息,再通过上采样层放大
class Generator(nn.Module):
    def __init__(self, scale_factor):
        upsample_block_num = int(math.log(scale_factor, 2))
        super(Generator, self).__init__()  # 初始化基类
        # 第一个卷积层,卷积核大小为9×9,输入通道数为3(彩色图片),输出通道数为64,使用Sequential容器将卷积层和激活层打包
        self.block1 = nn.Sequential(nn.Conv2d(3, 64, kernel_size=9, padding=4),
                                    nn.PReLU())
        # 6个残差模块,完成了12次卷积、12次归一化、6次激活、6次跳跃连接
        self.block2 = ResidualBlock(64)
        self.block3 = ResidualBlock(64)
        self.block4 = ResidualBlock(64)
        self.block5 = ResidualBlock(64)
        self.block6 = ResidualBlock(64)
        self.block7 = nn.Sequential(
            nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64))
        # upsample_block_num个上采样模块,每一个上采样模块恢复2倍的上采样倍率

                  # block8包含两个上采样层和一个卷积层

        block8 = [UpsampleBLock(64, 2) for _ in range(upsample_block_num)]
        # 最后一个卷积层,卷积核大小为9×9,输入通道数为64,输出通道数为3
        block8.append(nn.Conv2d(64, 3, kernel_size=9, padding=4))
        self.block8 = nn.Sequential(*block8)  # Sequential容器里有两个上采样层和一个卷积层
    def forward(self, x):  # init方法创造积木,forward方法搭积木
        block1 = self.block1(x)  # input经block1后,channel变为64,H和W不变
        block2 = self.block2(block1)  # block1经残差块block2后,channel、H、W均未变
        block3 = self.block3(block2)  
        block4 = self.block4(block3)  
        block5 = self.block5(block4)  
        block6 = self.block6(block5)  
        block7 = self.block7(block6)  # block7本身不是残差块,但是下面和block1相加后变成了残差块
        block8 = self.block8(block1 + block7)  # block8经过两次上采样和一次卷积后,channel变为3,H和W变为输入的4倍
        return (torch.tanh(block8) + 1) / 2  # 经激活层后输出范围为[-1,1],最终调整为[0,1]

               

# 残差模块,解决深层网络梯度消失和信息量减少问题
class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super(ResidualBlock, self).__init__()
        # 两个卷积层,卷积核大小为3×3,通道数不变,图像尺寸也不变
        self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(channels)  # 批量归一化层
        self.prelu = nn.PReLU()
        self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(channels)  # 第二个批量归一化层之后没有连接激活层
    def forward(self, x):
        residual = self.conv1(x)
        residual = self.bn1(residual)
        residual = self.prelu(residual)
        residual = self.conv2(residual)
        residual = self.bn2(residual)
        return x + residual  # 增加跳跃连接

# 上采样模块,每一个恢复分辨率为2
class UpsampleBLock(nn.Module):
    def __init__(self, in_channels, up_scale):
        super(UpsampleBLock, self).__init__()
        # 卷积层,输入通道数为in_channels,输出通道数为in_channels * up_scale ** 2,图像尺寸不变
        self.conv = nn.Conv2d(in_channels,
                              in_channels * up_scale**2,
                              kernel_size=3,
                              padding=1)
        # PixelShuffle上采样层,来自于后上采样结构
        self.pixel_shuffle = nn.PixelShuffle(up_scale)  # 像素重排,降低通道数,增加H和W,既低分辨率调整为高分辨率
        self.prelu = nn.PReLU()
    def forward(self, x):
        x = self.conv(x)
        x = self.pixel_shuffle(x)
        x = self.prelu(x)
        return x

在上述的生成器定义中,调用了nn.PixelShuffle模块来实现上采样,它的具体原理在上节基于亚像素卷积的后上采样ESPCN模型中有详细介绍。

1.4 判别器

判别器是一个普通的类似于VGG的CNN模型,完整定义如下:

# 残差模块
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.net = nn.Sequential(
            # 第1个卷积层,卷积核大小为3×3,输入通道数为3,输出通道数为64
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.LeakyReLU(0.2),
            # 第2个卷积层,卷积核大小为3×3,输入通道数为64,输出通道数为64
            nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2),
            # 第3个卷积层,卷积核大小为3×3,输入通道数为64,输出通道数为128
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            # 第4个卷积层,卷积核大小为3×3,输入通道数为128,输出通道数为128
            nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            # 第5个卷积层,卷积核大小为3×3,输入通道数为128,输出通道数为256
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2),
            # 第6个卷积层,卷积核大小为3×3,输入通道数为256,输出通道数为256
            nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2),
            # 第7个卷积层,卷积核大小为3×3,输入通道数为256,输出通道数为512
            nn.Conv2d(256, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2),
            # 第8个卷积层,卷积核大小为3×3,输入通道数为512,输出通道数为512
            nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2),
            # 全局池化层,将特征图转换为特征向量
            nn.AdaptiveAvgPool2d(1),  # 输出shape为(m,512,1,1)
            # 两个全连接层,使用卷积实现
            nn.Conv2d(512, 1024, kernel_size=1),  # 输出shape为(m,1024,1,1)
            nn.LeakyReLU(0.2),
            nn.Conv2d(1024, 1, kernel_size=1))  # 使用卷积层实现将特征向量映射为标量,判断输出图像是否为真实图像,输出shape为(m,1,1,1)
    def forward(self, x):
        batch_size = x.size(0)  # 提取批量值
        return torch.sigmoid(self.net(x).view(batch_size))  # 输出张量展开为1维,sigmoid则将输出结果控制在(0,1)范围内

1.5 损失定义

import torch
from torch import nn
from torchvision.models.vgg import vgg16
import os
os.environ['TORCH_HOME'] = './'  # 将环境变了“TORCH_HOME”设置为当前目录
# 生成器损失定义
class GeneratorLoss(nn.Module):
    def __init__(self):
        super(GeneratorLoss, self).__init__()
        vgg = vgg16(pretrained=True)  # 加载预训练的VGG模型

                  # 由于vgg.features返回的是包含31个层结构的容器,因此下面语句等价于:

                  # loss_network = vgg.features.eval()

        loss_network = nn.Sequential(*list(vgg.features)[:31]).eval()
        for param in loss_network.parameters():
            param.requires_grad = False  # 取消所有参数的梯度计算,预加载的网络无需再次训练
        self.loss_network = loss_network
        self.mse_loss = nn.MSELoss()  # MSE损失
        self.tv_loss = TVLoss()  # TV平滑损失
    def forward(self, out_labels, out_images, target_images):
        # 对抗损失,输入图片全是真实数据时,样本标签为全1,减去网络的预测张量,再求平均值,既为对抗损失
        adversarial_loss = torch.mean(1 - out_labels)  
        # 感知损失,将生成图像和目标图像分别输入loss_network,求两种情况下的输出损失
        perception_loss = self.mse_loss(self.loss_network(out_images),
                                        self.loss_network(target_images))
        # 图像MSE损失,直接求输出图像和目标图像之间的误差
        image_loss = self.mse_loss(out_images, target_images)
        # TV平滑损失,求输出图像的平滑损失
        tv_loss = self.tv_loss(out_images)

                  # 将各损失加权求和,既为生成器的综合损失

        return image_loss + 0.001 * adversarial_loss + 0.006 * perception_loss + 2e-8 * tv_loss
# TV平滑损失
class TVLoss(nn.Module):
    def __init__(self, tv_loss_weight=1):
        super(TVLoss, self).__init__()
        self.tv_loss_weight = tv_loss_weight
    def forward(self, x):
        batch_size = x.size()[0]
        h_x = x.size()[2]
        w_x = x.size()[3]
        count_h = self.tensor_size(x[:, :, 1:, :])
        count_w = self.tensor_size(x[:, :, :, 1:])

                  # 图像垂直方向相邻像素之间的差异,求平方和

        h_tv = torch.pow((x[:, :, 1:, :] - x[:, :, :h_x - 1, :]), 2).sum()

                  # 图像水平方向相邻像素之间的差异

        w_tv = torch.pow((x[:, :, :, 1:] - x[:, :, :, :w_x - 1]), 2).sum()
        return self.tv_loss_weight * 2 * (h_tv / count_h +
                                          w_tv / count_w) / batch_size
    @staticmethod  # 装饰器,定义类的静态方法,可以通过类直接调用
    def tensor_size(t):
        return t.size()[1] * t.size()[2] * t.size()[3]

  # 调试生成器损失

if __name__ == "__main__":
    g_loss = GeneratorLoss()
    print(g_loss)
GeneratorLoss(

  (loss_network): Sequential(

    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))

    (1): ReLU(inplace=True)

    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))

    (3): ReLU(inplace=True)

    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)

    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))

    (6): ReLU(inplace=True)

    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))

    (8): ReLU(inplace=True)

    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)

    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))

    (11): ReLU(inplace=True)

    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))

    (13): ReLU(inplace=True)

    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))

    (15): ReLU(inplace=True)

    (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)

    (17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))

    (18): ReLU(inplace=True)

    (19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))

    (20): ReLU(inplace=True)

    (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))

    (22): ReLU(inplace=True)

    (23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)

    (24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))

    (25): ReLU(inplace=True)

    (26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))

    (27): ReLU(inplace=True)

    (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))

    (29): ReLU(inplace=True)

    (30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)

  )

  (mse_loss): MSELoss()

  (tv_loss): TVLoss()

)

生成器损失总共包含4部分,分别是对抗网络损失,逐像素的图像MSE损失,基于VGG模型的感知损失,用于约束图像平滑的TV平滑损失。

2. 模型训练

接下来我们来解读模型的核心训练代码,查看模型训练的结果。训练代码除了模型和损失定义,还需要完成优化器定义,训练和验证指标变量的存储,核心代码如下:

from math import exp  # 求指数
import torch
import torch.nn.functional as F  # 函数式接口模块,调用卷积等函数
from torch.autograd import Variable

  # 一维高斯核函数,返回高斯权重

def gaussian(window_size, sigma):
    gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)])
    return gauss / gauss.sum()  # 确保高斯权重之和为1,size为[window_size,]
def create_window(window_size, channel):
    _1D_window = gaussian(window_size, 1.5).unsqueeze(1)  # 一维高斯核,sigma取默认值1.5,size为[window_size, 1]

          # 对一维高斯核进行转置,然后与一维高斯核通过矩阵乘,得到二维高斯核,size为[1,1,window_size,window_size]

    _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
    window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())  # _2D_window经过expand之后内存不连续了,contiguous使内存连续,同时完成深拷贝
    return window

  # 计算图像结构相似性指数

def _ssim(img1, img2, window, window_size, channel, size_average=True):

          # img1的加权均值

    mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)

          # img2的加权均值

    mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)
    mu1_sq = mu1.pow(2)
    mu2_sq = mu2.pow(2)
    mu1_mu2 = mu1 * mu2

          # 方差公式:Var(X) = E[X**2] - E[X]**2

    sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq
    sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq

          # 协方差的公式:cov(X,Y) = E[XY] - E[X]*E[Y]

    sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2

          # C1和C2的作用是保证计算过程中分母不为0

    C1 = 0.01 ** 2
    C2 = 0.03 ** 2

          # 结构相似性

    ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
    if size_average:
        return ssim_map.mean()  # 直接对整个tensor求均值
    else:
        return ssim_map.mean(1).mean(1).mean(1)  # 分别对tensor的第二维度求3次均值,如果tensor为4维,则与直接对tensor求均值等同
class SSIM(torch.nn.Module):
    def __init__(self, window_size=11, size_average=True):
        super(SSIM, self).__init__()
        self.window_size = window_size  # 高斯窗口尺寸
        self.size_average = size_average  # 是否直接对结果整体进行平均计算
        self.channel = 1
        self.window = create_window(window_size, self.channel)
    def forward(self, img1, img2):
        (_, channel, _, _) = img1.size()  # 获得图像的通道数

                  # 如果图像的通道数为1,且图像的数据类型为float,则使用初始化时创建的窗口,否则使用图像的实际通道创建窗口

        if channel == self.channel and self.window.data.type() == img1.data.type():
            window = self.window
        else:

                          # 使用图像的通道数创建高斯窗口

            window = create_window(self.window_size, channel)
            if img1.is_cuda:
                window = window.cuda(img1.get_device())
            window = window.type_as(img1)  # 将高斯窗口的数据类型设置为与输入图像一致
            self.window = window
            self.channel = channel
        return _ssim(img1, img2, window, self.window_size, channel, self.size_average)
def ssim(img1, img2, window_size=11, size_average=True):
    (_, channel, _, _) = img1.size()
    window = create_window(window_size, channel)
    if img1.is_cuda:
        window = window.cuda(img1.get_device())
    window = window.type_as(img1)
    return _ssim(img1, img2, window, window_size, channel, size_average)

创建一些文件夹,首次使用时运行!!!

!mkdir training_results
!mkdir epochs
!mkdir statistics
mkdir: cannot create directory ‘training_results’: File exists

mkdir: cannot create directory ‘epochs’: File exists

mkdir: cannot create directory ‘statistics’: File exists

注意:由于阿里云平台 GPU 资源受限,本项目仅使用少量数据集进行训练

import os
from math import log10  # 计算对数
import pandas as pd  # 处理和分析数据
import torch.optim as optim  # 优化器
import torch.utils.data  # 数据预处理
import torchvision.utils as utils  # 图像预处理
from torch.autograd import Variable  # 自动求导
from torch.utils.data import DataLoader  # 分批载入数据
from tqdm import tqdm  # 进度条
if __name__ == '__main__':   # 当前模块是否被直接执行,还是被其它模块导入
    CROP_SIZE = 240 #opt.crop_size   ## 裁剪尺寸,即训练尺度
    UPSCALE_FACTOR = 4  #opt.upscale_factor  ## 超分上采样倍率
    NUM_EPOCHS = 20  #opt.num_epochs  ## 迭代epoch次数 
    ## 获取训练集/验证集
    train_set = TrainDatasetFromFolder('dataset/train', crop_size=CROP_SIZE, upscale_factor=UPSCALE_FACTOR)
    val_set = ValDatasetFromFolder('dataset/val', upscale_factor=UPSCALE_FACTOR)
    train_loader = DataLoader(dataset=train_set, num_workers=4, batch_size=16, shuffle=True)
    val_loader = DataLoader(dataset=val_set, num_workers=4, batch_size=1, shuffle=False) 
    netG = Generator(UPSCALE_FACTOR) ##生成器定义
    netD = Discriminator() ##判别器定义
    generator_criterion = GeneratorLoss() ##生成器优化目标 
    ## 是否使用GPU
    if torch.cuda.is_available():
        netG.cuda()
        netD.cuda()
        generator_criterion.cuda() 
    ##生成器和判别器优化器
    optimizerG = optim.Adam(netG.parameters())
    optimizerD = optim.Adam(netD.parameters()) 
    results = {'d_loss': [], 'g_loss': [], 'd_score': [], 'g_score': [], 'psnr': [], 'ssim': []}  # PSNR峰值信噪比
    ## epoch迭代
    for epoch in range(1, NUM_EPOCHS + 1):
        train_bar = tqdm(train_loader)  # 将批量数据生成器传给tqdm,创建一个进度条
        running_results = {'batch_sizes': 0, 'd_loss': 0, 'g_loss': 0, 'd_score': 0, 'g_score': 0} ##结果变量
        netG.train() ##生成器训练,允许网络改变权重值,默认即为training状态
        netD.train() ##判别器训练
        ## 每一个epoch的数据迭代,data为lr_image,target为hr_image
        for data, target in train_bar:
            g_update_first = True
            batch_size = data.size(0)  # 批量值
            running_results['batch_sizes'] += batch_size  # 记录每一轮里训练的批量数
            ## 优化判别器,最大化D(x)-1-D(G(z))
            real_img = Variable(target)

                          # 将图像复制到GPU

            if torch.cuda.is_available():
                real_img = real_img.cuda()
            z = Variable(data)
            if torch.cuda.is_available():
                z = z.cuda()
            fake_img = netG(z) ## 通过低分辨率图像获取生成结果
            netD.zero_grad()  # 判别器参数梯度置0
            real_out = netD(real_img).mean()
            fake_out = netD(fake_img).mean()
            d_loss = 1 - real_out + fake_out  # 判别器的损失既误判率,在真实数据上的误判加上在假数据上的误判
            d_loss.backward(retain_graph=True)  # 反向传播,计算梯度
            optimizerD.step() ##优化判别器
            ## 优化生成器 最小化1-D(G(z)) + Perception Loss + Image Loss + TV Loss
            netG.zero_grad()  # 生成器的参数梯度置0
            fake_img = netG(z)
            fake_out = netD(fake_img).mean()
            g_loss = generator_criterion(fake_out, fake_img, real_img)
            g_loss.backward()  # 反向传播计算梯度 
            # fake_img = netG(z)
            # fake_out = netD(fake_img).mean()
            optimizerG.step()
            # 记录当前损失
            running_results['g_loss'] += g_loss.item() * batch_size
            running_results['d_loss'] += d_loss.item() * batch_size
            running_results['d_score'] += real_out.item() * batch_size
            running_results['g_score'] += fake_out.item() * batch_size
            train_bar.set_description(desc='[%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f' % (
                epoch, NUM_EPOCHS, running_results['d_loss'] / running_results['batch_sizes'],
                running_results['g_loss'] / running_results['batch_sizes'],
                running_results['d_score'] / running_results['batch_sizes'],
                running_results['g_score'] / running_results['batch_sizes'])) 
        ## 对验证集进行验证
        netG.eval() ## 设置验证模式,不改变BN层的参数
        out_path = 'training_results/SRF_' + str(UPSCALE_FACTOR) + '/'
        if not os.path.exists(out_path):
            os.makedirs(out_path) 
        ## 计算验证集相关指标
        with torch.no_grad():  # 不生成计算图,节省资源
            val_bar = tqdm(val_loader)  # 验证过程进度条
            valing_results = {'mse': 0, 'ssims': 0, 'psnr': 0, 'ssim': 0, 'batch_sizes': 0}
            val_images = []
            for val_lr, val_hr_restore, val_hr in val_bar:  # val_lr低分辨率图像,val_hr_restore通过低分辨率图像恢复后的高分辨率图像,val_hr高分辨率图像
                batch_size = val_lr.size(0)  # 验证集批量大小1
                valing_results['batch_sizes'] += batch_size  # 存储已验证的图片数量
                lr = val_lr ##低分辨率真值图
                hr = val_hr ##高分辨率真值图
                if torch.cuda.is_available():  # 拷贝到GPU
                    lr = lr.cuda() 
                    hr = hr.cuda()
                sr = netG(lr) ##超分重建结果 
                batch_mse = ((sr - hr) ** 2).data.mean() ##计算MSE指标
                valing_results['mse'] += batch_mse * batch_size  # 累计每个批量的mse值
                valing_results['psnr'] = 10 * log10(1 / (valing_results['mse'] / valing_results['batch_sizes'])) ##计算PSNR指标
                batch_ssim = ssim(sr, hr).item() ##计算SSIM指标
                valing_results['ssims'] += batch_ssim * batch_size  # 累计每个批量的ssim
                valing_results['ssim'] = valing_results['ssims'] / valing_results['batch_sizes']  # 已验证图像的ssim均值
        ## 存储模型参数
        torch.save(netG.state_dict(), 'epochs/netG_epoch_%d_%d.pth' % (UPSCALE_FACTOR, epoch))
        torch.save(netD.state_dict(), 'epochs/netD_epoch_%d_%d.pth' % (UPSCALE_FACTOR, epoch))
        ## 记录训练集损失以及验证集的psnr,ssim等指标 \scores\psnr\ssim
        results['d_loss'].append(running_results['d_loss'] / running_results['batch_sizes'])
        results['g_loss'].append(running_results['g_loss'] / running_results['batch_sizes'])
        results['d_score'].append(running_results['d_score'] / running_results['batch_sizes'])
        results['g_score'].append(running_results['g_score'] / running_results['batch_sizes'])
        results['psnr'].append(valing_results['psnr'])
        results['ssim'].append(valing_results['ssim']) 
        ## 存储结果到本地文件
        if epoch % 10 == 0 and epoch != 0:
            out_path = 'statistics/'
            data_frame = pd.DataFrame(
                data={'Loss_D': results['d_loss'], 'Loss_G': results['g_loss'], 'Score_D': results['d_score'],
                      'Score_G': results['g_score'], 'PSNR': results['psnr'], 'SSIM': results['ssim']},
                index=range(1, epoch + 1))
            data_frame.to_csv(out_path + 'srf_' + str(UPSCALE_FACTOR) + '_train_results.csv', index_label='Epoch')
[1/20] Loss_D: 0.9206 Loss_G: 0.0191 D(x): 0.3966 D(G(z)): 0.2900: 100%|██████████| 44/44 [00:31<00:00,  1.40it/s]

100%|██████████| 85/85 [00:01<00:00, 45.18it/s]

[2/20] Loss_D: 0.9882 Loss_G: 0.0099 D(x): 0.3624 D(G(z)): 0.3444: 100%|██████████| 44/44 [00:31<00:00,  1.41it/s]

100%|██████████| 85/85 [00:01<00:00, 45.75it/s]

[3/20] Loss_D: 0.9785 Loss_G: 0.0087 D(x): 0.3593 D(G(z)): 0.3436: 100%|██████████| 44/44 [00:31<00:00,  1.40it/s]

100%|██████████| 85/85 [00:01<00:00, 45.57it/s]

[4/20] Loss_D: 1.0011 Loss_G: 0.0070 D(x): 0.5323 D(G(z)): 0.5201: 100%|██████████| 44/44 [00:31<00:00,  1.41it/s]

100%|██████████| 85/85 [00:01<00:00, 45.72it/s]

[5/20] Loss_D: 0.9912 Loss_G: 0.0071 D(x): 0.3812 D(G(z)): 0.3706: 100%|██████████| 44/44 [00:31<00:00,  1.41it/s]

100%|██████████| 85/85 [00:01<00:00, 45.98it/s]

[6/20] Loss_D: 0.9738 Loss_G: 0.0072 D(x): 0.4276 D(G(z)): 0.3970: 100%|██████████| 44/44 [00:31<00:00,  1.41it/s]

100%|██████████| 85/85 [00:01<00:00, 45.68it/s]

[7/20] Loss_D: 1.0016 Loss_G: 0.0066 D(x): 0.1955 D(G(z)): 0.1928: 100%|██████████| 44/44 [00:31<00:00,  1.41it/s]

100%|██████████| 85/85 [00:01<00:00, 45.08it/s]

[8/20] Loss_D: 1.0028 Loss_G: 0.0064 D(x): 0.1531 D(G(z)): 0.1508: 100%|██████████| 44/44 [00:31<00:00,  1.40it/s]

100%|██████████| 85/85 [00:01<00:00, 45.72it/s]

[9/20] Loss_D: 1.0018 Loss_G: 0.0066 D(x): 0.0594 D(G(z)): 0.0610: 100%|██████████| 44/44 [00:31<00:00,  1.41it/s]

100%|██████████| 85/85 [00:01<00:00, 45.19it/s]

[10/20] Loss_D: 0.9963 Loss_G: 0.0061 D(x): 0.0795 D(G(z)): 0.0764: 100%|██████████| 44/44 [00:31<00:00,  1.41it/s]

100%|██████████| 85/85 [00:01<00:00, 45.63it/s]

[11/20] Loss_D: 1.0042 Loss_G: 0.0061 D(x): 0.1649 D(G(z)): 0.1674: 100%|██████████| 44/44 [00:31<00:00,  1.41it/s]

100%|██████████| 85/85 [00:01<00:00, 45.91it/s]

[12/20] Loss_D: 0.9918 Loss_G: 0.0058 D(x): 0.2955 D(G(z)): 0.2907: 100%|██████████| 44/44 [00:31<00:00,  1.41it/s]

100%|██████████| 85/85 [00:01<00:00, 45.34it/s]

[13/20] Loss_D: 1.0028 Loss_G: 0.0056 D(x): 0.2586 D(G(z)): 0.2455: 100%|██████████| 44/44 [00:31<00:00,  1.41it/s]

100%|██████████| 85/85 [00:01<00:00, 45.59it/s]

[14/20] Loss_D: 1.0006 Loss_G: 0.0057 D(x): 0.1642 D(G(z)): 0.1645: 100%|██████████| 44/44 [00:31<00:00,  1.41it/s]

100%|██████████| 85/85 [00:01<00:00, 45.12it/s]

[15/20] Loss_D: 0.9968 Loss_G: 0.0057 D(x): 0.2240 D(G(z)): 0.2179: 100%|██████████| 44/44 [00:31<00:00,  1.41it/s]

100%|██████████| 85/85 [00:01<00:00, 45.46it/s]

[16/20] Loss_D: 1.0059 Loss_G: 0.0055 D(x): 0.1927 D(G(z)): 0.1983: 100%|██████████| 44/44 [00:31<00:00,  1.41it/s]

100%|██████████| 85/85 [00:01<00:00, 45.51it/s]

[17/20] Loss_D: 0.9990 Loss_G: 0.0057 D(x): 0.2181 D(G(z)): 0.2166: 100%|██████████| 44/44 [00:31<00:00,  1.41it/s]

100%|██████████| 85/85 [00:01<00:00, 45.56it/s]

[18/20] Loss_D: 0.9990 Loss_G: 0.0051 D(x): 0.2205 D(G(z)): 0.2180: 100%|██████████| 44/44 [00:31<00:00,  1.41it/s]

100%|██████████| 85/85 [00:01<00:00, 45.73it/s]

[19/20] Loss_D: 1.0063 Loss_G: 0.0051 D(x): 0.2115 D(G(z)): 0.2121: 100%|██████████| 44/44 [00:31<00:00,  1.41it/s]

100%|██████████| 85/85 [00:01<00:00, 45.92it/s]

[20/20] Loss_D: 0.9974 Loss_G: 0.0050 D(x): 0.1125 D(G(z)): 0.1072: 100%|██████████| 44/44 [00:31<00:00,  1.41it/s]

100%|██████████| 85/85 [00:01<00:00, 44.92it/s]

从上述代码可以看出,训练时采用的crop_size为240×240,批处理大小为16,使用的优化器为Adam,Adam采用了默认的优化参数。

损失等相关数据将生成在 statistics 文件夹下。

上采样倍率为4的模型训练结果如下:

Image

4倍上采样的PSNR和SSIM曲线

3. 模型测试

接下来我们进行模型的测试。

3.1 测试代码

首先解读测试代码,需要完成模型的载入,图像预处理和结果存储,完整代码如下:

import torch
from PIL import Image
from torch.autograd import Variable
from torchvision.transforms import ToTensor, ToPILImage
UPSCALE_FACTOR = 4 ##上采样倍率
TEST_MODE = True ## 使用GPU进行测试
IMAGE_NAME = "./dataset/val/10879.jpg"  # 测试图片路径
MODEL_NAME = './epochs/netG_epoch_4_20.pth' ##模型路径
model = Generator(UPSCALE_FACTOR).eval() ##设置验证模式
if TEST_MODE:
    model.cuda()
    model.load_state_dict(torch.load(MODEL_NAME))  # 加载训练好的生成器的权重
else:
    model.load_state_dict(torch.load(MODEL_NAME, map_location=lambda storage, loc: storage))
image = Image.open(IMAGE_NAME) ##读取图片
image = Variable(ToTensor()(image), volatile=True).unsqueeze(0) ##图像预处理
if TEST_MODE:
    image = image.cuda()
with torch.no_grad():
    RESULT_NAME = "out_srf_" + str(UPSCALE_FACTOR) + "_" + IMAGE_NAME.split("/")[-1]
    out = model(image)
    out_img = ToPILImage()(out[0].data.cpu())
    out_img.save(RESULT_NAME)
/opt/conda/lib/python3.6/site-packages/ipykernel_launcher.py:21: UserWarning: volatile was removed and now has no effect. Use `with torch.no_grad():` instead.

预测结果将在本级目录生成,以 out_srf_ 开头

3.2 重建结果

下图展示了若干图片的超分辨结果。

第一行为使用双线性插值进行上采样的结果, 第二行为4倍超分结果,第三行为原始大图。

Image

本次我们对SRGAN模型进行了实践,使用高清人脸数据集进行训练,对低分辨率的人脸图像进行了超分重建,验证了SRGAN模型的有效性,不过该模型仍然有较大的改进空间,它需要使用成对数据集进行训练,而训练时低分辨率图片的模式产生过于简单,无法对复杂的退化类型完成重建。

当要对退化类型更加复杂的图像进行超分辨重建时,模型训练时也应该采取多种对应的数据增强方法,包括但不限于对比度增强,各类噪声污染,JPEG压缩失真等操作,这些就留给读者去做更多的实验。

Logo

权威|前沿|技术|干货|国内首个API全生命周期开发者社区

更多推荐