迁移学习的实现需要网络在其他数据集上做预训练,完成参数调优工作,然后拿预训练好的参数在新的任务上做fine-tune,但是有时候可能只需要预训练的网络的一部分权重,本文主要提供一个方法如何在tf上加载想要加载的权重。

在使用tensorflow加载网络权重的时候,直接使用tf.train.Saver().restore(sess, ‘ckpt’)的话是直接加载了全部权重,我们可能只需要加载网络的前几层权重,或者只要或者不要特定几层的权重,这时可以使用下面的方法:

var = tf.global_variables()
var_to_restore = [val  for val in var if 'conv1' in val.name or 'conv2'in val.name]
saver = tf.train.Saver(var_to_restore )
saver.restore(sess, os.path.join(model_dir, model_name))
var_to_init = [val  for val in var if 'conv1' not in val.name or 'conv2'not in val.name]
tf.initialize_variables(var_to_init)

这样就只从ckpt文件里只读取到了两层卷积的卷积参数,前提是你的前两层网络结构和名字和ckpt文件里定义的一样。将var_to_restore和var_to_init反过来就是加载名字中不包含conv1、2的权重。

tensorflow
一个面向所有人的开源机器学习框架

如果使用tensorflow的slim选择性读取权重的话就更方便了

exclude = ['layer1', 'layer2']
variables_to_restore = slim.get_variables_to_restore(exclude=exclude)
saver = tf.train.Saver(variables_to_restore)
saver.restore(sess, os.path.join(model_dir, model_name))

这样就完成了不读取ckpt文件中’layer1’, ‘layer2’权重

推荐内容
GitHub 加速计划 / te / tensorflow
25
4
下载
一个面向所有人的开源机器学习框架
最近提交(Master分支:2 个月前 )
4f64a3d5 Instead, check for this case in `ResolveUsers` and `ResolveOperand`, by querying whether the `fused_expression_root` is part of the `HloFusionAdaptor`. This prevents us from stepping into nested fusions. PiperOrigin-RevId: 724311958 2 个月前
aa7e952e Fix a bug in handling negative strides, and add a test case that exposes it. We can have negative strides that are not just -1, e.g. with a combining reshape. PiperOrigin-RevId: 724293790 2 个月前
Logo

旨在为数千万中国开发者提供一个无缝且高效的云端环境,以支持学习、使用和贡献开源项目。

更多推荐