更多问题可参考:
https://blog.csdn.net/qiankendeNMY/article/details/128450196

论文地址:https://arxiv.org/abs/2003.06957
论文代码:https://github.com/ucbdrive/few-shot-object-detection

我的配置:
Python :3.8.16(ubuntu20.04)
Pytorch :1.7.1
Cuda :11.0
GPU:RTX 3090 Ti(24GB)

1、环境配置

1.1、安装pytorch

安装和CUDA版本匹配的pytorch,且pytorch版本和预购间的Detectron2版本匹配:

pip install torch==1.7.1+cu110 torchvision==0.8.2+cu110 torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html

1.2、构建Detectron2

这里我选择的是v0.3版本:
https://github.com/facebookresearch/detectron2/releases/tag/v0.3

首先克隆仓库并标出正确版本:

git clone https://github.com/facebookresearch/detectron2.git
cd detectron2
git checkout v0.3

在这里插入图片描述在这里插入图片描述

然后,安装detectron2==0.3:

python -m pip install detectron2==0.3 -f https://dl.fbaipublicfiles.com/detectron2/wheels/cu110/torch1.7/index.html

根据requirements.txt安装其他库:

python3 -m pip install -r requirements.txt

2、准备数据集

下载小样本数据集:

wget -r -R index.html dl.yf.io/fs-det/datasets/

然后将下载好的vocsplit移动到datasets文件夹下:

在使用自制数据集时,也可以通过以下命令自己生成vocsplit:

python3 -m datasets.prepare_voc_few_shot --seeds 1 30

3、基础训练

split1的基础训练

python3 -m tools.train_net --num-gpus 1 --config-file configs/PascalVOC-detection/split1/faster_rcnn_R_101_FPN_base1.yaml

1、报错

