[Tensorflow] MNIST源码分析
源代码地址: https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/tutorials/mnistmnist.py主要功能: 构建一个完全连接(fully connected)的MINST模型推理 interface()def inference(images, hidden...
·
源代码地址: https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/tutorials/mnist
mnist.py
主要功能: 构建一个完全连接(fully connected)的MINST模型
推理 interface()
def inference(images, hidden1_units, hidden2_units):
"""
功能: Build the MNIST model up to where it may be used for inference
返回: 包含了预测结果的Tensor --- logits
输入:
图像占位符
第一个和第二个隐层的大小
实现方式:
借助ReLu(Rectified Linear Units)激活函数,
构建一对完全连接层(layers), 以及一个有着十个节点(node), 指明了输出logtis模型的线性层.
"""
# Hidden 1 (隐层1)
with tf.name_scope('hidden1'):
# 每个变量在构建时, 都会获得初始化操作
weights = tf.Variable(
tf.truncated_normal([IMAGE_PIXELS, hidden1_units],
stddev=1.0 / math.sqrt(float(IMAGE_PIXELS))),
name='weights')
# 权重变量的名称其实是"hidden1/weights"
# tf.truncated_normal()用来初始化weights: 根据所得到的均值和标准差, 生成一个随机分布
# shape是一个二维的tensor: [IMAGE_PIXELS, hidden1_units]
# 第一个维度代表该层中权重变量所连接(connect from)的单元数量
# 第二个维度代表该层中权重变量所连接到的(connect to)单元数量
biases = tf.Variable(tf.zeros([hidden1_units]),
name='biases')
# tf.zero()用来初始化biases
# shape: [hidden1_units], 指的是该层中所接到的(connect to)单元数量
hidden1 = tf.nn.relu(tf.matmul(images, weights) + biases)
# tf.nn.relu()是个激活函数
# Hidden 2 (隐层2)
with tf.name_scope('hidden2'):
weights = tf.Variable(
tf.truncated_normal([hidden1_units, hidden2_units],
stddev=1.0 / math.sqrt(float(hidden1_units))),
name='weights')
biases = tf.Variable(tf.zeros([hidden2_units]),
name='biases')
hidden2 = tf.nn.relu(tf.matmul(hidden1, weights) + biases)
# Linear
with tf.name_scope('softmax_linear'):
weights = tf.Variable(
tf.truncated_normal([hidden2_units, NUM_CLASSES],
stddev=1.0 / math.sqrt(float(hidden2_units))),
name='weights')
biases = tf.Variable(tf.zeros([NUM_CLASSES]),
name='biases')
logits = tf.matmul(hidden2, weights) + biases
# 生成预测结果
return logits
损失 lose()
def loss(logits, labels):
"""
功能: 计算logits和labels之间的损失值(loss)
输入: logits和labels
返回: 包含了损失值(loss)的tensor
"""
labels = tf.to_int64(labels)
return tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
训练 training()
def training(loss, learning_rate):
"""
功能: 通过梯度下降将损失最小化
输入:
损失loss
梯度下降的学习速率
返回: 包含了训练操作(training op)输出结果的Tensor
"""
# Add a scalar summary for the snapshot loss.
tf.summary.scalar('loss', loss)
# 根据给定的学习速率生成一个梯度下降算法op
optimizer = tf.train.GradientDescentOptimizer(learning_rate)
# 声明一个变量用于保存全局步骤进行到哪一步了
global_step = tf.Variable(0, name='global_step', trainable=False)
# 应用梯度下降算法去最小化损失(loss)
train_op = optimizer.minimize(loss, global_step=global_step)
return train_op
y_connected_feed.py
主要功能: 利用下载的数据集训练构建好的MNIST模型, 以数据反馈字典(feed dictionary)的形式作为输入
run_training()
- 这是核心function()
- 功能: 通过一系列步骤训练MNIST
def run_training():
"""通过一系列的步骤训练MNIST"""
# 确保你下载了正确的数据
# 解压这些输入并返回一个含有DataSet实例的字典
data_sets = input_data.read_data_sets(FLAGS.input_data_dir, FLAGS.fake_data)
# 模型将会建立在默认图上
with tf.Graph().as_default():
# 为数据创建占位符
images_placeholder, labels_placeholder = placeholder_inputs(
FLAGS.batch_size)
# 推理预测节点(op)
logits = mnist.inference(images_placeholder,
FLAGS.hidden1,
FLAGS.hidden2)
# 损失节点(op)
loss = mnist.loss(logits, labels_placeholder)
# 训练节点(op)
train_op = mnist.training(loss, FLAGS.learning_rate)
# Add the Op to compare the logits to the labels during evaluation.
eval_correct = mnist.evaluation(logits, labels_placeholder)
# Build the summary Tensor based on the TF collection of Summaries.
summary = tf.summary.merge_all()
# 变量初始化节点(op)
init = tf.global_variables_initializer()
# Create a saver for writing training checkpoints.
saver = tf.train.Saver()
# 创建会话
sess = tf.Session()
# Instantiate a SummaryWriter to output summaries and the Graph.
summary_writer = tf.summary.FileWriter(FLAGS.log_dir, sess.graph)
# And then after everything is built:
# 执行变量的初始化操作
sess.run(init)
# 开始循环训练
for step in xrange(FLAGS.max_steps):
start_time = time.time()
# 向图表提供反馈
# 为每一个训练步骤提供一个反馈字典(包括图像和标签)
# fill_feed_dict()
# 查询给定的DataSet
# 索要下一批次batch_size的图像和标签
# 与占位符相匹配的Tensor则会包含下一批次的图像和标签
# feed_dict之后会传入sess.run()中, 为其训练提供输入样例
feed_dict = fill_feed_dict(data_sets.train,
images_placeholder,
labels_placeholder)
# Run one step of the model. The return values are the activations
# from the `train_op` (which is discarded) and the `loss` Op. To
# inspect the values of your Ops or variables, you may include them
# in the list passed to sess.run() and the value tensors will be
# returned in the tuple from the call.
_, loss_value = sess.run([train_op, loss],
feed_dict=feed_dict)
duration = time.time() - start_time
# 定时输出当前的状态
if step % 100 == 0:
# 输出状态
print('Step %d: loss = %.2f (%.3f sec)' % (step, loss_value, duration))
# Update the events file.
summary_str = sess.run(summary, feed_dict=feed_dict)
summary_writer.add_summary(summary_str, step)
summary_writer.flush()
# 保存检查点并定期评估模型
if (step + 1) % 1000 == 0 or (step + 1) == FLAGS.max_steps:
checkpoint_file = os.path.join(FLAGS.log_dir, 'model.ckpt')
saver.save(sess, checkpoint_file, global_step=step)
# Evaluate against the training set.
print('Training Data Eval:')
do_eval(sess,
eval_correct,
images_placeholder,
labels_placeholder,
data_sets.train)
# Evaluate against the validation set.
print('Validation Data Eval:')
do_eval(sess,
eval_correct,
images_placeholder,
labels_placeholder,
data_sets.validation)
# Evaluate against the test set.
print('Test Data Eval:')
do_eval(sess,
eval_correct,
images_placeholder,
labels_placeholder,
data_sets.test)
更多推荐
已为社区贡献1条内容
所有评论(0)