目的:运行并粗略看懂ML-GCN的代码。注:此代码为更改后的代码,结构与原来模型相差甚远。但代码结构大致相同。

代码地址https://github.com/chenzhaomin123/ML_GCN

论文地址  https://arxiv.org/abs/1904.03582

目录

一、相关依赖项下载

1.1 程序及数据

1.2 数据集

1.3 放入对应位置

1.4 标注 annotations位置

1.5 环境及依赖项

命令行

二、代码结构

2.1 原始ML-GCN

2.2 可选项

模型结构

loss设置

2.3 模型定义位置


一、相关依赖项下载

1.1 程序及数据

https://github.com/chenzhaomin123/ML_GCN

1.2 数据集

微软发布的 COCO 数据库是一个大型图像数据集, 专为对象检测、分割、人体关键点检测、语义分割和字幕生成而设计。

COCO 数据库的网址是:

运用coco2014数据集,数据集较大

-rw-rw-r-- 1 xingxiangrui xingxiangrui  13G Apr 23 11:20 train2014.zip
-rw-rw-r-- 1 xingxiangrui xingxiangrui 6.2G Apr 23 11:23 val2014.zip

训练集与验证集数量

train2014$ ls -l |grep "^-"|wc -l
82783
val2014$ ls -l |grep "^-"|wc -l
40504

尺寸为640*426

1.3 放入对应位置

调用关系及位置:

in general_trian.py
train_dataset = COCO2014(args.data, phase='train', inp_name=Config.INP_NAME, is_grouping=True)  # fixme
DATA = 'data/data/coco'

in coco.py
    tmpdir = os.path.join(root, 'tmp/')
    data = os.path.join(root, 'data/')
    if not os.path.exists(data):
        os.makedirs(data)
    if not os.path.exists(tmpdir):
        os.makedirs(tmpdir)
    if phase == 'train':
        filename = 'train2014.zip'
    elif phase == 'val':
        filename = 'val2014.zip'
    cached_file = os.path.join(tmpdir, filename)

    # extract file
    img_data = os.path.join(data, filename.split('.')[0])
    if not os.path.exists(img_data):
        print('[dataset] Extracting tar file {file} to {path}'.format(file=cached_file, path=data))
        command = 'unzip {} -d {}'.format(cached_file, data)
        os.system(command)

root= 'data/data/coco'

tmpdir='data/data/coco/tmp/'

data='data/data/coco  /data/'

路径位置应该为这样 /data/data/coco/   data/train2014.zip

解压后图片位置为: /data/data/coco/ data/

(如果不按照这个路径放好数据,程序会重新下载并安装)

1.4 标注 annotations位置

    # train/val images/annotations
    cached_file = os.path.join(tmpdir, 'annotations_trainval2014.zip')
    if not os.path.exists(cached_file):
        print('Downloading: "{}" to {}\n'.format(urls['annotations'], cached_file))
        os.chdir(tmpdir)
        subprocess.Popen('wget ' + urls['annotations'], shell=True)
        os.chdir(root)
    annotations_data = os.path.join(data, 'annotations')
    if not os.path.exists(annotations_data):
        print('[dataset] Extracting tar file {file} to {path}'.format(file=cached_file, path=data))
        command = 'unzip {} -d {}'.format(cached_file, data)
        os.system(command)
    print('[annotation] Done!')

tmpdir= data/data/coco /tmp/

zip 压缩包存放的位置:  cached_file= data/data/coco /tmp/ annotations_trainval2014.zip

解压后文件的位置:  data='data/data/coco  /data/'

敲击命令行,除了网络信息之外,打印出下面,即表明数据集没有问题。

[dataset] Done!
[annotation] Done!
[json] Done!

相应pkl文件也应该放入相应文件夹内(原版代码需要把data前的/去掉,原版有/data这个目录)

    DATA = 'data/data/coco'
    INP_NAME = 'data/data/coco/coco_glove_word2vec.pkl'
    ADJ_FILE = 'data/data/coco/coco_adj.pkl'

1.5 环境及依赖项

没有的话直接 pip install ***

  • numpy
  • torch-0.3.1
  • torchnet
  • torchvision-0.2.0
  • tqdm

命令行

原版ML-GCN

  • lr: learning rate 学习率
  • lrp: factor for learning rate of pretrained layers. The learning rate of the pretrained layers is lr * lrp,预训练层的因子,需要乘以学习率
  • batch-size: number of images per batch
  • image-size: size of the image
  • epochs: number of training epochs
  • evaluate: evaluate model on validation set 在验证集上进行validate,评估模型
  • resume: path to checkpoint 即checkpoint的路径

Demo VOC 2007

python3 demo_voc2007_gcn.py data/voc --image-size 448 --batch-size 32 -e --resume checkpoint/voc/voc_checkpoint.pth.tar

Demo COCO 2014

python3 demo_coco_gcn.py data/coco --image-size 448 --batch-size 32 -e --resume checkpoint/coco/coco_checkpoint.pth.tar

我们的代码:

python general_train.py

(torch031) [xingxiangrui@gzbh-mms-gpu55.gzbh.baidu.com chun-ML_GCN]$ python general_train.py
{'batch_size': 32,
 'data': 'data/data/coco',
 'device_ids': [0, 1, 2, 3],
 'epoch_step': 30,
 'epochs': 100,
 'evaluate': False,
 'image_size': 448,
 'lr': 0.01,
 'lrp': 0.001,
 'momentum': 0.9,
 'print_freq': 10,
 'resume': './checkpoint/coco/exp_4/model_best_79.8707.pth.tar',
 'start_epoch': 0,
 'weight_decay': 1e-06,
 'workers': 4}
