编写一个自己的项目,并不是从零开始。而是在已有的框架下,往里填充东西。这个框架可以是别人的项目,也可以是自己的框架。下面是我自己的PyTorch框架。

1. 变量封装

    为了方便管理项目中众多变量,将所有变量封装到一个对象中,需要用到时直接从该对象中获取。使用argparse封装变量。封装变量可以单独创建一个py文件,取名为arg_parser.py。代码如下:

#!/usr/bin/env python
# -*- coding:utf-8 -*-

import argparse

parser = argparse.ArgumentParser(description='该项目的一句话简介,可写可不写')

parser.add_argument('--batch_size', type=int, default=16, help='该项的解释,可写可不写')
parser.add_argument('--lr', type=float, default=0.001)
parser.add_argument('--epochs', type=int, default=10)
parser.add_argument('--snapshots_folder', type=str, default='snapshots/')
parser.add_argument('--sample_output_folder', type=str, default='samples/')

args = parser.parse_args()

2. 获取训练/测试集

    一般情况下,我们训练集都是自己的,也就是说PyTorch并不会集成。所以需要我们自己创建Dataset。自己实现Dataset,只需要继承该类时,实现里面的两个抽象函数——(__getitem__(self, item)和__len__(self))。这两个函数的功能分别是根据传入的参数获取数据,以及获取数据的数量。

#!/usr/bin/env python
# -*- coding:utf-8 -*-

import glob
import torch
import numpy as np

from PIL import Image
from torch.utils.data import Dataset


def populate_train_list(orig_images_path, haze_images_path):
    data = []

    image_list_haze = glob.glob(haze_images_path + '*.jpg')
    for image in image_list_haze:
        image = image.split('\\')[-1]
        gt = image.split('_')[0] + '_' + image.split('_')[1] + '.jpg'
        data.append([haze_images_path + image, orig_images_path + gt])

    train_dataset = []
    test_dataset = []
    # 前百分之九十:训练集,后百分之十:验证集
    train_dataset.append(data[0:int(len(data) * 0.9)])
    test_dataset.append(data[int(len(data) * 0.9):])

    train_dataset = np.asarray(train_dataset).squeeze()
    test_dataset = np.asarray(test_dataset).squeeze()

    return train_dataset, test_dataset


class dehaze_dataset(Dataset):
    def __init__(self, orig_images_path, haze_images_path, mode='train'):
        super(dehaze_dataset, self).__init__()
        self.mode = mode
        train_dataset, test_dataset = populate_train_list(orig_images_path, haze_images_path)

        if mode == 'train':
            self.data_list = train_dataset
        else:
            self.data_list = test_dataset

    def __getitem__(self, item):
        haze_image_path, gt_path = self.data_list[item]

        haze_image = Image.open(haze_image_path)
        gt = Image.open(gt_path)

        haze_image = haze_image.resize((480, 640), Image.ANTIALIAS)
        gt = gt.resize((480, 640), Image.ANTIALIAS)

        haze_image = np.asarray(haze_image) / 255.0
        gt = np.asarray(gt) / 255.0

        haze_image = torch.from_numpy(haze_image).float()
        gt = torch.from_numpy(gt).float()

        return haze_image.permute(2, 0, 1), gt.permute(2, 0, 1)

    def __len__(self):
        return len(self.data_list)

3. 网络框架

    此处以AODNet为例。创建一个名为AODNet.py文件,代码如下:

#!/usr/bin/env python
# -*- coding:utf-8 -*-

import torch
import torch.nn as nn
import torch.nn.functional as F


class AODNet(nn.Module):
    def __init__(self):
        super(AODNet, self).__init__()
        self.conv_1 = nn.Conv2d(3, 3, 1)
        self.conv_2 = nn.Conv2d(3, 3, 3, padding=1)
        self.conv_3 = nn.Conv2d(6, 3, 5, padding=2)
        self.conv_4 = nn.Conv2d(6, 3, 7, padding=3)
        self.conv_5 = nn.Conv2d(12, 3, 3, padding=1)

    def forward(self, x):
        x1 = F.relu(self.conv_1(x))
        x2 = F.relu(self.conv_2(x))
        concat1 = torch.cat((x1, x2), dim=1)
        x3 = F.relu(self.conv_3(concat1))
        concat2 = torch.cat((x2, x3), dim=1)
        x4 = F.relu(self.conv_4(concat2))
        concat3 = torch.cat((x1, x2, x3, x4), dim=1)
        k = F.relu(self.conv_5(concat3))

        return F.relu((k * x) - k + 1)

    网络框架千变万化,但是复现一个已给出参数的网络框架还是比较简单的。

