模型存储的参数恢复

模型存取介绍

见该博客

模型恢复

使用代码

def get_restorer():

    checkpoint_path = tf.train.latest_checkpoint(os.path.join(FLAGS.trained_checkpoint, cfgs.VERSION))

    if checkpoint_path != None:
        if RESTORE_FROM_RPN:
            print('___restore from rpn___')
            model_variables = slim.get_model_variables()
            restore_variables = [var for var in model_variables if not var.name.startswith('Fast_Rcnn')] + [slim.get_or_create_global_step()]
            for var in restore_variables:
                print(var.name)
            restorer = tf.train.Saver(restore_variables)
        else:
            restorer = tf.train.Saver()
        print("model restore from :", checkpoint_path)
    else:
        checkpoint_path = FLAGS.pretrained_model_path
        print("model restore from pretrained mode, path is :", checkpoint_path)

        model_variables = slim.get_model_variables()

        restore_variables = [var for var in model_variables if
                             (var.name.startswith(cfgs.NET_NAME)
                              and not var.name.startswith('{}/logits'.format(cfgs.NET_NAME)))]
        for var in restore_variables:
            print(var.name)
        restorer = tf.train.Saver(restore_variables)
    return restorer, checkpoint_path

需在FLAGS.pretrained_model_path给出预训练模型参数的路径。

.ckpt文件与{.ckpt.meta, .ckpt.index, .ckpt.data-00000-of-00001}的区别

两者没有区别,只是在设置路径时(FLAGS.pretrained_model_path),对于.ckpt文件,路径需写至文件夹路径/.ckpt;而对于{.ckpt.meta, .ckpt.index, .ckpt.data-00000-of-00001},路径需写至文件夹路径/.ckpt-160186,其中.ckpt-160186.meta的前缀。

Logo

瓜分20万奖金 获得内推名额 丰厚实物奖励 易参与易上手

更多推荐