[dataset] Done!
[annotation] Done!
[json] Done!
[dataset] Done!
[annotation] Done!
[json] Done!
Number of model parameters: 65196189
<torchvision.transforms.transforms.Compose object at 0x7f05d4c4eeb8>
=> no checkpoint found at './checkpoint/coco/exp_4/model_best_79.8707.pth.tar'
backbone learning rate 0.001
head learning rate 0.01
Epoch: [0][0/2565]	Time 22.885 (22.885)	Data 1.159 (1.159)	Loss 0.7680 (0.7680)
Epoch: [0][10/2565]	Time 1.358 (3.296)	Data 0.000 (0.106)	Loss 0.6504 (0.7201)
Epoch: [0][20/2565]	Time 1.223 (2.334)	Data 0.000 (0.056)	Loss 0.5153 (0.6486)
Epoch: [0][30/2565]	Time 1.185 (1.996)	Data 0.000 (0.038)	Loss 0.4166 (0.5850)
Epoch: [0][40/2565]	Time 1.268 (1.809)	Data 0.000 (0.029)	Loss 0.3573 (0.5346)

用torch0.4.1到后面会报错显存不够,我们需要用torch0.3.1,

python demo_coco_gcn.py data/data/coco --image-size 448 --batch-size 32 --epochs 100 -e --resume checkpoint/coco/checkpoint.pth.tar

二、代码结构

2.1 原始ML-GCN

不同数据集上有不同的代码,我们以coco代码为准

pytorch代码训练过程基本为一个套路

  • 创建参数
  • 创建模型
  • 加载数据
  • 定义loss
  • 定义optimizer
  • train

2.2 可选项

模型结构

三种模型结构,hgat_fc,  hgat_conv,  groupnet(可以理解为baseline)

    # fixme=============begin=========
    if Config.MODEL == 'hgat_fc':
        import mymodels.hgat_fc as hgat_fc
        model = hgat_fc.HGAT_FC(Config.BACKBONE, groups=Config.GROUPS, nclasses=Config.NCLASSES,
                                nclasses_per_group=Config.NCLASSES_PER_GROUP,
                                group_channels=Config.GROUP_CHANNELS, class_channels=Config.CLASS_CHANNELS)
    elif Config.MODEL == 'hgat_conv':
        import mymodels.hgat_conv as hgat_conv
        model = hgat_conv.HGAT_CONV(Config.BACKBONE, groups=Config.GROUPS, nclasses=Config.NCLASSES,
                            nclasses_per_group=Config.NCLASSES_PER_GROUP,
                            group_channels=Config.GROUP_CHANNELS, class_channels=Config.CLASS_CHANNELS)
    elif Config.MODEL == 'groupnet':
        pass
    else:
        raise Exception()
    print('Number of model parameters: {}'.format(
        sum([p.data.nelement() for p in model.parameters()])))

loss设置

可以在这三种loss之中选择一种

BCEWithLogitsLoss, MultiLabelSoftMarginLoss, DeepMarLoss
    if Config.LOSS_TYPE == 'MultiLabelSoftMarginLoss':
        criterion = nn.MultiLabelSoftMarginLoss()
    elif Config.LOSS_TYPE == 'BCEWithLogitsLoss':
        criterion = nn.BCEWithLogitsLoss()
    elif Config.LOSS_TYPE == 'DeepMarLoss':
        criterion = F.binary_cross_entropy_with_logits
    else:
        raise Exception()

2.3 模型定义位置

以MODEL = 'hgat_fc'  为准,在mymodels中hgat_fc.py之中。

        import mymodels.hgat_fc as hgat_fc
        model = hgat_fc.HGAT_FC(Config.BACKBONE, groups=Config.GROUPS, nclasses=Config.NCLASSES,
                                nclasses_per_group=Config.NCLASSES_PER_GROUP,
                                group_channels=Config.GROUP_CHANNELS, class_channels=Config.CLASS_CHANNELS)

其中:

class HGAT_FC(nn.Module):
    def __init__(self, backbone, groups, nclasses, nclasses_per_group, group_channels, class_channels):
        super(HGAT_FC, self).__init__()

三、自动断开及其解决

3.1 问题描述

Test: [510/1255]	Time 0.501 (0.559)	Data 0.000 (0.002)	Loss 0.0975 (0.1163)
Test: [520/1255]	Time 0.535 (0.559)	Data 0.000 (0.002)	Loss 0.1211 (0.1163)
packet_write_wait: Connection to 10.44.67.42 port 22: Broken pipe

https://blog.csdn.net/weixin_36474809/article/details/88710505

已经用这个方法设置了非自动断开,但是运行代码时候会自动断开,可能因为运行时间过长。因此我们需要重新设置代码。

3.2 运用nohup指令运行

直接更改general_train.sh文件

CUDA_VISIBLE_DEVICES=0,1,2,3 nohup python -u demo_coco_hgat.py  > ./train_logs/exp_4.log 2>&1 &
  • -u参数的使用

python命令加上-u(unbuffered)参数后会强制其标准输出也同标准错误一样不通过缓存直接打印到屏幕。

CUDA_VISIBLE_DEVICES=0,1,2,3 nohup python -u general_train.py

Logo

旨在为数千万中国开发者提供一个无缝且高效的云端环境,以支持学习、使用和贡献开源项目。

更多推荐