4. 训练

    创建一个名为train.py的文件。训练过程有一个固定的模板,只需要记住改模板即可。

#!/usr/bin/env python
# -*- coding:utf-8 -*-

import torch
import torchvision

from arg_parser import args


def train(net, train_loader, test_dataset, optimizer, device):
    criterion = torch.nn.MSELoss().to(device)
    for epoch in range(args.epochs):
        net.train()
        train_loss = 0
        for batch_idx, (haze, gt) in enumerate(train_loader, 0):
            haze, gt = haze.to(device), gt.to(device)

            output = net(haze)
            loss = criterion(output, gt)
            train_loss += loss

            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm(net.parameters(), args.grad_clip_norm)
            optimizer.step()

            if (batch_idx + 1) % args.display_iter == 0:
                print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:6f}'.format(
                    epoch, (batch_idx + 1) * len(haze), len(train_loader.dataset),
                    100. * (batch_idx + 1) / len(train_loader), train_loss / args.display_iter
                ))
                train_loss = 0

            if (batch_idx + 1) % args.snapshot_iter == 0:
                torch.save(net.state_dict(), args.snapshots_folder + 'Epoch' + str(epoch) + '.pt')

        net.eval()
        for batch_idx, (haze, gt) in enumerate(test_dataset, 0):
            haze = haze.to(device)
            output = net(haze)

            torchvision.utils.save_image(torch.cat((haze, output, gt), dim=0),
                                         args.sample_output_folder + str(batch_idx + 1) + '.jpg')

        torch.save(net.state_dict(), args.snapshots_folder + 'epoch_' + str(epoch) + '.pt')

    torch.save(net.state_dict(), args.snapshots_folder + 'dehazer.pt')

5. 测试

#!/usr/bin/env python
# -*- coding:utf-8 -*-

import torch
import torchvision
import numpy as np

from PIL import Image
from AODNet import AODNet


def main():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    image_path = 'D:/...'
    dehaze_net = AODNet().to(device)

    data_hazy = Image.open(image_path)
    data_hazy = (np.asarray(data_hazy) / 255.0)

    data_hazy = torch.from_numpy(data_hazy).float()
    data_hazy = data_hazy.permute(2, 0, 1)
    data_hazy = data_hazy.cuda().unsqueeze(0)

    dehaze_net.load_state_dict(torch.load('snapshots/dehazer.pt'))

    clean_image = dehaze_net(data_hazy)
    torchvision.utils.save_image(torch.cat((data_hazy, clean_image), 0), "results/" + image_path.split("/")[-1])


if __name__ == "__main__":
    main()

6. 主函数

#!/usr/bin/env python
# -*- coding:utf-8 -*-

import os
import torch


from torch.utils.data import DataLoader

from train import train
from AODNet import AODNet
from arg_parser import args
from create_dataset import dehaze_dataset


def main():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    train_dataset = dehaze_dataset(args.orig_images_path, args.hazy_images_path)
    test_dataset = dehaze_dataset(args.orig_images_path, args.hazy_images_path, mode='test')

    train_loader = DataLoader(dataset=train_dataset, batch_size=args.train_batch_size, shuffle=True)
    test_loader = DataLoader(dataset=test_dataset, batch_size=args.val_batch_size, shuffle=True)

    net = AODNet().to(device)

    optimizer = torch.optim.Adam(net.parameters(), lr=args.lr, weight_decay=args.weight_decay)

    train(net, train_loader, test_loader, optimizer, device)


if __name__ == "__main__":
    if not os.path.exists(args.snapshots_folder):
        os.mkdir(args.snapshots_folder)
    if not os.path.exists(args.sample_output_folder):
        os.mkdir(args.sample_output_folder)

    main()

Logo

CSDN联合极客时间,共同打造面向开发者的精品内容学习社区,助力成长!

更多推荐