Traceback (most recent call last):
  File "/home/test/anaconda3/envs/tfa/lib/python3.8/runpy.py", line 194, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/home/test/anaconda3/envs/tfa/lib/python3.8/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "/home/test/code/TFA/tools/train_net.py", line 113, in <module>
    launch(
  File "/home/test/anaconda3/envs/tfa/lib/python3.8/site-packages/detectron2/engine/launch.py", line 62, in launch
    main_func(*args)
  File "/home/test/code/TFA/tools/train_net.py", line 105, in main
    trainer = Trainer(cfg)
  File "/home/test/code/TFA/fsdet/engine/defaults.py", line 304, in __init__
    data_loader = self.build_train_loader(cfg)
  File "/home/test/code/TFA/fsdet/engine/defaults.py", line 492, in build_train_loader
    return build_detection_train_loader(cfg)
  File "/home/test/anaconda3/envs/tfa/lib/python3.8/site-packages/detectron2/config/config.py", line 201, in wrapped
    explicit_args = _get_args_from_config(from_config, *args, **kwargs)
  File "/home/test/anaconda3/envs/tfa/lib/python3.8/site-packages/detectron2/config/config.py", line 238, in _get_args_from_config
    ret = from_config_func(*args, **kwargs)
  File "/home/test/anaconda3/envs/tfa/lib/python3.8/site-packages/detectron2/data/build.py", line 308, in _train_loader_from_config
    dataset = get_detection_dataset_dicts(
  File "/home/test/anaconda3/envs/tfa/lib/python3.8/site-packages/detectron2/data/build.py", line 227, in get_detection_dataset_dicts
    dataset_dicts = [DatasetCatalog.get(dataset_name) for dataset_name in names]
  File "/home/test/anaconda3/envs/tfa/lib/python3.8/site-packages/detectron2/data/build.py", line 227, in <listcomp>
    dataset_dicts = [DatasetCatalog.get(dataset_name) for dataset_name in names]
  File "/home/test/anaconda3/envs/tfa/lib/python3.8/site-packages/detectron2/data/catalog.py", line 58, in get
    return f()
  File "/home/test/code/TFA/fsdet/data/meta_pascal_voc.py", line 147, in <lambda>
    lambda: load_filtered_voc_instances(
  File "/home/test/code/TFA/fsdet/data/meta_pascal_voc.py", line 48, in load_filtered_voc_instances
    fileids = np.loadtxt(f, dtype=np.str)
  File "/home/test/anaconda3/envs/tfa/lib/python3.8/site-packages/numpy/__init__.py", line 305, in __getattr__
    raise AttributeError(__former_attrs__[attr])
AttributeError: module 'numpy' has no attribute 'str'.
`np.str` was a deprecated alias for the builtin `str`. To avoid this error in existing code, use `str` by itself. Doing this will not modify any behavior and is safe. If you specifically wanted the numpy scalar type, use `np.str_` here.
The aliases was originally deprecated in NumPy 1.20; for more details and guidance see the original release note at:
    https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations

解决方案:
找到报错的位置,将 np.str 改为 np.str_:
在这里插入图片描述
评估:

python3 -m  tools.test_net --num-gpus 1  --config-file configs/PascalVOC-detection/split1/faster_rcnn_R_101_FPN_base1.yaml   --eval-only

在这里插入图片描述

4、随机初始化新类别的权重

python3 -m tools.ckpt_surgery --src1 checkpoints/voc/faster_rcnn/faster_rcnn_R_101_FPN_base1/model_final.pth --method randinit --save-dir checkpoints/voc/faster_rcnn/faster_rcnn_R_101_FPN_all1

5、微调

1、使用CosineSimOutputLayers

python3 -m tools.train_net --num-gpus 1 --config-file configs/PascalVOC-detection/split1/faster_rcnn_R_101_FPN_ft_all1_1shot.yaml --opts MODEL.WEIGHTS checkpoints/voc/faster_rcnn/faster_rcnn_R_101_FPN_all1/model_reset_surgery.pth

2、使用FastRCNNConvFCHead

python3 -m tools.train_net --num-gpus 1 --config-file configs/PascalVOC-detection/split1/faster_rcnn_R_101_FPN_ft_fc_all1_1shot.yaml --opts MODEL.WEIGHTS checkpoints/voc/faster_rcnn/faster_rcnn_R_101_FPN_all1/model_reset_surgery.pth

1、报错

/home/test/code/TFA/fsdet/data/meta_pascal_voc.py:37: FutureWarning: In the future `np.str` will be defined as the corresponding NumPy scalar.
  fileids_ = np.loadtxt(f, dtype=np.str).tolist()
Traceback (most recent call last):
  File "/home/test/anaconda3/envs/tfa/lib/python3.8/runpy.py", line 194, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/home/test/anaconda3/envs/tfa/lib/python3.8/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "/home/test/code/TFA/tools/train_net.py", line 113, in <module>
    launch(
  File "/home/test/anaconda3/envs/tfa/lib/python3.8/site-packages/detectron2/engine/launch.py", line 62, in launch
    main_func(*args)
  File "/home/test/code/TFA/tools/train_net.py", line 105, in main
    trainer = Trainer(cfg)
  File "/home/test/code/TFA/fsdet/engine/defaults.py", line 304, in __init__
    data_loader = self.build_train_loader(cfg)
  File "/home/test/code/TFA/fsdet/engine/defaults.py", line 492, in build_train_loader
    return build_detection_train_loader(cfg)
  File "/home/test/anaconda3/envs/tfa/lib/python3.8/site-packages/detectron2/config/config.py", line 201, in wrapped
    explicit_args = _get_args_from_config(from_config, *args, **kwargs)
  File "/home/test/anaconda3/envs/tfa/lib/python3.8/site-packages/detectron2/config/config.py", line 236, in _get_args_from_config
    ret = from_config_func(*args, **kwargs)
  File "/home/test/anaconda3/envs/tfa/lib/python3.8/site-packages/detectron2/data/build.py", line 301, in _train_loader_from_config
    dataset = get_detection_dataset_dicts(
  File "/home/test/anaconda3/envs/tfa/lib/python3.8/site-packages/detectron2/data/build.py", line 220, in get_detection_dataset_dicts
    dataset_dicts = [DatasetCatalog.get(dataset_name) for dataset_name in dataset_names]
  File "/home/test/anaconda3/envs/tfa/lib/python3.8/site-packages/detectron2/data/build.py", line 220, in <listcomp>
    dataset_dicts = [DatasetCatalog.get(dataset_name) for dataset_name in dataset_names]
  File "/home/test/anaconda3/envs/tfa/lib/python3.8/site-packages/detectron2/data/catalog.py", line 58, in get
    return f()
  File "/home/test/code/TFA/fsdet/data/meta_pascal_voc.py", line 147, in <lambda>
    lambda: load_filtered_voc_instances(
  File "/home/test/code/TFA/fsdet/data/meta_pascal_voc.py", line 37, in load_filtered_voc_instances
    fileids_ = np.loadtxt(f, dtype=np.str).tolist()
  File "/home/test/anaconda3/envs/tfa/lib/python3.8/site-packages/numpy/__init__.py", line 305, in __getattr__
    raise AttributeError(__former_attrs__[attr])
AttributeError: module 'numpy' has no attribute 'str'.
`np.str` was a deprecated alias for the builtin `str`. To avoid this error in existing code, use `str` by itself. Doing this will not modify any behavior and is safe. If you specifically wanted the numpy scalar type, use `np.str_` here.
The aliases was originally deprecated in NumPy 1.20; for more details and guidance see the original release note at:
    https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations

解决方案:
找到报错的位置,把 np.str 改为 np.str_:
在这里插入图片描述
微调结束:
1、CosineSimOutputLayers的结果
在这里插入图片描述

2、FastRCNNConvFCHead的结果:
在这里插入图片描述

Logo

旨在为数千万中国开发者提供一个无缝且高效的云端环境,以支持学习、使用和贡献开源项目。

更多推荐