EGNet---code(二): github-推理,测试
https://github.com/JXingZhao/EGNet测试For testing:Download pretrained model (2cf5);下载 预训练模型Change the test image path in dataset.py 在数据集py文件中更改图片路径class ImageDataTest(data.Da...
测试
For testing:
-
Download pretrained model (2cf5);
下载 预训练模型 -
Change the test image path in dataset.py
在数据集py文件中更改图片路径 -
class ImageDataTest(data.Dataset): def __init__(self, test_mode=1, sal_mode='e'): if test_mode == 0: # self.image_root = '/home/liuj/dataset/saliency_test/ECSSD/Imgs/' # self.image_source = '/home/liuj/dataset/saliency_test/ECSSD/test.lst' self.image_root = '/home/liuj/dataset/HED-BSDS_PASCAL/HED-BSDS/test/' self.image_source = '/home/liuj/dataset/HED-BSDS_PASCAL/HED-BSDS/test.lst' elif test_mode == 1: if sal_mode == 'e': self.image_root = '/home/liuj/dataset/saliency_test/ECSSD/Imgs/' self.image_source = '/home/liuj/dataset/saliency_test/ECSSD/test.lst' self.test_fold = '/media/ubuntu/disk/Result/saliency/ECSSD/' elif sal_mode == 'p': self.image_root = '/home/liuj/dataset/saliency_test/PASCALS/Imgs/' self.image_source = '/home/liuj/dataset/saliency_test/PASCALS/test.lst' self.test_fold = '/media/ubuntu/disk/Result/saliency/PASCALS/' elif sal_mode == 's': self.image_root = '/home/liuj/dataset/saliency_test/SOD/Imgs/' self.image_source = '/home/liuj/dataset/saliency_test/SOD/test.lst' self.test_fold = '/media/ubuntu/disk/Result/saliency/SOD/' elif test_mode == 2: self.image_root = '/home/liuj/dataset/SK-LARGE/images/test/' self.image_source = '/home/liuj/dataset/SK-LARGE/test.lst' with open(self.image_source, 'r') as f: self.image_list = [x.strip() for x in f.readlines()]
PS: 这里选择第一个e模式,因为这个数据集才1000张,在网上(我已经传到本博客的附件里)下载即可,
另外这里有一个test.lst文件需要自己写
就是把测试图片名写下来放在ECSSD路径下 写此文件的代码我也上传了
-
Generate saliency maps for
(PS: ECSSD dataset by python3 run.py --mode test --sal_mode e
)
SOD dataset by python3 run.py --mode test --sal_mode s
,
PASCALS by python3 run.py --mode
test
--sal_mode
p
and so on;
import argparse
import os
from dataset import get_loader
from solver import Solver
def main(config):
if config.mode == 'train':
train_loader, dataset = get_loader(config.batch_size, num_thread=config.num_thread)
run = "nnet"
if not os.path.exists("%s/run-%s" % (config.save_fold, run)):
os.mkdir("%s/run-%s" % (config.save_fold, run))
os.mkdir("%s/run-%s/logs" % (config.save_fold, run))
os.mkdir("%s/run-%s/models" % (config.save_fold, run))
config.save_fold = "%s/run-%s" % (config.save_fold, run)
train = Solver(train_loader, None, config)
train.train()
elif config.mode == 'test':
test_loader, dataset = get_loader(config.test_batch_size, mode='test',num_thread=config.num_thread, test_mode=config.test_mode, sal_mode=config.sal_mode)test = Solver(None, test_loader, config, dataset.save_folder())
test.test(test_mode=config.test_mode)
else:
raise IOError("illegal input!!!")
if __name__ == '__main__':
vgg_path = '/home/liuj/code/Messal/weights/vgg16_20M.pth' # 改成自己的路经
resnet_path = '/home/liuj/code/Messal/weights/resnet50_caffe.pth'parser = argparse.ArgumentParser()
# Testing settings
parser.add_argument('--model', type=str, default='./epoch_resnet.pth')
parser.add_argument('--test_fold', type=str, default='./results/test')
parser.add_argument('--test_mode', type=int, default=1)
parser.add_argument('--sal_mode', type=str, default='t')# Misc
parser.add_argument('--mode', type=str, default='train', choices=['train', 'test'])
parser.add_argument('--visdom', type=bool, default=False)
config = parser.parse_args()
if not os.path.exists(config.save_fold): os.mkdir(config.save_fold)
main(config)
更多推荐
所有评论(0)