Tensorflow查看网络、冻结变量和迁移训练

(Inspect network structure, freeze graph variables, and finetune/transfer learning in Tensorflow)

 

1.    查看网络结构和参数

 

python
/usr/local/lib/python2.7/dist-packages/tensorflow/python/tools/inspect_checkpoint.py
--file_name=model.ckpt-1562770
--tensor_name=unit_1_2/sub1/conv1/DW

 

源码中的inspect_checkpoint.py可以看ckpt文件中的层和某层的权重值

如果只有--file_name就只显示层,如果还有--tensor_name就能显示那一层的权重

 

2.    只训练graph中部分变量(相当于冻结了其他变量)

Tensorflow在构建graph的过程中会默认自动收集一些变量名到对应的Collection。例如TRAINABLE_VARIABLES就是所有可训练的变量集合。

因此可以通过使用tf.get_collection,指定TRAINABLE_VARIABLES,使其仅包含我们需要重新训练的变量,来冻结其他变量的训练。

例子如下:

 

    first_train_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
        "unit_last")
    trainable_variables = first_train_vars
    #print trainable_variables
    grads = self.optimizer.compute_gradients(self.cost, self.trainable_variables)

 

 

 

 

3.    更改graph后恢复训练

根据monitored_session.py,使用MonitoredTrainingSession来开启控制Session的时候,若指定的checkpoint路径中有上次的存档,则现有源码只能严格按照之前训练恢复。因此我们需要一个空的checkpoint路径,此时MonitoredTrainingSession就会执行init_op以及init_fn。在init_fn中自己添加恢复函数,并把init_fn作为参数加入MonitoredTrainingSession中的scaffold即可。

例子如下:

 

    variables_to_restore = tf.contrib.framework.get_variables_to_restore(
        exclude=['logit'])
    init_assign_op, init_feed_dict = tf.contrib.framework.assign_from_checkpoint(
        ckpt.model_checkpoint_path, variables_to_restore)
    def InitAssignFn(scaffold, sess):
        sess.run(init_assign_op, init_feed_dict)
    scaffold = tf.train.Scaffold(saver=tf.train.Saver(), init_fn=InitAssignFn)

 

 

 

 

 

Logo

CSDN联合极客时间,共同打造面向开发者的精品内容学习社区,助力成长!

更多推荐